In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
from sct_dataset import read_sct_stories
from sct_dataset import SCTCachedReader
from tokenizer import Tokenizer

In [3]:
DATA_DIR = 'data'
CACHE_DIR = 'cache'

# download the data
! test -d $DATA_DIR || mkdir $DATA_DIR
! test -f "$DATA_DIR/sct_train.csv" || curl "http://n.ethz.ch/~thomasdi/download/sct_train.csv" --output "$DATA_DIR/sct_train.csv"
! test -f "$DATA_DIR/sct_val.csv" || curl "http://n.ethz.ch/~thomasdi/download/sct_val.csv" --output "$DATA_DIR/sct_val.csv"

# initialize cache dir
! test -d $CACHE_DIR && rm -rf $CACHE_DIR

In [4]:
texts_train = read_sct_stories('data/sct_train.csv')
texts_eval = read_sct_stories('data/sct_val.csv')

In [5]:
tok = Tokenizer().fit(texts_train.begin + texts_train.end_real)

In [6]:
tok.vocabulary_size

34961

In [7]:
print('# STATS FOR TRAINING DATA')
print(pd.Series([len(seq) for seq in texts_train.begin]).describe())
print('# STATS FOR EVAL DATA')
print(pd.Series([len(seq) for seq in texts_eval.begin]).describe())

# STATS FOR TRAINING DATA
count    352644.000000
mean         44.065451
std          13.230812
min           9.000000
25%          34.000000
50%          44.000000
75%          54.000000
max          86.000000
dtype: float64
# STATS FOR EVAL DATA
count    7484.000000
mean       45.646045
std        12.919189
min        11.000000
25%        36.000000
50%        46.000000
75%        56.000000
max        72.000000
dtype: float64


In [8]:
MAX_SEQ_LEN = 91

# create cachedir
! test -d $CACHE_DIR || mkdir $CACHE_DIR

# prepare SCT reader
sctreader = SCTCachedReader(CACHE_DIR, tok, MAX_SEQ_LEN)

In [9]:
sctreader.read_stories('data/sct_train.csv')

SCTSequences(begin=array([[[    0,     0,     0, ...,  2457,     4,     2],
        [    0,     0,     0, ...,   399,     4,     2],
        [    0,     0,     0, ...,   399,     4,     2],
        [    0,     0,     0, ...,   119,     4,     2]],

       [[    0,     0,     0, ...,    15,     4,     2],
        [    0,     0,     0, ...,   318,     4,     2],
        [    0,     0,     0, ...,  2008,     4,     2],
        [    0,     0,     0, ...,  9822,    42,     2]],

       [[    0,     0,     0, ...,  1673,     4,     2],
        [    0,     0,     0, ...,  1673,     4,     2],
        [    0,     0,     0, ...,  2564,    13,     2],
        [    0,     0,     0, ...,  5308,     4,     2]],

       ...,

       [[    0,     0,     0, ...,    99,     4,     2],
        [    0,     0,     0, ...,    73,     4,     2],
        [    0,     0,     0, ...,  3955,     4,     2],
        [    0,     0,     0, ...,  1018,     4,     2]],

       [[    0,     0,     0, ...,   320,     4,

In [41]:
story_index = 3
print('STORY BEGIN:')
stories = sctreader.read_stories('data/sct_val.csv')
for text in tok.sequences_to_texts(stories.begin[story_index]):
    print('  ', text)
print('REAL END:')
print('  ', tok.sequences_to_texts(stories.end_real[None, story_index])[0])
print('FAKE END:')
print('  ', tok.sequences_to_texts(stories.end_fake[None, story_index])[0])


STORY BEGIN:
   <bos> gina was worried the cookie dough in the tube would be gross . <eos>
   <bos> she was very happy to find she was wrong . <eos>
   <bos> the cookies from the tube were as good as from scratch . <eos>
   <bos> gina intended to only eat 2 cookies and save the rest . <eos>
REAL END:
   <bos> gina liked the cookies so much she ate them all in one sitting . <eos>
FAKE END:
   <bos> gina gave the cookies away at her church . <eos>
