# Boilerplate

In [121]:
# files
TRAINING_DIRECTORY = 'cnn/stories/'
EXTENSION = '.story'
MAX_FILES = 10000

# tokenization
FILTERS = '!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n'  # default minus >, <
TARGET_BEGIN_CHAR = '<target-begin>'
END_CHAR = '<end>'
OOV_CHAR = '<unk>'
TARGET_BEGIN_TOKEN = 1
END_TOKEN = 2
OOV_TOKEN = 3
NUM_WORDS = 2000

# MODEL_PARAMS
SENTENCE_LEN = 200
MAX_INPUT_LEN = 150

# Read in files

In [2]:
import glob

In [3]:
FILES = glob.glob('%s/*%s' % (TRAINING_DIRECTORY, EXTENSION))
print(len(FILES))
FILES

92579


['cnn/stories/dc295a3a6fbc0cda66119560c938418c6e4a237a.story',
 'cnn/stories/d9f04519987f0b8b276ac29e5922fa166dd4db24.story',
 'cnn/stories/760e36c897b001a669fb77fb7370ed0e733bbb49.story',
 'cnn/stories/25f7bb8875c84968ef8d8e1fb7cc87950dd01c4e.story',
 'cnn/stories/c3933bf13f505ec9d9a62af21a896ef533709b3b.story',
 'cnn/stories/d558e1ce6dd3a1594d4eeca96f4815b8444a63a6.story',
 'cnn/stories/80b5cbd1477153e8502a6af912a1374a2cc11564.story',
 'cnn/stories/1dc981ec98233ec5a0136f7f73ad70da6a940760.story',
 'cnn/stories/d29cf9d5118713d240e8e0102e4086dc8f2e1495.story',
 'cnn/stories/5731ca7672bece19c47bc07ca4d887484d75b567.story',
 'cnn/stories/313021a532d8f95d8d74b2e576ffa56c32761f94.story',
 'cnn/stories/2d13c8cb0938c1de2769b430e7a2a05281f9e891.story',
 'cnn/stories/6ea17cbb6db99d009d6d7eccd7c5b3ec1470171a.story',
 'cnn/stories/33e41b123b136a9f9e60fd7bf2cee225aed599cf.story',
 'cnn/stories/82f2640c1ec65cb01aa7b4575e8fbb75d2342549.story',
 'cnn/stories/dc1dd85fa9e7ef464bed89906e4cba1a5fe1774f.

In [4]:
FILES = FILES[:MAX_FILES]

# Define method for generating text from files

In [5]:
def preprocessor(text):
    table = {ord(c): None for c in '<>'}
    text = text.translate(table)
    return text

In [6]:
def text_generator(files, preprocessor=None):
    for f in files:
        text = open(f).read()
        if preprocessor is not None:
            text = preprocessor(text)
        # remove highlights
        body, highlight1, *_ = text.split('@highlight')
        yield body, highlight1

In [66]:
from keras.preprocessing.sequence import pad_sequences

def tokenize(input_text, target_text, tokenizer, max_input_len, target_begin_token, end_token):
    input_tokens = tokenizer([input_text])[0]
    target_tokens = tokenizer([target_text])[0]
    input_tokens = pad_sequences([input_tokens[:max_input_len]])[0].tolist()
    target_tokens = target_tokens
    return [input_tokens + [target_begin_token] + target_tokens + [end_token]]

In [44]:
next(text_generator(FILES, preprocessor=preprocessor))

('(CNN)The Spanish nurse\'s aide who contracted Ebola after treating virus-stricken patients in Madrid is now free of the virus, her doctors announced Tuesday after another test on her.\n\nTeresa Romero Ramos is clear of Ebola, physicians at Carlos III hospital said.\n\nShe received an initial test, which turned up no virus in her blood, doctors said Sunday. More tests were administered to be sure she was virus-free.\n\nWhile Spain welcomes the good news about Ebola, the United States is doing more to help prevent the spread of the virus. The Department of Homeland Security said Tuesday that all arriving passengers from West African countries that Ebola has hit hardest -- Liberia, Sierra Leone and Guinea -- must land in one of the five U.S. airports that have enhanced Ebola screening.\n\nThose airports are New York\'s John F. Kennedy International; D.C.\'s Washington Dulles; New Jersey\'s Newark Liberty International; Chicago\'s O\'Hare International; and Hartsfield-Jackson Internation

# Initialize tokenizer

In [45]:
from keras.preprocessing.text import text_to_word_sequence, Tokenizer as _Tokenizer

class Tokenizer(_Tokenizer):
    def fit_on_texts(self, texts):
        """Updates internal vocabulary based on a list of texts.
        In the case where texts contains lists, we assume each entry of the lists
        to be a token.
        Required before using `texts_to_sequences` or `texts_to_matrix`.
        # Arguments
            texts: can be a list of strings,
                a generator of strings (for memory-efficiency),
                or a list of list of strings.
        """
        for text in texts:
            self.document_count += 1
            if self.char_level or isinstance(text, list):
                seq = text
            else:
                seq = text_to_word_sequence(text,
                                            self.filters,
                                            self.lower,
                                            self.split)
            for w in seq:
                if w in self.word_counts:
                    self.word_counts[w] += 1
                else:
                    self.word_counts[w] = 1
            for w in set(seq):
                if w in self.word_docs:
                    self.word_docs[w] += 1
                else:
                    self.word_docs[w] = 1

        wcounts = list(self.word_counts.items())
        wcounts.sort(key=lambda x: x[1], reverse=True)
        sorted_voc = [wc[0] for wc in wcounts]
        # note that index 0, 1, 2 is reserved, never assigned to an existing word
        self.word_index = dict(list(zip(sorted_voc, list(range(4, len(sorted_voc) + 4)))))
        self.word_index[self.oov_token] = 3

        for w, c in list(self.word_docs.items()):
            self.index_docs[self.word_index[w]] = c

    def texts_to_sequences_generator(self, texts):
        """Transforms each text in `texts` in a sequence of integers.
        Each item in texts can also be a list, in which case we assume each item of that list
        to be a token.
        Only top "num_words" most frequent words will be taken into account.
        Only words known by the tokenizer will be taken into account.
        # Arguments
            texts: A list of texts (strings).
        # Yields
            Yields individual sequences.
        """
        num_words = self.num_words
        for text in texts:
            if self.char_level or isinstance(text, list):
                seq = text
            else:
                seq = text_to_word_sequence(text,
                                            self.filters,
                                            self.lower,
                                            self.split)
            vect = []
            for w in seq:
                i = self.word_index.get(w)
                if i is not None and (self.num_words and i < self.num_words):
                    vect.append(i)
                elif self.oov_token is not None:
                    i = self.word_index.get(self.oov_token)
                    if i is not None:
                        vect.append(i)
            yield vect

In [46]:
TOKENIZER = Tokenizer(
    num_words=NUM_WORDS,
    filters=FILTERS,  # no newline
    oov_token=OOV_CHAR)

In [47]:
gen = text_generator(FILES, preprocessor=preprocessor)

In [48]:
%%time
TOKENIZER.fit_on_texts(text for train_pair in gen for text in train_pair)

CPU times: user 5.29 s, sys: 72 ms, total: 5.36 s
Wall time: 5.36 s


In [49]:
TOKENIZER.num_words

2000

In [50]:
TOKENIZER.document_count

20000

In [51]:
len(TOKENIZER.word_index), TOKENIZER.word_index

(101044,
 {"seals'": 48836,
  'pastiches': 66367,
  'bombastic': 61346,
  'tanzman': 84829,
  'ween': 30415,
  "''but": 86621,
  '173': 27584,
  'concerts': 7327,
  'apotex': 80394,
  'audience': 1752,
  'committees': 5708,
  'pashupatinath': 92702,
  'realise': 37778,
  'ashish': 50539,
  'unregulated': 20207,
  'enrichment': 7007,
  'negatives': 21002,
  'eavis': 65209,
  'gunpoint': 13312,
  'kapiolani': 79979,
  'ischemic': 44334,
  'ljubicic': 41326,
  'lakeway': 90017,
  'wifely': 82727,
  "1991's": 53172,
  "aurobindo's": 83248,
  "'already": 65682,
  'parisian': 20261,
  'fiji': 16438,
  'pregnancy': 4864,
  "intellectuals'": 100677,
  'obliviousness': 85729,
  '396': 57291,
  'murkiest': 78159,
  "'actually": 47599,
  'bout': 10801,
  'newsfeed': 31897,
  'dialog': 40959,
  'ataye': 61535,
  "innocent's": 92597,
  '1981': 6267,
  'voyage': 9713,
  'panjshir': 94169,
  'ligure': 74196,
  'wearing': 1587,
  'sarawak': 42536,
  'vivino': 79000,
  'sternberg': 72847,
  'shaggy': 2

In [52]:
index_to_word = {v: k for k, v in TOKENIZER.word_index.items()}
index_to_word[0] = '<pad>'
index_to_word[TARGET_BEGIN_TOKEN] = TARGET_BEGIN_CHAR
index_to_word[END_TOKEN] = END_CHAR

In [53]:
sorted(index_to_word.items(), key=lambda x: x[0])

[(0, '<pad>'),
 (1, '<target-begin>'),
 (2, '<end>'),
 (3, '<unk>'),
 (4, 'the'),
 (5, 'to'),
 (6, 'of'),
 (7, 'and'),
 (8, 'a'),
 (9, 'in'),
 (10, 'that'),
 (11, 'for'),
 (12, 'is'),
 (13, 'said'),
 (14, 'on'),
 (15, 'was'),
 (16, 'with'),
 (17, 'he'),
 (18, 'it'),
 (19, 'as'),
 (20, 'at'),
 (21, 'his'),
 (22, 'have'),
 (23, 'from'),
 (24, 'are'),
 (25, 'be'),
 (26, 'i'),
 (27, 'but'),
 (28, 'by'),
 (29, 'this'),
 (30, 'has'),
 (31, 'an'),
 (32, 'not'),
 (33, 'they'),
 (34, 'who'),
 (35, 'we'),
 (36, 'will'),
 (37, 'were'),
 (38, 'their'),
 (39, 'you'),
 (40, 'she'),
 (41, 'her'),
 (42, 'had'),
 (43, 'about'),
 (44, 'more'),
 (45, 'been'),
 (46, 'one'),
 (47, 'or'),
 (48, 'cnn'),
 (49, 'people'),
 (50, 'after'),
 (51, 'when'),
 (52, 'new'),
 (53, 'all'),
 (54, 'would'),
 (55, 'out'),
 (56, 'which'),
 (57, 'up'),
 (58, 'there'),
 (59, 'what'),
 (60, 'also'),
 (61, 'its'),
 (62, "it's"),
 (63, 'year'),
 (64, 'if'),
 (65, 'can'),
 (66, 'u'),
 (67, 'so'),
 (68, 'two'),
 (69, 'than'),
 (70

In [54]:
TOKENIZER.num_words = min(len(TOKENIZER.word_index)+1, TOKENIZER.num_words)

In [55]:
gen = text_generator(FILES)
x, y = next(gen)

In [56]:
len(x)

3566

In [57]:
x

'(CNN)The Spanish nurse\'s aide who contracted Ebola after treating virus-stricken patients in Madrid is now free of the virus, her doctors announced Tuesday after another test on her.\n\nTeresa Romero Ramos is clear of Ebola, physicians at Carlos III hospital said.\n\nShe received an initial test, which turned up no virus in her blood, doctors said Sunday. More tests were administered to be sure she was virus-free.\n\nWhile Spain welcomes the good news about Ebola, the United States is doing more to help prevent the spread of the virus. The Department of Homeland Security said Tuesday that all arriving passengers from West African countries that Ebola has hit hardest -- Liberia, Sierra Leone and Guinea -- must land in one of the five U.S. airports that have enhanced Ebola screening.\n\nThose airports are New York\'s John F. Kennedy International; D.C.\'s Washington Dulles; New Jersey\'s Newark Liberty International; Chicago\'s O\'Hare International; and Hartsfield-Jackson Internationa

In [67]:
seq = tokenize(
    x,
    y,
    TOKENIZER.texts_to_sequences,
    max_input_len=MAX_INPUT_LEN,
    target_begin_token=TARGET_BEGIN_TOKEN,
    end_token=END_TOKEN)
seq

[[48,
  4,
  1410,
  3,
  3,
  34,
  3,
  1284,
  50,
  3,
  1640,
  3,
  1165,
  9,
  1593,
  12,
  93,
  346,
  6,
  4,
  1640,
  41,
  1134,
  489,
  202,
  50,
  157,
  928,
  14,
  41,
  3,
  3,
  3,
  12,
  405,
  6,
  1284,
  3,
  20,
  3,
  3,
  459,
  13,
  40,
  624,
  31,
  3,
  928,
  56,
  735,
  57,
  71,
  1640,
  9,
  41,
  1037,
  1134,
  13,
  260,
  44,
  1799,
  37,
  3,
  5,
  25,
  524,
  40,
  15,
  1640,
  346,
  107,
  1537,
  3,
  4,
  192,
  179,
  43,
  1284,
  4,
  106,
  110,
  12,
  374,
  44,
  5,
  190,
  1349,
  4,
  1400,
  6,
  4,
  1640,
  4,
  251,
  6,
  3,
  161,
  13,
  202,
  10,
  53,
  3,
  929,
  23,
  449,
  495,
  359,
  10,
  1284,
  30,
  407,
  3,
  3,
  3,
  3,
  7,
  3,
  354,
  1006,
  9,
  46,
  6,
  4,
  231,
  66,
  73,
  3,
  10,
  22,
  3,
  1284,
  3,
  98,
  3,
  24,
  52,
  3,
  355,
  1971,
  1689,
  181,
  994,
  1313,
  3,
  277,
  3,
  52,
  3,
  3,
  3,
  1,
  66,
  73,
  3,
  3,
  3,
  23,
  3,
  3,
  3,
  3,
  5,
  231

In [68]:
[[index_to_word[i] for i in L] for L in seq]

[['cnn',
  'the',
  'spanish',
  '<unk>',
  '<unk>',
  'who',
  '<unk>',
  'ebola',
  'after',
  '<unk>',
  'virus',
  '<unk>',
  'patients',
  'in',
  'madrid',
  'is',
  'now',
  'free',
  'of',
  'the',
  'virus',
  'her',
  'doctors',
  'announced',
  'tuesday',
  'after',
  'another',
  'test',
  'on',
  'her',
  '<unk>',
  '<unk>',
  '<unk>',
  'is',
  'clear',
  'of',
  'ebola',
  '<unk>',
  'at',
  '<unk>',
  '<unk>',
  'hospital',
  'said',
  'she',
  'received',
  'an',
  '<unk>',
  'test',
  'which',
  'turned',
  'up',
  'no',
  'virus',
  'in',
  'her',
  'blood',
  'doctors',
  'said',
  'sunday',
  'more',
  'tests',
  'were',
  '<unk>',
  'to',
  'be',
  'sure',
  'she',
  'was',
  'virus',
  'free',
  'while',
  'spain',
  '<unk>',
  'the',
  'good',
  'news',
  'about',
  'ebola',
  'the',
  'united',
  'states',
  'is',
  'doing',
  'more',
  'to',
  'help',
  'prevent',
  'the',
  'spread',
  'of',
  'the',
  'virus',
  'the',
  'department',
  'of',
  '<unk>',
  's

In [69]:
len(seq), len(seq[0])

(1, 165)

In [71]:
s = seq[0]
s

[48,
 4,
 1410,
 3,
 3,
 34,
 3,
 1284,
 50,
 3,
 1640,
 3,
 1165,
 9,
 1593,
 12,
 93,
 346,
 6,
 4,
 1640,
 41,
 1134,
 489,
 202,
 50,
 157,
 928,
 14,
 41,
 3,
 3,
 3,
 12,
 405,
 6,
 1284,
 3,
 20,
 3,
 3,
 459,
 13,
 40,
 624,
 31,
 3,
 928,
 56,
 735,
 57,
 71,
 1640,
 9,
 41,
 1037,
 1134,
 13,
 260,
 44,
 1799,
 37,
 3,
 5,
 25,
 524,
 40,
 15,
 1640,
 346,
 107,
 1537,
 3,
 4,
 192,
 179,
 43,
 1284,
 4,
 106,
 110,
 12,
 374,
 44,
 5,
 190,
 1349,
 4,
 1400,
 6,
 4,
 1640,
 4,
 251,
 6,
 3,
 161,
 13,
 202,
 10,
 53,
 3,
 929,
 23,
 449,
 495,
 359,
 10,
 1284,
 30,
 407,
 3,
 3,
 3,
 3,
 7,
 3,
 354,
 1006,
 9,
 46,
 6,
 4,
 231,
 66,
 73,
 3,
 10,
 22,
 3,
 1284,
 3,
 98,
 3,
 24,
 52,
 3,
 355,
 1971,
 1689,
 181,
 994,
 1313,
 3,
 277,
 3,
 52,
 3,
 3,
 3,
 1,
 66,
 73,
 3,
 3,
 3,
 23,
 3,
 3,
 3,
 3,
 5,
 231,
 3,
 2]

In [72]:
one_hot = TOKENIZER.sequences_to_matrix([[i] for i in s])

In [73]:
one_hot

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.]])

In [74]:
# only one per row
import numpy as np
np.argwhere(one_hot == 1)[:50]

array([[   0,   48],
       [   1,    4],
       [   2, 1410],
       [   3,    3],
       [   4,    3],
       [   5,   34],
       [   6,    3],
       [   7, 1284],
       [   8,   50],
       [   9,    3],
       [  10, 1640],
       [  11,    3],
       [  12, 1165],
       [  13,    9],
       [  14, 1593],
       [  15,   12],
       [  16,   93],
       [  17,  346],
       [  18,    6],
       [  19,    4],
       [  20, 1640],
       [  21,   41],
       [  22, 1134],
       [  23,  489],
       [  24,  202],
       [  25,   50],
       [  26,  157],
       [  27,  928],
       [  28,   14],
       [  29,   41],
       [  30,    3],
       [  31,    3],
       [  32,    3],
       [  33,   12],
       [  34,  405],
       [  35,    6],
       [  36, 1284],
       [  37,    3],
       [  38,   20],
       [  39,    3],
       [  40,    3],
       [  41,  459],
       [  42,   13],
       [  43,   40],
       [  44,  624],
       [  45,   31],
       [  46,    3],
       [  47,

# Define batch generator

In [122]:
def sequencer(tokens, L):
    return [tokens[i:L+i] for i in range(0, len(tokens)-L+1)]

In [123]:
list(sequencer('a quick brown fox', 10))

['a quick br',
 ' quick bro',
 'quick brow',
 'uick brown',
 'ick brown ',
 'ck brown f',
 'k brown fo',
 ' brown fox']

In [124]:
import random
import numpy as np

class BatchGenerator:
    def __init__(self, files, tokenizer, max_input_len, sentence_len, batch_size,
                 target_begin_token, end_token, epoch_end=None):
        self.files = files
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.sentence_len = sentence_len
        self.batch_size = batch_size
        self.target_begin_token = target_begin_token
        self.end_token = end_token
        self.epoch_end = epoch_end
        
    def generate(self):
        steps = []
        while True:
            random.shuffle(self.files)
            for input_text, target_text in self.iter_files(self.files):
                tokens = self.tokenize(input_text, target_text)

                for seq_tokens in self.sequence(tokens):
                    steps.append(seq_tokens)
                    
                while len(steps) >= self.batch_size:
                    X = steps[:self.batch_size]
                    y = self.tokenizer.sequences_to_matrix([[i] for s in X for i in s])
                    y = y.reshape((self.batch_size, self.sentence_len, self.tokenizer.num_words))
                    
                    # offset
                    X = X[:-1]
                    y = y[1:]
                    yield X, y
                    
                    # reset
                    steps = steps[self.batch_size:]
            yield self.epoch_end
        
    def preprocess(self, text):
        # replace all occurences of multiple newlines and replace them
        # with a single newline padded with spaces so it is treated as a
        # token
        text = ' \n '.join(t for t in text.split('\n') if t)
        table = {ord(c): None for c in '<>'}
        text = text.translate(table)
        return text

    def iter_files(self, files):
        for f in files:
            text = open(f).read()
            text = self.preprocess(text)
            # remove highlights
            body, highlight1, *_ = text.split('@highlight')
            yield body, highlight1
        
    def tokenize(self, input_text, target_text):
        input_tokens = self.tokenizer.texts_to_sequences([input_text])[0]
        target_tokens = self.tokenizer.texts_to_sequences([target_text])[0]
        input_tokens = pad_sequences([input_tokens[:self.max_input_len]], maxlen=self.max_input_len)[0].tolist()
        target_tokens = target_tokens
        return input_tokens + [self.target_begin_token] + target_tokens + [self.end_token]

    def sequence(self, tokens):
        L = self.sentence_len
        return [tokens[i:L+i] for i in range(0, len(tokens)-L+1)]

In [125]:
batch_gen = BatchGenerator(
    files=FILES,
    tokenizer=TOKENIZER,
    max_input_len=MAX_INPUT_LEN,
    sentence_len=SENTENCE_LEN,
    batch_size=32,
    target_begin_token=TARGET_BEGIN_TOKEN,
    end_token=END_TOKEN).generate()

In [126]:
X, y = next(batch_gen)

TypeError: 'NoneType' object is not iterable

In [None]:
X

In [98]:
X.shape, y.shape

AttributeError: 'list' object has no attribute 'shape'

In [None]:
X

In [None]:
y

In [None]:
[' '.join([index_to_word.get(i, '<pad>') for i in x]) for x in X]

In [None]:
# only one per row
import numpy as np
ys = np.argwhere(y[0] == 1)

In [None]:
import numpy as np
for j in range(0, len(y), 5):
    ys = np.argwhere(y[j] == 1)
    assert len(ys) == len({row for row, idx in ys})
    print(' '.join(index_to_word[idx] for row, idx in ys))

# Training

In [None]:
N_HEADS = 8
N_LAYERS = 6
D_MODEL = 64*N_HEADS
VOCAB_SIZE = TOKENIZER.num_words
WARMUP_STEPS = 200

In [None]:
batch_gen = BatchGenerator(
    files=FILES,
    tokenizer=TOKENIZER,
    maxlen=MAX_SEQUENCE_LEN,
    batch_size=32,
    target_begin_token=TARGET_BEGIN_TOKEN,
    end_token=END_TOKEN
).generate()

In [None]:
# loop over batch generator until we hit the end of the epoch
# to calculate number of batches in epoch and compute some
# stats along the way
steps_per_epoch = 0
for batch in batch_gen:
    if batch is None:
        break
    steps_per_epoch += 1

In [None]:
print('steps per epoch', steps_per_epoch)

In [None]:
train_gen = (X for X in batch_gen if not X == 0)

In [None]:
from keras.callbacks import TerminateOnNaN
callbacks = [TerminateOnNaN()]

In [None]:
from model_decoder import TransformerDecoder
model = TransformerDecoder(
        n_heads=N_HEADS, decoder_layers=N_LAYERS,
        d_model=D_MODEL, vocab_size=VOCAB_SIZE, sequence_len=MAX_SEQUENCE_LEN,
        layer_normalization=True, dropout=True,
        residual_connections=True)

In [None]:
model.summary()

In [None]:
# import keras.backend as K
# def loss(y_true, y_pred):
#    return K.categorical_crossentropy(y_true[:,-1:,:], y_pred[:,-1:,:])

In [None]:
loss = 'categorical_crossentropy'

In [None]:
class LRScheduler:
    def __init__(self, d_model, warmup_steps):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.epoch = 1

    def lr(self, epoch):
        lr = self.d_model**-.5 * min(self.epoch**-.5, epoch*(self.warmup_steps**-1.5))
        self.epoch += 1
        return lr
lr_scheduler = LRScheduler(D_MODEL, WARMUP_STEPS)

In [None]:
from keras.callbacks import LearningRateScheduler
# callbacks.append(LearningRateScheduler(lr_scheduler.lr))

In [None]:
from keras.optimizers import adam
model.compile(loss=loss, optimizer=adam(lr=1e-4))

In [None]:
# from keras import backend as K
# old_lr = K.get_value(model.optimizer.lr)
# K.set_value(model.optimizer.lr, 1e-4)

In [None]:
n_epochs = 1000
model.fit_generator(
    train_gen, steps_per_epoch=steps_per_epoch,
    epochs=n_epochs, callbacks=callbacks)

In [None]:
X, y = next(batch_gen)

In [None]:
y

In [None]:
def show_X(X):
    print('X:', ' '.join(index_to_word[i] for i in X))
    
def show_y(y):
    ones = np.argwhere(y == 1)
    print('y:', ' '.join(index_to_word[idx] for row, idx in ones))

def show_results(model, X, y):
    show_X(X[0])
    show_y(y)
    y_hat = model.predict(X)
    show_y(y_hat[0])