In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from tokenizer import Tokenizer
from util import head

In [3]:
tok = Tokenizer()

In [4]:
tok.fit(['hello world', 'foo bar'])

<tokenizer.Tokenizer at 0x1060fa390>

In [5]:
seqs = tok.texts_to_sequences(['hello world blub', 'foo bar'])

In [6]:
tok.sequences_to_texts(seqs)

['<bos> hello world <unk> <eos>', '<bos> foo bar <eos>']

In [64]:
import pickle
import os
from util import load_or_compute
from tokenizer import Tokenizer

from tensorflow.keras.preprocessing.sequence import pad_sequences
SCTStories = namedtuple('SCTStories', ('begin', 'end_real', 'end_fake'))
SCTSequences = namedtuple('SCTSequences', ('begin', 'end_real', 'end_fake'))

def read_sct_stories(fname, skip_header=True):
    beginnings = list()
    real_endings = list()
    fake_endings = list()
    with open(fname) as f:
        csvreader = csv.reader(f, delimiter=',')
        if skip_header:
            next(csvreader, None) # skip header
        for row in csvreader:
            if len(row) == 7: #ROCstories, only real endings
                beginnings.extend(row[2:-1])
                real_endings.append(row[-1])
            elif len(row) == 8: #Eval/test set, real and fake endings
                beginnings.extend(row[1:-3])
                realid = int(row[-1])
                if realid == 1:
                    real_endings.append(row[-3])
                    fake_endings.append(row[-2])
                else:
                    real_endings.append(row[-2])
                    fake_endings.append(row[-3])
            else:
                raise Exception('wrong number of items in input file')
    return SCTStories(beginnings, real_endings, fake_endings if len(fake_endings) > 0 else None)

def sct_stories_to_sequences(texts_to_sequences_func, sct_stories, max_seq_len=90):
    seq_b = pad_sequences(texts_to_sequences_func(sct_stories.begin), maxlen=max_seq_len)
    seq_b = seq_b.reshape(seq_b.shape[0]//4, 4, seq_b.shape[1])
    seq_r = pad_sequences(texts_to_sequences_func(sct_stories.end_real), maxlen=max_seq_len)
    seq_f = None
    if sct_stories.end_fake:
        seq_f = pad_sequences(texts_to_sequences_func(sct_stories.end_fake), maxlen=max_seq_len)
    return SCTSequences(seq_b, seq_r, seq_f)

In [65]:
texts_train = read_sct_stories('data/sct_train.csv')
texts_eval = read_sct_stories('data/sct_val.csv')

In [66]:
tok = Tokenizer().fit(texts_train.begin + texts_train.end_real)

In [67]:
tok.vocabulary_size

64647

In [68]:
import pandas as pd
print(pd.Series([len(seq) for seq in texts_train.begin]).describe())
print(pd.Series([len(seq) for seq in texts_eval.begin]).describe())

count    352644.000000
mean         44.065451
std          13.230812
min           9.000000
25%          34.000000
50%          44.000000
75%          54.000000
max          86.000000
dtype: float64
count    7484.000000
mean       45.646045
std        12.919189
min        11.000000
25%        36.000000
50%        46.000000
75%        56.000000
max        72.000000
dtype: float64


In [71]:
cache_dir = 'cache'
os.makedirs(cache_dir, exist_ok=True)

In [74]:
tok = load_or_compute(os.path.join(cache_dir, 'tokenizer.pickle'), tok.fit, texts_train.begin + texts_train.end_real)

In [75]:
seqs_train = load_or_compute(os.path.join(cache_dir, 'train.pickle'), sct_stories_to_sequences, tok.texts_to_sequences, texts_train)
seqs_eval = load_or_compute(os.path.join(cache_dir, 'eval.pickle'), sct_stories_to_sequences, tok.texts_to_sequences, texts_eval)