### TensorFlow Addons Networks : Sequence-to-Sequence NMT with Attention Mechanism

In [1]:
!wget --quiet http://www.manythings.org/anki/deu-eng.zip
!unzip deu-eng.zip

Archive:  deu-eng.zip
  inflating: deu.txt                 
  inflating: _about.txt              


In [2]:
!pip install tensorflow-addons



In [3]:
import re
import csv
import string
import random
import itertools

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

from pickle import load, dump
from typing import List, Tuple
from unicodedata import normalize
from keras.models import load_model
from keras.utils.vis_utils import plot_model
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical
from nltk.translate.bleu_score import corpus_bleu
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Dense, LSTM, Embedding

### Data Preprocessing

In [4]:
# Start of sentence
SOS = "<start>"
# End of sentence
EOS = "<end>"
# Relevant punctuation
PUNCTUATION = set("?,!.")


def load_dataset(filename: str) -> str:
    """
    load dataset into memory
    """
    with open(filename, mode = "rt", encoding = "utf-8") as fp:
        return fp.read()


def to_pairs(dataset: str, limit: int = None, shuffle = False) -> List[Tuple[str, str]]:
    """
    Split dataset into pairs of sentences, discards dataset line info.

    e.g.
    input -> 'Go.\tGeh.\tCC-BY 2.0 (France) Attribution: tatoeba.org
    #2877272 (CM) & #8597805 (Roujin)'
    output -> [('Go.', 'Geh.')]

    :param dataset: dataset containing examples of translations between
    two languages
    the examples are delimited by `\n` and the contents of the lines are
    delimited by `\t`
    :param limit: number that limit dataset size (optional)
    :param shuffle: default is True
    :return: list of pairs
    """
    assert isinstance(limit, (int, type(None))), TypeError(
        "the limit value must be an integer"
    )
    lines = dataset.strip().split("\n")
    # Radom dataset
    if shuffle is True:
        random.shuffle(lines)
    number_examples = limit or len(lines)  # if None get all
    pairs = []
    for line in lines[: abs(number_examples)]:
        # take only source and target
        src, trg, _ = line.split("\t")
        pairs.append((src, trg))

    # dataset size check
    assert len(pairs) == number_examples
    return pairs


def separe_punctuation(token: str) -> str:
    """
    Separe punctuation if exists
    """

    if not set(token).intersection(PUNCTUATION):
        return token
    for p in PUNCTUATION:
        token = f" {p} ".join(token.split(p))
    return " ".join(token.split())


def preprocess(sentence: str, add_start_end: bool=True) -> str:
    """

    - convert lowercase
    - remove numbers
    - remove special characters
    - separe punctuation
    - add start-of-sentence <start> and end-of-sentence <end>

    :param add_start_end: add SOS (start-of-sentence) and EOS (end-of-sentence)
    """
    re_print = re.compile(f"[^{re.escape(string.printable)}]")
    # convert lowercase and normalizing unicode characters
    sentence = (
        normalize("NFD", sentence.lower()).encode("ascii", "ignore").decode("UTF-8")
    )
    cleaned_tokens = []
    # tokenize sentence on white space
    for token in sentence.split():
        # removing non-printable chars form each token
        token = re_print.sub("", token).strip()
        # ignore tokens with numbers
        if re.findall("[0-9]", token):
            continue
        # add space between words and punctuation eg: "ok?go!" => "ok ? go !"
        token = separe_punctuation(token)
        cleaned_tokens.append(token)

    # rebuild sentence with space between tokens
    sentence = " ".join(cleaned_tokens)

    # adding a start and an end token to the sentence
    if add_start_end is True:
        sentence = f"{SOS} {sentence} {EOS}"
    return sentence


def dataset_preprocess(dataset: List[Tuple[str, str]]) -> Tuple[List[str], List[str]]:
    """
    Returns processed database

    :param dataset: list of sentence pairs
    :return: list of paralel data e.g. 
    (['first source sentence', 'second', ...], ['first target sentence', 'second', ...])
    """
    source_cleaned = []
    target_cleaned = []
    for source, target in dataset:
        source_cleaned.append(preprocess(source))
        target_cleaned.append(preprocess(target))
    return source_cleaned, target_cleaned

### Create Dataset
- limit number of examples
- load dataset into pairs ```[('Be nice.', 'Seien Sie nett!'), ('Beat it.', 'Geh weg!'), ...]```
- preprocessing dataset

In [5]:
NUM_EXAMPLES = 10000 

filename = 'deu.txt'
dataset = load_dataset(filename)

pairs = to_pairs(dataset, limit = NUM_EXAMPLES)
print(f"Dataset size: {len(pairs)}")
raw_data_en, raw_data_ge = dataset_preprocess(pairs)

# show last 5 pairs
for pair in zip(raw_data_en[-5:],raw_data_ge[-5:]):
    print(pair)

Dataset size: 10000
('<start> tom was crying . <end>', '<start> tom flennte . <end>')
('<start> tom was eating . <end>', '<start> tom hat gegessen . <end>')
('<start> tom was famous . <end>', '<start> tom war beruhmt . <end>')
('<start> tom was framed . <end>', '<start> tom wurde reingelegt . <end>')
('<start> tom was fuming . <end>', '<start> tom war wutend . <end>')


### Tokenization

In [6]:
en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters = '')
en_tokenizer.fit_on_texts(raw_data_en)

data_en = en_tokenizer.texts_to_sequences(raw_data_en)
data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en,padding = 'post')

ge_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters = '')
ge_tokenizer.fit_on_texts(raw_data_ge)

data_ge = ge_tokenizer.texts_to_sequences(raw_data_ge)
data_ge = tf.keras.preprocessing.sequence.pad_sequences(data_ge,padding = 'post')

In [7]:
def max_len(tensor):
    return max( len(t) for t in tensor)

### Training Prep

In [8]:
X_train,  X_test, Y_train, Y_test = train_test_split(data_en,data_ge,test_size = 0.2)
BATCH_SIZE = 64
BUFFER_SIZE = len(X_train)
steps_per_epoch = BUFFER_SIZE//BATCH_SIZE
embedding_dims = 256
rnn_units = 1024
dense_units = 1024
Dtype = tf.float32   #used to initialize DecoderCell Zero state

Tx = max_len(data_en)
Ty = max_len(data_ge)  

input_vocab_size = len(en_tokenizer.word_index)+1  
output_vocab_size = len(ge_tokenizer.word_index)+ 1
dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder = True)
example_X, example_Y = next(iter(dataset))
print(example_X.shape) 
print(example_Y.shape) 

(64, 9)
(64, 13)


### Build Model

In [9]:
# ENCODER
class EncoderNetwork(tf.keras.Model):
    def __init__(self,input_vocab_size,embedding_dims, rnn_units ):
        super().__init__()
        self.encoder_embedding = tf.keras.layers.Embedding(input_dim = input_vocab_size,
                                                           output_dim = embedding_dims)
        self.encoder_rnnlayer = tf.keras.layers.LSTM(rnn_units,return_sequences = True, 
                                                     return_state = True )
    
# DECODER
class DecoderNetwork(tf.keras.Model):
    def __init__(self,output_vocab_size, embedding_dims, rnn_units):
        super().__init__()
        self.decoder_embedding = tf.keras.layers.Embedding(input_dim = output_vocab_size,
                                                           output_dim = embedding_dims) 
        self.dense_layer = tf.keras.layers.Dense(output_vocab_size)
        self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units)
        # Sampler
        self.sampler = tfa.seq2seq.sampler.TrainingSampler()
        # Create attention mechanism with memory = None
        self.attention_mechanism = self.build_attention_mechanism(dense_units,None,BATCH_SIZE*[Tx])
        self.rnn_cell =  self.build_rnn_cell(BATCH_SIZE)
        self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler = self.sampler,
                                                output_layer = self.dense_layer)

    def build_attention_mechanism(self, units,memory, memory_sequence_length):
        return tfa.seq2seq.LuongAttention(units, memory = memory, 
                                          memory_sequence_length = memory_sequence_length)

    # wrap decodernn cell  
    def build_rnn_cell(self, batch_size ):
        rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnncell, self.attention_mechanism,
                                                attention_layer_size=dense_units)
        return rnn_cell
    
    def build_decoder_initial_state(self, batch_size, encoder_state, Dtype):
        decoder_initial_state = self.rnn_cell.get_initial_state(batch_size = batch_size, 
                                                                dtype = Dtype)
        decoder_initial_state = decoder_initial_state.clone(cell_state = encoder_state) 
        return decoder_initial_state

encoderNetwork = EncoderNetwork(input_vocab_size,embedding_dims, rnn_units)
decoderNetwork = DecoderNetwork(output_vocab_size,embedding_dims, rnn_units)
optimizer = tf.keras.optimizers.Adam()

### Training

In [10]:
def loss_function(y_pred, y):
   
    #shape of y [batch_size, ty]
    #shape of y_pred [batch_size, Ty, output_vocab_size] 
    sparsecategoricalcrossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True,
                                                                                  reduction = 'none')
    loss = sparsecategoricalcrossentropy(y_true = y, y_pred = y_pred)
    mask = tf.logical_not(tf.math.equal(y,0))   #output 0 for y=0 else output 1
    mask = tf.cast(mask, dtype = loss.dtype)
    loss = mask*loss
    loss = tf.reduce_mean(loss)
    return loss

def train_step(input_batch, output_batch,encoder_initial_cell_state):
    loss = 0
    with tf.GradientTape() as tape:
        encoder_emb_inp = encoderNetwork.encoder_embedding(input_batch)
        a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp, 
                                                        initial_state = encoder_initial_cell_state)

        # [last step activations,last memory_state] of encoder passed as input to decoder Network
        
         
        # Prepare correct Decoder input & output sequence data
        decoder_input = output_batch[:,:-1] # ignore <end>
        # compare logits with timestepped +1 version of decoder_input
        decoder_output = output_batch[:,1:] # ignore <start>


        # Decoder Embeddings
        decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)

        # Setting up decoder memory from encoder output and Zero State for AttentionWrapperState
        decoderNetwork.attention_mechanism.setup_memory(a)
        decoder_initial_state = decoderNetwork.build_decoder_initial_state(BATCH_SIZE,
                                                                           encoder_state = [a_tx, c_tx],
                                                                           Dtype = tf.float32)
        
        # BasicDecoderOutput        
        outputs, _, _ = decoderNetwork.decoder(decoder_emb_inp,initial_state = decoder_initial_state,
                                               sequence_length = BATCH_SIZE*[Ty-1])

        logits = outputs.rnn_output
        # Calculate loss

        loss = loss_function(logits, decoder_output)

    # Returns the list of all layer variables / weights.
    variables = encoderNetwork.trainable_variables + decoderNetwork.trainable_variables  
    # differentiate loss wrt variables
    gradients = tape.gradient(loss, variables)

    # grads_and_vars – List of(gradient, variable) pairs.
    grads_and_vars = zip(gradients,variables)
    optimizer.apply_gradients(grads_and_vars)
    return loss

In [11]:
# RNN LSTM hidden and memory state initializer
def initialize_initial_state():
        return [tf.zeros((BATCH_SIZE, rnn_units)), tf.zeros((BATCH_SIZE, rnn_units))]

In [12]:
epochs = 15
for i in range(1, epochs+1):

    encoder_initial_cell_state = initialize_initial_state()
    total_loss = 0.0

    for ( batch , (input_batch, output_batch)) in enumerate(dataset.take(steps_per_epoch)):
        batch_loss = train_step(input_batch, output_batch, encoder_initial_cell_state)
        total_loss += batch_loss
        if (batch+1)%5 == 0:
            print("total loss: {} epoch {} batch {} ".format(batch_loss.numpy(), i, batch+1))

total loss: 3.3131473064422607 epoch 1 batch 5 
total loss: 2.262629747390747 epoch 1 batch 10 
total loss: 2.076655149459839 epoch 1 batch 15 
total loss: 2.0565059185028076 epoch 1 batch 20 
total loss: 1.9498294591903687 epoch 1 batch 25 
total loss: 2.079404830932617 epoch 1 batch 30 
total loss: 1.9453245401382446 epoch 1 batch 35 
total loss: 1.8376885652542114 epoch 1 batch 40 
total loss: 1.7034138441085815 epoch 1 batch 45 
total loss: 1.6252403259277344 epoch 1 batch 50 
total loss: 1.6810463666915894 epoch 1 batch 55 
total loss: 1.746016502380371 epoch 1 batch 60 
total loss: 1.541724681854248 epoch 1 batch 65 
total loss: 1.7413946390151978 epoch 1 batch 70 
total loss: 1.6177066564559937 epoch 1 batch 75 
total loss: 1.5472936630249023 epoch 1 batch 80 
total loss: 1.570003628730774 epoch 1 batch 85 
total loss: 1.6743879318237305 epoch 1 batch 90 
total loss: 1.6717625856399536 epoch 1 batch 95 
total loss: 1.668370246887207 epoch 1 batch 100 
total loss: 1.5403496026992

### Evaluation