<a href="https://colab.research.google.com/github/kartik727/neural-machine-translation/blob/master/Seq2Seq_with_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Seq2Seq with attention

## Dependencies
Primary library used for modelling and training - trax

## Data - Tensorflow Datasets (TFDS)
1. OPUS (`'opus'`)

In [1]:
# instal trax

# !pip install trax

In [2]:
import random
import numpy as np
import re
import nltk

import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training

In [3]:
import tensorflow_datasets as tfds

In [4]:
import tensorflow as tf

In [5]:
from collections import defaultdict

In [6]:
dataset_train = tfds.load('opus', split='train', batch_size=-1, shuffle_files=True)

In [7]:
ds_np = tfds.as_numpy(dataset_train)

In [8]:
# Utils Namespace

class Namespace:
    def __init__(self, **kwargs):
        self.update(**kwargs)

    def update(self, **kwargs):
        self.__dict__.update(kwargs)

In [9]:
config_dict = {
    'vocab_size_en' : 20_000,
    'vocab_size_de' : 20_000
}



In [10]:
def preprocess_sentence(w):
  w = w.decode().lower().strip()

  # creating a space between a word and the punctuation following it
  # eg: "he is a boy." => "he is a boy ."
  # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation
  w = re.sub(r"([?.!,¿])", r" \1 ", w)
  w = re.sub(r'[" "]+', " ", w)

  # replacing everything with space except (a-z, A-Z, ".", "?", "!", ",")
  w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)

  w = w.strip()

  # adding a start and an end token to the sentence
  # so that the model know when to start and stop predicting.
  w = '<start> ' + w + ' <end>'
  return w

In [11]:
def preprocess_data(data):
    return [preprocess_sentence(w) for w in data]

In [12]:
preprocessed_data_en, preprocessed_data_de = preprocess_data(ds_np['en'])[:10000],preprocess_data(ds_np['de'])[:10000]

In [13]:
del ds_np

In [14]:
def tokenize(lang):
  lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
  lang_tokenizer.fit_on_texts(lang)

  tensor = lang_tokenizer.texts_to_sequences(lang)

  tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,
                                                         padding='post')

  return tensor, lang_tokenizer

In [15]:
token_en = tokenize(preprocessed_data_en)

In [16]:
del preprocessed_data_en

In [17]:
token_de = tokenize(preprocessed_data_de)

In [18]:
def input_encoder_fn(input_vocab_size, embedding_size, n_encoder_layers):
    input_encoder = tl.Serial( 
        tl.Embedding(input_vocab_size, embedding_size),
        [tl.LSTM(embedding_size) for _ in range(n_encoder_layers)]
    )
    return input_encoder

def pre_attention_decoder_fn(mode, target_vocab_size, embedding_size):
    pre_attention_decoder = tl.Serial(
        tl.ShiftRight(),
        tl.Embedding(target_vocab_size, embedding_size),
        tl.LSTM(embedding_size)
    )
    return pre_attention_decoder

def prepare_attention_input(encoder_activations, decoder_activations, inputs):
    keys = encoder_activations
    values = encoder_activations
    queries = decoder_activations
    mask = inputs != 0
    mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
    mask = mask + fastnp.zeros((1, 1, decoder_activations.shape[1], 1))
    return queries, keys, values, mask

def NMTAttn(input_vocab_size=33300,
            target_vocab_size=33300,
            embedding_size=1024,
            n_encoder_layers=2,
            n_decoder_layers=2,
            n_attention_heads=4,
            attention_dropout=0.0,
            mode='train'):
    input_encoder = input_encoder_fn(input_vocab_size, embedding_size, n_encoder_layers)
    pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, embedding_size)
    model = tl.Serial( 
      tl.Select([0, 1, 0, 1]),
      tl.Parallel(input_encoder, pre_attention_decoder),
      tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4),
      tl.Residual(tl.AttentionQKV(embedding_size, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)),
      tl.Select([0, 2]),
      [tl.LSTM(embedding_size) for _ in range(n_decoder_layers)],
      tl.Dense(target_vocab_size),
      tl.LogSoftmax()
    )
    return model

In [19]:
class Data_Iter:
    def __init__(self, data_en, data_de, batch_size):
        self.data_en = data_en
        self.data_de = data_de
        self.batch_size = batch_size
        self.idx = 0
        self.l = len(data_en)

    def __iter__(self):
        return self

    def __next__(self):
        if self.idx>=self.l:
            self.idx=0
        batch_en = self.data_en[self.idx:self.idx+self.batch_size]
        batch_de = self.data_de[self.idx:self.idx+self.batch_size]
        mask = np.array(batch_de != 0, dtype=int)
        self.idx += self.batch_size
        return batch_en, batch_de, mask

In [20]:
data_iter = Data_Iter(token_en[0], token_de[0], 64)

In [21]:
train_task = training.TrainTask(
    labeled_data= data_iter,
    loss_layer= tl.CrossEntropyLoss(),
    optimizer= trax.optimizers.Adam(learning_rate=0.01),
    lr_schedule= trax.lr.warmup_and_rsqrt_decay(1000, 0.01),
    n_steps_per_checkpoint= 10,
)

eval_task = training.EvalTask(
    labeled_data=data_iter,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
)



In [22]:
output_dir = '/content/output_dir/'
training_loop = training.Loop(NMTAttn(mode='train'),
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

  "jax.host_id has been renamed to jax.process_index. This alias "
  "jax.host_count has been renamed to jax.process_count. This alias "


In [None]:
training_loop.run(20)

  "jax.host_id has been renamed to jax.process_index. This alias "
