# Seq2Seq LSTM (Tensorflow)

The basic code on which this is based comes from [this Keras tutorial on a Seq2Seq LSTM model](https://keras.io/examples/nlp/lstm_seq2seq/). However, this differs in a few key places:

1. The tutorial model predicts the text at the character level (character-by-character). This model predicts at the token level.
2. The tutorial uses basic one-hot-encoding for characters. This model uses word embeddings.
3. This code not only includes a demo on how sequence prediction (decoding) works but also how metrics calculation can be done using BLEU score, a standard metric for machine translation tasks.
4. The tutorial constructs the inference models after training the base model. This code constructs the training model along with the inference encoder and decoder at the same time, leveraging the fact that the encoder and decoder are simply re-using components of the training model that will be fit to the data during training.

The data used can be found [here](http://www.manythings.org/anki/fra-eng.zip).

In [None]:
# Imports
import re
import nltk
import numpy as np
import tensorflow as tf
from statistics import mean
from tensorflow import keras
from collections import defaultdict

## Setup

In [None]:
# GPU check
gpus = tf.config.list_logical_devices('GPU')
print(gpus[0].name)

In [None]:
# Hyperparameters
BATCH_SIZE = 64        # batch size for training
EPOCHS = 125           # number of epochs to train for
LATENT_DIM = 256       # latent dimensionality of the encoding space
NUM_SAMPLES = 10000    # number of samples to train on
EMBED_SIZE = 512       # number or params in the embedding layer

In [None]:
# Data path
DATA_PATH = "fra.txt"  # path to data

In [None]:
# Simple tokenization utils
BOS = "[BOS]"          # beginning of sequence token
EOS = "[EOS]"          # end of sequence token
PAD = "[PAD]"          # padding token
SPEC_TOKENS = [PAD, BOS, EOS]

# Punctuation regex
PUNKT_RE = re.compile(r"[\.!\?]+")
# Multiple space regex
MULTI_SPACE_RE = re.compile(r"\s\s+")

def format_text(text):
    '''
    Remove punctuation, consolidate white space, lowercase,
    and split a sequence
    '''
    return MULTI_SPACE_RE.sub(
        " ", PUNKT_RE.sub(
            "", text
        )
    ).lower().strip().split(" ")

## Data

In [None]:
# Load and prep data

# Input/target text lists
INPUT_TEXTS = []
TARGET_TEXTS = []

# Input/target token sets
INPUT_TOKENS = set()
TARGET_TOKENS = set()

# Dictionary of input sequence to output sequence(s)
SRC_TO_TGT_MAP = defaultdict(set)

with open(DATA_PATH, "r", encoding="utf-8") as _f:
    for i, line in enumerate(_f):
        if i == NUM_SAMPLES:
            break

        input_text, target_text = line.split("\t")[:2]
        input_text = format_text(input_text)
        target_text = format_text(target_text)

        SRC_TO_TGT_MAP[" ".join(input_text)].add(" ".join(target_text))

        target_text = [BOS] + target_text + [EOS]

        INPUT_TEXTS.append(input_text)
        TARGET_TEXTS.append(target_text)

        INPUT_TOKENS |= set(input_text)
        TARGET_TOKENS |= set(target_text)

# Aggregate tokens, sequences, and counts
ALL_TOKENS = sorted(list(INPUT_TOKENS | TARGET_TOKENS))
NUM_ALL_TOKENS = len(ALL_TOKENS)
INPUT_TOKENS = sorted(list(INPUT_TOKENS))
TARGET_TOKENS = sorted(list(TARGET_TOKENS))
NUM_INPUT_TOKENS = len(INPUT_TOKENS) + len(SPEC_TOKENS)
NUM_TARGET_TOKENS = len(TARGET_TOKENS) + len(SPEC_TOKENS)
MAX_INPUT_SEQ_LEN = max([len(txt) for txt in INPUT_TEXTS])
MAX_TARGET_SEQ_LEN = max([len(txt) for txt in TARGET_TEXTS])

# Token-to-index (and reverse) mappings
INPUT_TOK_TO_IDX = {tok: i for i, tok in enumerate(SPEC_TOKENS)}
TARGET_TOK_TO_IDX = {tok: i for i, tok in enumerate(SPEC_TOKENS)}
for i, tok in enumerate(INPUT_TOKENS):
    INPUT_TOK_TO_IDX[tok] = i + len(SPEC_TOKENS)
INPUT_IDX_TO_TOK = {i: tok for tok, i in INPUT_TOK_TO_IDX.items()}
for i, tok in enumerate(TARGET_TOKENS):
    TARGET_TOK_TO_IDX[tok] = i + len(SPEC_TOKENS)
TARGET_IDX_TO_TOK = {i: tok for tok, i in TARGET_TOK_TO_IDX.items()}

print(f"NUMBER OF SAMPLES: {len(INPUT_TEXTS)}")
print(f"NUMBER OF UNIQUE INPUT TOKENS: {NUM_INPUT_TOKENS}")
print(f"NUMBER OF UNIQUE TARGET TOKENS: {NUM_TARGET_TOKENS}")
print(f"MAX INPUT SEQ LEN: {MAX_INPUT_SEQ_LEN}")
print(f"MAX TARGET SEQ LEN: {MAX_TARGET_SEQ_LEN}")

In [None]:
# Set up encoder input & decoder input/target data as numpy arrays for training.
# This is done manually here but can also be completed using different methods
# and classes from Keras like its Tokenizer class and tokens_to_word_sequences.
# The purpose is to simply get a space-delimited series of tokens into a vector
# where each token is represented by a unique number or vector of numbers (as
# in one-hot encoding).

# Encoder/decoder input data are only single-dimension arrays per sequence.
# Tokens are converted to indexs from the tok_to_idx mappings, only integers.
# These will be converted in the embeedding layers
ENCODER_INPUT_DATA = np.zeros((len(INPUT_TEXTS), MAX_INPUT_SEQ_LEN), dtype="int32")
DECODER_INPUT_DATA = np.zeros((len(INPUT_TEXTS), MAX_TARGET_SEQ_LEN), dtype="int32")

# Decoder target data is multi-dimensional arrays per sequence where each token is
# one-hot encoded. Target data is handled as a multi-class classification problem
# (i.e. which token has the highest logit).
DECODER_TARGET_DATA = np.zeros((len(INPUT_TEXTS), MAX_TARGET_SEQ_LEN, NUM_TARGET_TOKENS), dtype="float32")

for i, (input_text, target_text) in enumerate(zip(INPUT_TEXTS, TARGET_TEXTS)):
    for j, in_tok in enumerate(input_text):
        ENCODER_INPUT_DATA[i, j] = np.int32(INPUT_TOK_TO_IDX[in_tok])
    for k, tgt_tok in enumerate(target_text):
        DECODER_INPUT_DATA[i, k] = np.int32(TARGET_TOK_TO_IDX[tgt_tok])
        if k > 0:
            DECODER_TARGET_DATA[i, k-1, TARGET_TOK_TO_IDX[tgt_tok]] = 1.0

## Model

In [None]:
# Set up models

# Shared embedding
embed_layer = keras.layers.Embedding(input_dim = NUM_ALL_TOKENS, output_dim = EMBED_SIZE)

# Encoder parts
encoder_inputs = keras.Input(shape=(MAX_INPUT_SEQ_LEN, ))
encoder_embedding = embed_layer(encoder_inputs)
encoder_lstm = keras.layers.LSTM(LATENT_DIM, return_state=True)
encoder_lstm_outputs, enc_state_h, enc_state_c = encoder_lstm(encoder_embedding)
encoder_states = [enc_state_h, enc_state_c]

# Decoder parts
decoder_inputs = keras.Input(shape=(MAX_TARGET_SEQ_LEN, ))
decoder_embedding = embed_layer(decoder_inputs)
decoder_lstm = keras.layers.LSTM(LATENT_DIM, return_sequences=True, return_state=True)
decoder_lstm_outputs, _, _ = decoder_lstm(decoder_embedding, initial_state=encoder_states)
decoder_dense = keras.layers.Dense(NUM_TARGET_TOKENS, activation="softmax")
decoder_outputs = decoder_dense(decoder_lstm_outputs)

# Training model
model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)

# Encoder for inference
encoder = keras.Model(encoder_inputs, encoder_states)

# Decoder for inference
decoder_state_input_h = keras.Input(shape=(LATENT_DIM, ))
decoder_state_input_c = keras.Input(shape=(LATENT_DIM, ))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_infer_lstm_outputs, dec_state_h, dec_state_c = decoder_lstm(decoder_embedding, initial_state=decoder_states_inputs)
decoder_states = [dec_state_h, dec_state_c]
decoder_infer_outputs = decoder_dense(decoder_infer_lstm_outputs)
decoder = keras.Model([decoder_inputs] + decoder_states_inputs, [decoder_infer_outputs] + decoder_states)

In [None]:
# Show structure of the training model
model.summary()

In [None]:
# Show structure of the inference encoder
encoder.summary()

In [None]:
# Show structure of the inference decoder
decoder.summary()

In [None]:
# Compile and train model
with tf.device(gpus[0].name):
    model.compile(optimizer=tf.keras.optimizers.legacy.Adam(), loss="categorical_crossentropy", metrics=["accuracy"])
    model.fit(
        [ENCODER_INPUT_DATA, DECODER_INPUT_DATA],
        DECODER_TARGET_DATA,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_split=0.2
    )

## Inference

In [None]:
def decode_sequence(input_seq):
    '''
    Take an input sequence and get its model prediction.
    '''

    # Run the input seq through the encoder to get the encoder states.
    states_value = encoder.predict(input_seq)

    # Construct a single input sequence and insert the BOS token in the first slot.
    target_seq = np.zeros((1, MAX_TARGET_SEQ_LEN), dtype="int32")
    target_seq[0, 0] = TARGET_TOK_TO_IDX[BOS]

    # Iterate until 1 or the 2 stop conditions are met:
    #  1. Decoder predicts the EOS token
    #  2. Maximum target seq length is hit
    stop_condition = False
    decoded_sequence = []
    while not stop_condition:
        # Get the decoder prediction and updated states by passing it the
        # target sequence and the encoder states (or updated decoder states
        # if this isn't the first pass of the iteration).
        output_token, h, c = decoder.predict([target_seq] + states_value)

        # Get the token's index as the argmax of decoder output logits.
        sampled_token_idx = np.argmax(output_token[0][0])
        # Convert index to the actual token.
        sampled_token = TARGET_IDX_TO_TOK[sampled_token_idx]

        # Check stop conditions.
        if sampled_token == EOS or len(decoded_sequence) > MAX_TARGET_SEQ_LEN:
            stop_condition = True
        else:
            decoded_sequence.append(sampled_token)

        # Update the target sequence.
        target_seq = np.zeros((1, MAX_TARGET_SEQ_LEN), dtype="int32")
        target_seq[0, 0] = sampled_token_idx

        # Update the states with the updated states from the decoder.
        states_value = [h, c]
    
    return " ".join(decoded_sequence)

In [None]:
def calculate_bleu(refs, hyp):
    '''
    Calculate BLEU score as the mean of unigram through 4-gram BLEU scores.
    '''
    return mean(
        nltk.translate.bleu_score.sentence_bleu(
            [ref.split(" ") for ref in refs],
            hyp.split(" "),
            weights=[(1.0, ), (1./2., 1./2.),  (1./3., 1./3., 1./3.),  (1./4., 1./4., 1./4., 1./4.)]
        )
    )

### Example Inference

In [None]:
# Get 1 input sequence for demo
input_seq = ENCODER_INPUT_DATA[56:57]

In [None]:
# Print the input sequence
input_seq

In [None]:
# Print the input sequence converted back into tokens
[INPUT_IDX_TO_TOK[idx] for idx in input_seq[0]]

In [None]:
# Decoded prediction
print(decode_sequence(input_seq))

### Full Test

In [None]:
# Only run on 20 unique input sequences
TEST_DECODE_MAX = 20

In [None]:
# Gather 1 prediction per input sequence.
# Input sequences exist more than once in INPUT_TEXTS because each
# has more than 1 translation in the linked data. We only want the top one.
preds = {}
for i, input_text in enumerate(INPUT_TEXTS):
    if len(preds.keys()) == TEST_DECODE_MAX:
        break
    input_text = " ".join(input_text)
    if input_text not in preds.keys():
        if input_text == "i see": print(i)
        input_seq = ENCODER_INPUT_DATA[i:i+1]

        with tf.device(gpus[0].name):
            decoded_seq = decode_sequence(input_seq)

        preds[input_text] = decoded_seq    

In [None]:
# Print the results.
# For each input text, show:
#  1. The text, itself
#  2. The model's hypothesis
#  3. The possible correct references
#  4. If the hypothesis is one of the references
#  5. The BLEU score of the hypothesis against the references
for src, hyp in preds.items():
    refs = SRC_TO_TGT_MAP[src]
    print("----------------------")
    print(f"Source: \"{src}\"")
    print(f"Hypothesis Translation: \"{hyp}\"")
    print(f"Reference Translations: \"{refs}\"")
    print(f"Hypothesis in references: {hyp in refs}")
    print(f"BLEU Score: {calculate_bleu(refs, hyp)}")