# LSTM Baseline

In [None]:
# Import necessary modules
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras
from keras.preprocessing.sequence import pad_sequences
from keras.models import Model
from keras.layers import Input, LSTM, Dense, Embedding, TimeDistributed

In [None]:
# Change this to the path of the opi directory
data_path = "./clin-summ/data/opi"

In [None]:
import os

assert os.path.exists(data_path), 'Dataset not found ({})'.format(data_path)

def load_data(filename: str):
    df = pd.read_json(os.path.join(data_path, filename), lines=True)
    return df["inputs"].to_numpy(), df["target"].to_numpy()

train_x, train_y = load_data("train.jsonl")
val_x, val_y = load_data("validate.jsonl")
test_x, test_y = load_data("test.jsonl")

In [None]:
# Add start and end tokens to summaries
train_y = ['sostok ' + y + ' eostok' for y in train_y]
val_y = ['sostok ' + y + ' eostok' for y in val_y]
test_y = ['sostok ' + y + ' eostok' for y in test_y]

In [None]:
import matplotlib.pyplot as plt

text_lengths = [len(text.split()) for text in train_x]
summary_lengths = [len(summary.split()) for summary in train_y]

# Show the distribution of text/summary lengths

fig, (ax1, ax2) = plt.subplots(1, 2)
counts_x, bins_x = np.histogram(text_lengths)
ax1.stairs(counts_x, bins_x)
counts_y, bins_y = np.histogram(summary_lengths)
ax2.stairs(counts_y, bins_y)

In [None]:
max_text_len = np.max(text_lengths)
max_summary_len = np.max(summary_lengths)

In [None]:
from keras.layers import TextVectorization

# Tokenize text and summaries, and pad to maximum length.

max_tokens = 5000
x_tokenizer = TextVectorization(
    max_tokens=max_tokens,
    output_mode='int')
x_tokenizer.adapt(train_x)

train_x_seq = x_tokenizer(train_x)
val_x_seq = x_tokenizer(val_x)
test_x_seq = x_tokenizer(test_x)

y_tokenizer = TextVectorization(
    max_tokens=max_tokens,
    output_mode='int')
y_tokenizer.adapt(train_y)

train_y_seq = y_tokenizer(train_y)
val_y_seq = y_tokenizer(val_y)
test_y_seq = y_tokenizer(test_y)

# Pad text/summaries to max length by adding 0s at the end.
encoder_input_tr = pad_sequences(train_x_seq, maxlen=max_text_len, padding="post")
encoder_input_val = pad_sequences(val_x_seq, maxlen=max_text_len, padding="post")
encoder_input_test = pad_sequences(test_x_seq, maxlen=max_text_len, padding="post")

decoder_input_tr = pad_sequences(train_y_seq, maxlen=max_summary_len, padding="post")
decoder_input_val = pad_sequences(val_y_seq, maxlen=max_summary_len, padding="post")
decoder_input_test = pad_sequences(test_y_seq, maxlen=max_summary_len, padding="post")

# Move sequence back one step for teacher forcing.
decoder_target_tr = np.roll(decoder_input_tr, -1, axis=1)
decoder_target_tr[:, -1] = 0 
decoder_target_val = np.roll(decoder_input_val, -1, axis=1)
decoder_target_val[:, -1] = 0 
decoder_target_test = np.roll(decoder_input_test, -1, axis=1)
decoder_target_test[:, -1] = 0 

x_voc = len(x_tokenizer.get_vocabulary()) + 1
y_voc = len(y_tokenizer.get_vocabulary()) + 1

In [None]:
print(encoder_input_tr.shape)

In [None]:
# Reference: https://keras.io/examples/nlp/lstm_seq2seq/
latent_dim = 300
embedding_dim = 200

# Encoder Input
encoder_inputs = Input(shape=(max_text_len,), name='enc_input')

# Embedding layer
enc_emb =  Embedding(x_voc, max_text_len, trainable=True, name='enc_embedding')(encoder_inputs)

# Encoder LSTMs
encoder_lstm1 = LSTM(latent_dim, return_sequences=True, return_state=True, dropout=0.5, recurrent_dropout=0.5, name='enc_lstm1')
encoder_outputs1, state_h1, state_c1= encoder_lstm1(enc_emb)

encoder_lstm2 = LSTM(latent_dim, return_sequences=True, return_state=True, dropout=0.5, recurrent_dropout=0.5, name='enc_lstm2')
encoder_outputs2, state_h2, state_c2= encoder_lstm2(encoder_outputs1)

encoder_lstm3 = LSTM(latent_dim, return_sequences=True, return_state=True, dropout=0.5, recurrent_dropout=0.5, name='enc_lstm3')
encoder_outputs, state_h, state_c= encoder_lstm3(encoder_outputs2)

decoder_inputs = Input(shape=(None,), name='dec_input')

# Embedding layer
dec_emb_layer = Embedding(y_voc, embedding_dim, trainable=True, name='dec_embedding')
dec_emb = dec_emb_layer(decoder_inputs)

decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True, dropout=0.5, recurrent_dropout=0.25, name='dec_lstm')
decoder_outputs,decoder_fwd_state, decoder_back_state = decoder_lstm(dec_emb, initial_state=[state_h, state_c])

#dense layer - softmax
decoder_dense =  TimeDistributed(Dense(y_voc, activation='softmax'), name='dec_output')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model 
model = Model([encoder_inputs, decoder_inputs], decoder_outputs, name='lstm_baseline')

model.summary()

In [None]:
lstm_model = 'lstm_model.keras'
from keras.callbacks import EarlyStopping


if os.path.exists(lstm_model):
    model = keras.models.load_model(lstm_model)

else:
    model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    early_stop_cb = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

    batch_size = 40
    epochs = 20

    model.fit([encoder_input_tr, decoder_input_tr],
              decoder_target_tr,
              batch_size=batch_size,
              epochs=epochs,
              validation_data=([encoder_input_val, decoder_input_val], decoder_target_val),
              callbacks=[early_stop_cb]
             )

    model.save("lstm_model.keras")

In [None]:
reverse_encode_seq = dict((index, value) for (index, value) in enumerate(x_tokenizer.get_vocabulary()))
reverse_decode_seq = dict((index, value) for (index, value) in enumerate(y_tokenizer.get_vocabulary()))

In [None]:
def sample(preds, temperature=1.0):
    # helper function to sample an index from a probability array
    preds = np.asarray(preds).astype("float64")
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

In [None]:
# Rebuild the model for inference
for layer in model.layers:
    print(layer.name)

model = keras.models.load_model("lstm_model.keras")
encoder_inputs = model.input[0]
encoder_outputs, state_h_enc, state_c_enc = model.layers[6].output  # lstm_1
encoder_states = [state_h_enc, state_c_enc]
encoder_model = keras.Model(encoder_inputs, encoder_states)

decoder_inputs = model.input[1]
decoder_emb = model.layers[5](decoder_inputs)
decoder_state_input_h = keras.Input(shape=(latent_dim,))
decoder_state_input_c = keras.Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_lstm = model.layers[7]
decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(decoder_emb, initial_state=decoder_states_inputs)

decoder_states = [state_h_dec, state_c_dec]
decoder_dense = model.layers[8]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = keras.Model(
    [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
)

def decode_sequence(input_sequence, temp=1.0):
    states_value = encoder_model.predict(input_sequence, verbose=0)
    target_seq = np.zeros((1,1))
    target_seq[0, 0] = y_tokenizer('sostok')
    
    stop_condition = False
    decoded_sentence = ""
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value, verbose=0
        )

        # Sample a token
        sampled_token_index = sample(output_tokens[0, -1, :], temp)
        sampled_vocab = reverse_decode_seq[sampled_token_index]

        # Exit condition: either hit max length
        # or find stop character.
        if(sampled_vocab != 'eostok' and len(decoded_sentence) < 200):
            decoded_sentence += ' ' + sampled_vocab
        else:
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1,1))
        target_seq[0, 0] = sampled_token_index

        # Update states
        states_value = [h, c]
    return decoded_sentence

In [None]:
# Take sequences from training set to test inference.
for seq_index in range(0, 20):
    input_seq = encoder_input_tr[seq_index : seq_index + 1]
    decoded_sentence = decode_sequence(input_seq, temp=0.7)
    print("-")
    print("Input sentence:", train_x[seq_index])
    print("Decoded sentence:", decoded_sentence)
    print("Expected:", ' '.join(train_y[seq_index].split()[1:-1]))
