## Global modules import

In [49]:
%load_ext autoreload
%autoreload 2

In [94]:
import json
import numpy as np
from operator import itemgetter
import random as rnd
import sys
import torch

## Local modules import

In [14]:
sys.path.append('..')

In [62]:
from data_loading import create_word_lists

In [76]:
from sklearn.model_selection import train_test_split

## Loading data

In [17]:
sys.path.append('../data')

In [30]:
with open('../data/corpus_data.json') as json_file:
    data = json.load(json_file)
data = data['records']

In [31]:
human_transcripts = [entry['human_transcript'] for entry in data]
stt_transcripts   = [entry['stt_transcript'] for entry in data]

In [63]:
human_words, stt_words, word_labels, word_grams, word_sems = \
    create_word_lists(data)

# PIPELINE START
---

## Train-test split

We need to extract which sentences contain German words in order to stratify the data split:

In [90]:
max_length = max(map(len, word_labels))
padded_labels = [row + [False] * (max_length - len(row)) for row in word_labels]
padded_labels = np.array(padded_labels)
stat_labels = np.any(padded_labels, axis=1)

Here, we split only indices and not data itself, because the data contains arrays of variable length, which does not work with `train_test_split`:

In [91]:
indices = list(range(len(human_transcripts)))
tr_indices, te_indices = train_test_split(indices, test_size=0.2, random_state=0, shuffle=True, stratify=stat_labels)

These are hepler functions that will extract data selected by indices:

In [96]:
extract_train = itemgetter(*tr_indices)
extract_test  = itemgetter(*te_indices)

Finally, do data splitting:

In [97]:
tr_human_transcripts = extract_train(human_transcripts) 
tr_stt_transcripts   = extract_train(stt_transcripts)
tr_human_words       = extract_train(human_words)
tr_stt_words         = extract_train(stt_words)
tr_word_labels       = extract_train(word_labels)
tr_word_grams        = extract_train(word_grams)
tr_word_sems         = extract_train(word_sems)

te_human_transcripts = extract_test(human_transcripts) 
te_stt_transcripts   = extract_test(stt_transcripts)
te_human_words       = extract_test(human_words)
te_stt_words         = extract_test(stt_words)
te_word_labels       = extract_test(word_labels)
te_word_grams        = extract_test(word_grams)
te_word_sems         = extract_test(word_sems)