# **File structure**

In the first part of this file, we will be concerned only with the training of the model. After the training is done, some generation methods for testing the model will be implemented.

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from keras.models import Sequential
from keras.layers import Embedding, LSTM, Dense, Dropout, SimpleRNN, TimeDistributed, Flatten
from keras.preprocessing.sequence import pad_sequences
from keras.callbacks import ModelCheckpoint
import pandas as pd
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

In [None]:
# import training data
data = tfds.load('amazon_us_reviews/Digital_Software_v1_00', split='train', shuffle_files=True, download=True, batch_size=-1)

In [None]:
#preprocessing step: tokenization
reviews = data["data"]["review_body"].numpy()
for i in range(len(reviews)):
  reviews[i] = reviews[i].decode('utf-8')

max_words = 50000
tokenizer = Tokenizer(num_words=max_words)
tokenizer.fit_on_texts(reviews)
sequences = tokenizer.texts_to_sequences(reviews)
text = [item for sublist in sequences for item in sublist]
vocab_size = len(tokenizer.word_index)

In [None]:
# generate training  sequences
sentence_len = 20
pred_len = 1
train_len = sentence_len - pred_len
seq = []

for i in range(len(text)-sentence_len):
    seq.append(text[i:i+sentence_len])
    
reverse_word_map = dict(map(reversed, tokenizer.word_index.items()))

trainX = []
trainy = []
for i in seq:
    trainX.append(i[:train_len])
    trainy.append(i[-1])

In [None]:
# Custom Generator function
def format_data_as_generator(dataX, dataY, i, batch_size, size):
  
  start_of_batch = i*batch_size % size
  end_of_batch = (i+1)*batch_size % size

  if (start_of_batch < end_of_batch):
    batch_of_x = dataX[start_of_batch:end_of_batch]
    batch_of_y = dataY[start_of_batch:end_of_batch]
  else:
    batch_of_x = dataX[start_of_batch:size-1] + dataX[0:end_of_batch]
    batch_of_y = dataY[start_of_batch:size-1] + dataY[0:end_of_batch]

  one_hot_y = []
  for word in batch_of_y:
    empty_one_hot = np.zeros(shape=(vocab_size))
    empty_one_hot[word - 1] = 1
    one_hot_y.append(empty_one_hot)

  return (np.asarray(batch_of_x), np.asarray(one_hot_y))

def batch_generator(dataX, dataY, batch_size):
  i = 0

  size = len(dataX)
  while True:
    yield format_data_as_generator(dataX, dataY, i, batch_size, size)
    i += 1

# ** Note that the cell below should only be run in case of training **

If you would like to only test the model by generating new texts, skip this block

To load a set of weights as a training starting point, uncomment the code line in the middle of this block

In [None]:
# Actual training 
batch_size = 3500
epochs = 300
step_size_per_epoch = int(np.floor(len(trainX)/batch_size))
generator = batch_generator(trainX, trainy, batch_size)

model = Sequential([
    Embedding(vocab_size+1, 50, input_length=train_len),
    LSTM(100, return_sequences=True),
    LSTM(100),
    Dense(100, activation='relu'),
    Dropout(0.1),
    Dense(vocab_size, activation='softmax')
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# uncomment line below to load earlier saved model weights
# model.load_weights("./model_lstm_weights.200.hdf5")

filepath = "./model_lstm_weights_new.{epoch:03d}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=False, mode='min')
callbacks_list = [checkpoint]

history = model.fit_generator(generator,
         epochs = 300,
         steps_per_epoch = step_size_per_epoch,
         callbacks = callbacks_list,
         max_queue_size = 10,
         verbose = 1)

# **Text Generation**

The remaining part of this file will take care of the testing of the model by generating new text sequences. If you only want to train the model you can stop here.

Note that you should have run all of the cells above, except the specific training cells in order for the code below to work.

In [None]:
# generation of new text based on argmax method
def genArgMax(model,seq,max_len = 20):

    tokenized_sent = tokenizer.texts_to_sequences([seq])
    max_len = max_len+len(tokenized_sent[0])

    while len(tokenized_sent[0]) < max_len:
        padded_sentence = pad_sequences(tokenized_sent[-19:],maxlen=19)
        op = model.predict(np.asarray(padded_sentence).reshape(1,-1))
        tokenized_sent[0].append(op.argmax() + 1)
        
    return " ".join(map(lambda x : reverse_word_map[x],tokenized_sent[0]))

In [None]:
# generation of new text based on probability distribution sampling method
def genProbDist(model,seq,max_len = 20):
  
    tokenized_sent = tokenizer.texts_to_sequences([seq])
    max_len = max_len+len(tokenized_sent[0])

    while len(tokenized_sent[0]) < max_len:
        padded_sentence = pad_sequences(tokenized_sent[-19:],maxlen=19)
        op = model.predict(np.asarray(padded_sentence).reshape(1,-1))
        tokenized_sent[0].append(np.random.choice(vocab_size, 1, p=op[0])[0] + 1)
        
    return " ".join(map(lambda x : reverse_word_map[x],tokenized_sent[0]))

In [None]:
# Test data is loaded and saved from:
# tfds.load('amazon_us_reviews/Digital_Video_Games_v1_00', split='train', shuffle_files=True, download=True, batch_size=-1)

# preprocess test data
numResults = 25
inputs = []
i = 0

while (len(inputs) < numResults):
  text_file = text_file = open("/data/dat_" + str(i) + ".txt", "r")
  data = text_file.readlines()
  dataTok = data[0].split()
  if (len(dataTok) >= 19):
    inputs.append(' '.join(dataTok[:19]))
  i += 1

print(inputs[0])

In [None]:
# generate new texts
model_out = Sequential([
    Embedding(vocab_size+1, 50, input_length=train_len),
    LSTM(100, return_sequences=True),
    LSTM(100),
    Dense(100, activation='relu'),
    Dropout(0.1),
    Dense(vocab_size, activation='softmax')
])

print("\n\x1b[31mModel after training 200 epochs\x1b[0m\n")
model_out.load_weights("./model_lstm_weights.200.hdf5")

for seq in inputs:
  print(seq)
  print(genArgMax(model_out, seq, 25))
  print(genProbDist(model_out, seq, 25))
  print()