In [None]:
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Embedding, Dense, LSTM, Dropout
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.preprocessing.text import Tokenizer
import pandas as pd
import numpy as np

In [None]:
with open('processed_texts.csv', 'r', encoding='UTF-8') as file:
    train_data = [line.strip('\n') for line in file]

print('Number of training sentences: ', len(train_data))

max_words = 50000 # Max size of the dictionary
tokenizer = Tokenizer(num_words=max_words)
tokenizer.fit_on_texts(train_data)
sequences = tokenizer.texts_to_sequences(train_data)

# Flatten the list of lists resulting from the tokenization. This will reduce the list
# to one dimension, allowing us to apply the sliding window technique to predict the next word
text = [item for sublist in sequences for item in sublist]
vocab_size = len(tokenizer.word_index)

In [17]:
# Training on 19 words to predict the 20th
sentence_len = 20
pred_len = 1
train_len = sentence_len - pred_len
seq = []
# Sliding window to generate train data
for i in range(len(text)-sentence_len):
    seq.append(text[i:i+sentence_len])
# Reverse dictionary to decode tokenized sequences back to words
reverse_word_map = dict(map(reversed, tokenizer.word_index.items()))

# Each row in seq is a 20 word long window. We append he first 19 words as the input to predict the 20th word
trainX = []
trainy = []
for i in seq:
    trainX.append(i[:train_len])
    trainy.append(i[-1])

len(trainy)

5117695

In [18]:
# define model
model_2 = Sequential([
    Embedding(vocab_size+1, 50, input_length=train_len),
    LSTM(100, return_sequences=True),
    LSTM(100),
    Dense(100, activation='relu'),
    Dropout(0.1),
    Dense(vocab_size, activation='softmax')
])

In [19]:
def my_metric(y_true, y_pred):
    return 1- tf.gather(y_pred, tf.dtypes.cast(y_true, tf.int32))

# Train model with checkpoints
model_2.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [20]:
filepath = "./model_2_weights.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]
x = np.asarray(trainX)
y = np.asarray(trainy)

In [None]:
model_2.fit(x,y, epochs = 30, batch_size = 512, callbacks = callbacks_list)

Epoch 1/30
Epoch 00001: loss improved from inf to 6.08176, saving model to ./model_2_weights.hdf5
Epoch 2/30
Epoch 00002: loss improved from 6.08176 to 5.66452, saving model to ./model_2_weights.hdf5
Epoch 3/30
Epoch 00003: loss improved from 5.66452 to 5.52301, saving model to ./model_2_weights.hdf5
Epoch 4/30
Epoch 00004: loss improved from 5.52301 to 5.43311, saving model to ./model_2_weights.hdf5
Epoch 5/30
Epoch 00005: loss improved from 5.43311 to 5.36716, saving model to ./model_2_weights.hdf5
Epoch 6/30
Epoch 00006: loss improved from 5.36716 to 5.31694, saving model to ./model_2_weights.hdf5
Epoch 7/30
Epoch 00007: loss improved from 5.31694 to 5.27721, saving model to ./model_2_weights.hdf5
Epoch 8/30

In [None]:
def gen(model,seq,max_len = 20):
    ''' Generates a sequence given a string seq using specified model until the total sequence length
    reaches max_len'''
    # Tokenize the input string
    tokenized_sent = tokenizer.texts_to_sequences([seq])
    max_len = max_len+len(tokenized_sent[0])
    # If sentence is not as long as the desired sentence length, we need to 'pad sequence' so that
    # the array input shape is correct going into our LSTM. the `pad_sequences` function adds 
    # zeroes to the left side of our sequence until it becomes 19 long, the number of input features.
    while len(tokenized_sent[0]) < max_len:
        padded_sentence = tf.keras.preprocessing.sequence.pad_sequences(tokenized_sent[-19:],maxlen=19)
        op = model.predict(np.asarray(padded_sentence).reshape(1,-1))
        tokenized_sent[0].append(op.argmax()+1)
        
    return " ".join(map(lambda x : reverse_word_map[x],tokenized_sent[0]))

In [None]:
gen(model_2,'The climate change is a global crisis for ')