<a href="https://colab.research.google.com/github/kartik727/neural-machine-translation/blob/master/Seq2Seq_with_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Seq2Seq with attention

## Dependencies
Primary library used for modelling and training - trax

## Data - Tensorflow Datasets (TFDS)
1. OPUS (`'opus'`)

In [None]:
# instal trax

!pip install trax

In [3]:
import random
import numpy as np
import re
import nltk

import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training

In [4]:
import tensorflow_datasets as tfds

In [5]:
import tensorflow as tf

In [6]:
from collections import defaultdict

In [7]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    print('GPU device not found')
else:
    print('Found GPU at: {}'.format(device_name))

GPU device not found


In [8]:
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  print('Not connected to a TPU runtime')

Running on TPU  ['10.103.246.234:8470']


In [9]:
dataset_train = tfds.load('opus', split='train', batch_size=-1, shuffle_files=True)

[1mDownloading and preparing dataset 34.29 MiB (download: 34.29 MiB, generated: 188.85 MiB, total: 223.13 MiB) to /root/tensorflow_datasets/opus/medical/0.1.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…






HBox(children=(FloatProgress(value=0.0, description='Generating splits...', max=1.0, style=ProgressStyle(descr…

HBox(children=(FloatProgress(value=0.0, description='Generating train examples...', max=1108752.0, style=Progr…

HBox(children=(FloatProgress(value=0.0, description='Shuffling opus-train.tfrecord...', max=1108752.0, style=P…

[1mDataset opus downloaded and prepared to /root/tensorflow_datasets/opus/medical/0.1.0. Subsequent calls will reuse this data.[0m


In [10]:
ds_np = tfds.as_numpy(dataset_train)

In [11]:
# Utils Namespace

class Namespace:
    def __init__(self, **kwargs):
        self.update(**kwargs)

    def update(self, **kwargs):
        for kw, arg in kwargs.items():
            if type(arg)==dict:
                setattr(self, kw, Namespace(**arg))
            else:
                setattr(self, kw, arg)

    def get_from_list(self, arg_list):
        ret_dict = {}
        for arg in arg_list:
            ret_dict[arg] = getattr(self, arg)
        return  ret_dict

    def get(self):
        return self.get_from_list(self.__dict__.keys())

In [92]:
config_dict = {
    'model' : {
        'input_vocab_size' : 8187,
        'target_vocab_size' : 11976,
        'embedding_size' : 256,
        'n_encoder_layers' : 1,
        'n_decoder_layers' : 1,
        'n_attention_heads' : 1,
        'attention_dropout' : 0.
    },
    'data_size' : 10_000
}

config = Namespace(**config_dict)

In [13]:
def preprocess_sentence(w):
  w = w.decode().lower().strip()

  # creating a space between a word and the punctuation following it
  # eg: "he is a boy." => "he is a boy ."
  # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation
  w = re.sub(r"([?.!,¿])", r" \1 ", w)
  w = re.sub(r'[" "]+', " ", w)

  # replacing everything with space except (a-z, A-Z, ".", "?", "!", ",")
  w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)

  w = w.strip()

  # adding a start and an end token to the sentence
  # so that the model know when to start and stop predicting.
  w = '<start> ' + w + ' <end>'
  return w

In [14]:
def preprocess_data(data):
    return [preprocess_sentence(w) for w in data]

In [15]:
preprocessed_data_en, preprocessed_data_de = preprocess_data(ds_np['en'])[:config.data_size],preprocess_data(ds_np['de'])[:config.data_size]

In [16]:
del ds_np

In [17]:
def tokenize(lang):
  lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
  lang_tokenizer.fit_on_texts(lang)

  tensor = lang_tokenizer.texts_to_sequences(lang)

  tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,
                                                         padding='post')

  return tensor, lang_tokenizer

In [18]:
token_en = tokenize(preprocessed_data_en)

In [58]:
token_en[0][0]

array([ 1, 68,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0], dtype=int32)

In [66]:
token_en[1].sequences_to_texts([token_en[0][1][1:]])

['during treatment with olanzapine , adolescents gained significantly more weight compared with adults . <end>']

In [75]:
token_en[1].texts_to_sequences(['how are You ?'])

[[192, 26, 19, 359]]

In [19]:
del preprocessed_data_en

In [20]:
token_de = tokenize(preprocessed_data_de)

In [21]:
def input_encoder_fn(input_vocab_size, embedding_size, n_encoder_layers):
    input_encoder = tl.Serial( 
        tl.Embedding(input_vocab_size, embedding_size),
        [tl.LSTM(embedding_size) for _ in range(n_encoder_layers)]
    )
    return input_encoder

def pre_attention_decoder_fn(mode, target_vocab_size, embedding_size):
    pre_attention_decoder = tl.Serial(
        tl.ShiftRight(),
        tl.Embedding(target_vocab_size, embedding_size),
        tl.LSTM(embedding_size)
    )
    return pre_attention_decoder

def prepare_attention_input(encoder_activations, decoder_activations, inputs):
    keys = encoder_activations
    values = encoder_activations
    queries = decoder_activations
    mask = inputs != 0
    mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
    mask = mask + fastnp.zeros((1, 1, decoder_activations.shape[1], 1))
    return queries, keys, values, mask

def NMTAttn(input_vocab_size=8187,
            target_vocab_size=11976,
            embedding_size=1024,
            n_encoder_layers=2,
            n_decoder_layers=2,
            n_attention_heads=4,
            attention_dropout=0.0,
            mode='train'):
    input_encoder = input_encoder_fn(input_vocab_size, embedding_size, n_encoder_layers)
    pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, embedding_size)
    model = tl.Serial( 
      tl.Select([0, 1, 0, 1]),
      tl.Parallel(input_encoder, pre_attention_decoder),
      tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4),
      tl.Residual(tl.AttentionQKV(embedding_size, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)),
      tl.Select([0, 2]),
      [tl.LSTM(embedding_size) for _ in range(n_decoder_layers)],
      tl.Dense(target_vocab_size),
      tl.LogSoftmax()
    )
    return model

In [22]:
class Data_Iter:
    def __init__(self, data_en, data_de, batch_size):
        self.data_en = data_en
        self.data_de = data_de
        self.batch_size = batch_size
        self.idx = 0
        self.l = len(data_en)

    def __iter__(self):
        return self

    def __next__(self):
        if self.idx>=self.l:
            self.idx=0
        batch_en = self.data_en[self.idx:self.idx+self.batch_size]
        batch_de = self.data_de[self.idx:self.idx+self.batch_size]
        mask = np.array(batch_de != 0, dtype=int)
        self.idx += self.batch_size
        return batch_en, batch_de, mask

In [23]:
data_iter = Data_Iter(token_en[0], token_de[0], 64)

In [24]:
train_task = training.TrainTask(
    labeled_data= data_iter,
    loss_layer= tl.CrossEntropyLoss(),
    optimizer= trax.optimizers.Adam(learning_rate=0.01),
    lr_schedule= trax.lr.warmup_and_rsqrt_decay(1000, 0.01),
    n_steps_per_checkpoint= 10,
)

eval_task = training.EvalTask(
    labeled_data=data_iter,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
)



In [29]:
model = NMTAttn(mode='train', **config.model.get())

In [30]:
output_dir = '/content/output_dir/'
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

  "jax.host_id has been renamed to jax.process_index. This alias "
  "jax.host_count has been renamed to jax.process_count. This alias "


In [31]:
training_loop.run(20)

  "jax.host_id has been renamed to jax.process_index. This alias "



Step      1: Total number of trainable weights: 10078664
Step      1: Ran 1 train steps in 31.06 secs
Step      1: train CrossEntropyLoss |  9.37975311
Step      1: eval  CrossEntropyLoss |  9.38680172
Step      1: eval          Accuracy |  0.00000000

Step     10: Ran 9 train steps in 161.39 secs
Step     10: train CrossEntropyLoss |  9.33891964
Step     10: eval  CrossEntropyLoss |  9.24802208
Step     10: eval          Accuracy |  0.05151916

Step     20: Ran 10 train steps in 178.19 secs
Step     20: train CrossEntropyLoss |  9.12834263
Step     20: eval  CrossEntropyLoss |  8.93299484
Step     20: eval          Accuracy |  0.07381776


In [76]:
def next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature):
    token_length = len(cur_output_tokens)
    padded_length = np.power(2, int(np.ceil(np.log2(token_length + 1))))
    padded = cur_output_tokens + [0]*(padded_length - token_length)
    padded_with_batch = np.expand_dims(padded, axis=0)
    output, _ = NMTAttn((input_tokens, padded_with_batch))
    log_probs = output[0, token_length, :]
    symbol = int(tl.logsoftmax_sample(log_probs, temperature))
    return symbol, float(log_probs[symbol])

def sampling_decode(input_sentence, NMTAttn, input_tokenizer, target_tokenizer, temperature=0.0):
    input_tokens = input_tokenizer.texts_to_sequences([input_sentence])[0]
    cur_output_tokens = [1]
    cur_output = 0
    EOS = 2
    while cur_output != EOS:
        cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
        cur_output_tokens.append(cur_output)
    sentence = target_tokenizer.sequences_to_text([cur_output_tokens])[0]
    return cur_output_tokens, log_prob, sentence

In [86]:
model = NMTAttn(mode='eval')

# initialize weights from a pre-trained model
model.init_from_file(output_dir+"model.pkl.gz", weights_only=True)
model = tl.Accelerate(model)

IndexError: ignored

In [77]:
input_sentence = 'Hello how are you'
sampling_decode(input_sentence, NMTAttn, token_en[1], token_de[1])

TypeError: ignored

In [84]:
input_sentence = 'how are you'
input_tokens = token_en[1].texts_to_sequences([input_sentence])[0]

token_length = 0
padded_length = np.power(2, int(np.ceil(np.log2(token_length + 1))))
padded = [0]
padded_with_batch = np.expand_dims(padded, axis=0)
x = model((input_tokens, padded_with_batch))

AttributeError: ignored

In [87]:
keras_layer = trax.AsKeras(model)

In [88]:
keras_layer

<trax.trax2keras.AsKeras at 0x7f0052d60710>

In [95]:
inputs = tf.keras.Input(shape=(config.model.input_vocab_size, config.model.target_vocab_size), dtype='int32')
hidden = keras_layer(inputs)
outputs = hidden
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
print(keras_model)

TypeError: ignored