<a href="https://colab.research.google.com/github/h4ck4l1/datasets/blob/main/NLP_with_RNN_and_Attention/NMT_with_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
# from google.colab import auth
# auth.authenticate_user()
import os,warnings
from IPython.display import clear_output
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore")
!pip3 install -q -U "tensorflow-text==2.13.0"
!pip3 install -q -U einops
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_text as tf_text
np.printoptions(precision=2)
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
pio.templates.default = "plotly_dark"
import einops
from zipfile import ZipFile
from typing import Any
# %xmode Minimal
tf.get_logger().setLevel("ERROR")
clear_output()

In [24]:
# resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
# tf.config.experimental_connect_to_cluster(resolver)
# tf.tpu.experimental.initialize_tpu_system(resolver)
# strategy = tf.distribute.TPUStrategy(resolver)
# strategy = tf.distribute.OneDeviceStrategy(device="/device:GPU:0")

In [25]:
# class ShapeCheck():

#     def __init__(self):
#         self.shapes = {}

#     def __call__(self,tensor,names,**kwargs):

#         if not tf.executing_eagerly():
#             return

#         for name,dim in einops.parse_shape(tensor,names).items():

#             if name not in self.shapes:
#                 self.shapes[name] = dim

#             elif self.shapes[name] == dim:
#                 continue

#             else:
#                 raise ValueError(f"Dimension mismatch for tensor {tensor}\nfound dimention :{self.shapes[name]}\nnew dimension given :{dim}")


In [26]:
with tf.device("/job:localhost"):
    origin = "http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip"
    file_path = keras.utils.get_file(fname="spa-eng.zip",origin=origin,extract=True)
    with ZipFile(file_path,"r") as f:
        f.extractall("spa-eng")
    with open("spa-eng/spa-eng/spa.txt","r") as f:
        text = f.read()
    en_text,es_text = zip(*[line.split("\t") for line in text.splitlines()])
    for en,es in zip(en_text[:10],es_text[:10]):
        print(en,"---->",es)

Go. ----> Ve.
Go. ----> Vete.
Go. ----> Vaya.
Go. ----> Váyase.
Hi. ----> Hola.
Run! ----> ¡Corre!
Run. ----> Corred.
Who? ----> ¿Quién?
Fire! ----> ¡Fuego!
Fire! ----> ¡Incendio!


In [27]:
def text_preprocess(sentence:str):
    sentence = tf_text.normalize_utf8(sentence,"NFKD")
    sentence = tf.strings.lower(sentence)
    sentence = tf.strings.regex_replace(sentence,r"[^ a-z.,!?¿]","")
    sentence = tf.strings.regex_replace(sentence,r"[.,!?¿]",r" \0 ")
    sentence = tf.strings.strip(sentence)
    sentence = tf.strings.join(["[START]",sentence,"[END]"],separator=" ")
    return sentence

In [28]:
# for en,es in zip(text_preprocess(en_text[:10]).numpy(),text_preprocess(es_text[:10]).numpy()):
#     print(f"{en}   ---->{es}")

In [29]:
def get_layers(vocab_size=5000):
    en_vec_layer = keras.layers.TextVectorization(max_tokens=vocab_size,standardize=text_preprocess,ragged=True)
    es_vec_layer = keras.layers.TextVectorization(max_tokens=vocab_size,standardize=text_preprocess,ragged=True)
    en_vec_layer.adapt(en_text)
    es_vec_layer.adapt(es_text)
    return en_vec_layer,es_vec_layer

In [30]:
def preprocess(en_inputs,es_inputs):
    en_inputs = en_vec_layer(en_inputs).to_tensor()
    es_inputs = es_vec_layer(es_inputs).to_tensor()
    return (en_inputs,es_inputs[:,:-1]),es_inputs[:,1:]

In [31]:
AUTO = tf.data.AUTOTUNE
def get_datasets(en_text,es_text,batch_size=64):
    all_indices = np.random.uniform(size=len(en_text))
    train_indices = all_indices < 0.8
    valid_indices = all_indices > 0.8
    en_text = np.array(en_text)
    es_text = np.array(es_text)
    train_size = len(train_indices)
    valid_size = len(valid_indices)
    train_ds = (
        tf.data.Dataset
        .from_tensor_slices((en_text[train_indices],es_text[train_indices]))
        .shuffle(len(en_text))
        .batch(batch_size)
        .map(preprocess)
        .prefetch(AUTO)
    )
    valid_ds = (
        tf.data.Dataset
        .from_tensor_slices((en_text[valid_indices],es_text[valid_indices]))
        .shuffle(len(en_text))
        .batch(batch_size)
        .map(preprocess)
        .prefetch(AUTO)
    )
    return train_ds,valid_ds,train_size,valid_size


# train_ds,valid_ds = get_dataset(en_text,es_text)

In [32]:
# for (en_in,es_in),tar_in in train_ds.take(1):
#     print(en_in.shape,es_in.shape,tar_in.shape)
#     print(en_in[:2],es_in[:2],tar_in[:2])

In [33]:
class Encoder(keras.layers.Layer):

    def __init__(self,vec_layer: keras.layers.TextVectorization,units: int=256,**kwargs):

        super(Encoder,self).__init__(**kwargs)
        self.units = units
        self.vec_layer = vec_layer
        self.vocab_size = vec_layer.vocabulary_size()

        self.embedder = keras.layers.Embedding(self.vocab_size,units,mask_zero=True)
        self.encoder_unit = keras.layers.Bidirectional(keras.layers.LSTM(units,return_state=True,return_sequences=True,recurrent_initializer="glorot_uniform"),merge_mode="sum")

    def call(self,encoder_inputs):

        # shape_checker = ShapeCheck()
        # shape_checker(encoder_inputs,"batch encoder_sequence")

        encoder_embedded_outputs = self.embedder(encoder_inputs)
        # shape_checker(encoder_embedded_outputs,"batch encoder_sequence units")

        encoder_outputs,*self.encoder_state = self.encoder_unit(encoder_embedded_outputs)
        # shape_checker(self.encoder_state[0],"batch units")
        # shape_checker(self.encoder_state[1],"batch units")

        return encoder_outputs

In [34]:
# '''Test whether encoder is working for all inputs'''
# with strategy.scope():
#     encoder = Encoder()

# for en_in in train_ds.map(lambda x,y:x[0]).take(1):
#     encoder(en_in)

# for en_in in valid_ds.map(lambda x,y:x[0]).take(1):
#     encoder(en_in)

In [35]:
class Decoder(keras.layers.Layer):

    def __init__(self,vec_layer:keras.layers.TextVectorization,units:int=256,**kwargs):

        super(Decoder,self).__init__(**kwargs)
        self.units = units
        self.vec_layer = vec_layer
        self.vocab_size = vec_layer.vocabulary_size()

        self.embedder = keras.layers.Embedding(self.vocab_size,units,mask_zero=True)
        self.decoder_unit = keras.layers.LSTM(units,return_state=True,return_sequences=True,recurrent_initializer="glorot_uniform")


    def call(self,decoder_inputs,decoder_initial_state=None):

        # shape_checker = ShapeCheck()
        # shape_checker(decoder_inputs,"batch decoder_sequence")

        decoder_embedded_outputs = self.embedder(decoder_inputs)
        # shape_checker(decoder_embedded_outputs,"batch decoder_sequence units")

        decoder_outputs,*self.decoder_state = self.decoder_unit(decoder_embedded_outputs,initial_state=decoder_initial_state)
        # shape_checker(decoder_outputs,"batch decoder_sequence units")
        # shape_checker(self.decoder_state[0],"batch units")
        # shape_checker(self.decoder_state[1],"batch units")

        return decoder_outputs

In [36]:
# '''Testing for Decoder Errors'''
# with strategy.scope():
#     decoder = Decoder()


# for es_in in train_ds.map(lambda x,y:x[1]).take(1):
#     decoder(es_in)

# for es_in in valid_ds.map(lambda x,y:x[1]).take(1):
#     decoder(es_in)

In [37]:
class CrossAttention(keras.layers.Layer):

    def __init__(self,units=256,**kwargs):

        super(CrossAttention,self).__init__(**kwargs)

        self.mha = keras.layers.MultiHeadAttention(num_heads=1,key_dim=units)
        self.add = keras.layers.Add()
        self.layer_norm = keras.layers.LayerNormalization()

    def call(self,encoder_outputs,decoder_outputs):

        # shape_checker = ShapeCheck()
        # shape_checker(encoder_outputs,"batch encoder_sequence units")
        # shape_checker(decoder_outputs,"batch decoder_sequence units")

        attention_outputs,attention_scores = self.mha(query=decoder_outputs,value=encoder_outputs,return_attention_scores=True)
        # shape_checker(attention_outputs,"batch decoder_sequence units")
        # shape_checker(attention_scores,"batch num_heads decoder_sequence encoder_sequence")
        self.attention_scores = tf.reduce_mean(attention_scores,axis=1)
        normalized_attention_outputs = self.layer_norm(self.add([attention_outputs,decoder_outputs]))

        return normalized_attention_outputs

In [38]:
# '''Testing for Attention Errors'''
# with strategy.scope():
#     attention_layer = CrossAttention()

# for en_in,es_in in train_ds.map(lambda x,y:x).take(1):
#     attention_layer(encoder(en_in),decoder(es_in)


# for en_in,es_in in valid_ds.map(lambda x,y:x).take(1):
#     attention_layer(encoder(en_in),decoder(es_in))

In [39]:
class Translator(keras.Model):

    def __init__(self,input_vec_layer,output_vec_layer,units=256,**kwargs):

        super(Translator,self).__init__(**kwargs)

        self.encoder_layer = Encoder(units=units,vec_layer=input_vec_layer)
        self.decoder_layer = Decoder(units=units,vec_layer=output_vec_layer)
        self.attention_layer = CrossAttention(units=units)

        self.words_to_ids = keras.layers.StringLookup(
            vocabulary=self.decoder_layer.vec_layer.get_vocabulary(),
            oov_token="[UNK]",
            mask_token=""
        )
        self.ids_to_words = keras.layers.StringLookup(
            vocabulary=self.decoder_layer.vec_layer.get_vocabulary(),
            oov_token="[UNK]",
            mask_token="",
            invert=True
        )
        self.start_token = self.words_to_ids(["[START]"])
        self.end_token = self.words_to_ids(["[END]"])

        self.out = keras.layers.Dense(self.decoder_layer.vec_layer.vocabulary_size())


    def call(self,inputs,decoder_initial_state=None):


        # shape_checker = ShapeCheck()

        encoder_inputs,decoder_inputs = inputs
        # shape_checker(encoder_inputs,"batch encoder_sequence")
        # shape_checker(decoder_inputs,"batch decoder_sequence")

        encoder_outputs = self.encoder_layer(encoder_inputs)
        # shape_checker(encoder_outputs,"batch encoder_sequenc units")
        # shape_checker(self.encoder_layer.encoder_state[0],"batch units")
        # shape_checker(self.encoder_layer.encoder_state[1],"batch units")

        decoder_outputs = self.decoder_layer(decoder_inputs,decoder_initial_state)
        # shape_checker(decoder_outputs,"batch decoder_sequence units")
        # shape_checker(self.decoder_layer.decoder_state[0],"batch units")
        # shape_checker(self.decoder_layer.decoder_state[1],"batch units")

        attention_outputs = self.attention_layer(encoder_outputs,decoder_outputs)
        # shape_checker(attention_outputs,"batch decoder_sequence units")

        total_outputs = self.out(attention_outputs)
        # shape_checker(total_outputs,"batch decoder_sequence vocab_size")

        try:
            del total_outputs._keras_mask

        except AttributeError:

            pass

        return total_outputs


    def text_to_encoder_outputs(self,texts):
        texts = tf.convert_to_tensor(texts)
        en_vec_outputs = self.encoder_layer.vec_layer(texts).to_tensor()
        return self.encoder_layer(en_vec_outputs)

    def get_decoder_initial_state(self,encoder_outputs):
        batch_size = encoder_outputs.shape[0]
        start_tokens = tf.fill(dims=[batch_size,1],value=self.start_token)
        done = tf.zeros(shape=[batch_size,1],dtype=tf.bool)
        embedding = self.decoder_layer.embedder(start_tokens)
        return start_tokens,done,self.decoder_layer.decoder_unit.get_initial_state(embedding)

    def get_next_token(self,encoder_inputs,next_token,done,state,temperature=0.0):
        total_out = self((encoder_inputs,next_token),state)

        if temperature:
            scaled_total_out = total_out[:,-1,:]/temperature
            next_token = tf.random.categorical(scaled_total_out,num_samples=1)
        else:
            next_token = tf.argmax(total_out,axis=-1)

        done = done | (next_token == self.end_token)
        next_token = tf.where(done,tf.constant(0,tf.int64),next_token)
        return next_token,done,self.decoder_layer.decoder_state

    def tokens_to_text(self,tokens):
        texts = self.ids_to_words(tokens)
        texts = tf.strings.reduce_join(texts,separator=" ",axis=-1)
        texts = tf.strings.regex_replace(texts,r"^ *\[START\]* ","")
        texts = tf.strings.regex_replace(texts,r" *\[END]\ *$","")
        texts = tf.strings.strip(texts)
        return texts


In [40]:
# '''Testing for Total Translator errors'''
# with strategy.scope():
#     model = Translator()


# for x in train_ds.map(lambda x,y:x).take(1):
#     model(x)
# for x in valid_ds.map(lambda x,y:x).take(1):
#     model(x)

In [41]:
# '''Example Text Generation'''
# next_token,done,state = model.get_decoder_initial_state(model.encoder_layer(en_in))
# tokens_list = []
# for i in range(10):
#     next_token,done,state = model.get_next_token(en_in,next_token,done,state=state,temperature=1)
#     tokens_list.append(next_token)

# tokens_list = tf.concat(tokens_list,axis=-1)
# model.tokens_to_text(tokens_list)[:10]

In [42]:
def custom_loss(y_true,y_pred):

    '''
        y_pred will be [batch sequence vocab_size]
        y_true will be [batch sequence]
        as the sequence contains zeros we only use the non-zero part of the sequence so we will mask it
    '''
    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True,reduction="none")
    loss = loss_fn(y_true,y_pred) # tf.float32
    mask = tf.cast(y_true != 0,loss.dtype) # tf.float32
    loss *= mask # reducing the effective output scale
    return tf.reduce_sum(loss)/tf.reduce_sum(mask)


def custom_metric(y_true,y_pred):

    '''
        y_pred will be [batch sequence vocab_size]             with dtype = tf.float32
        y_true will be [batch sequence]                        with dtype = tf.int64
        as the sequence also has zeros we use masked accuracy
    '''

    y_pred = tf.cast(tf.argmax(y_pred,-1),y_true.dtype) # tf.int64
    mask = tf.cast(y_true != 0,tf.float32) # tf.float32
    accuracy = tf.cast(y_pred == y_true,tf.float32) # tf.float32
    return tf.reduce_sum(accuracy)/tf.reduce_sum(mask) # tf.float32




In [43]:
BATCH_SIZE = 64 #8*strategy.num_replicas_in_sync
UNITS = 256
en_vec_layer,es_vec_layer = get_layers(vocab_size=5000)
train_ds,valid_ds,train_size,valid_size = get_datasets(en_text,es_text,batch_size=BATCH_SIZE)
train_steps = train_size//BATCH_SIZE
valid_steps = valid_size//BATCH_SIZE
# with strategy.scope():
model = Translator(input_vec_layer=en_vec_layer,output_vec_layer=es_vec_layer,units=UNITS)
model.compile(loss=custom_loss,optimizer="adam",metrics=[custom_metric,custom_loss],steps_per_execution=20)

In [46]:
early_stop = keras.callbacks.EarlyStopping(patience=15,monitor='val_custom_metric',restore_best_weights=True)
check = keras.callbacks.ModelCheckpoint(filepath="/content/nmt",monitor="val_custom_loss",save_best_only=True)
history = model.fit(
    train_ds,
    epochs=100,
    validation_data=valid_ds,
    callbacks=[early_stop,check]
    # steps_per_epoch=train_steps,
    # validation_steps=valid_steps
    )

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100


<keras.src.callbacks.History at 0x7b7ba6b43e80>

In [48]:
model.history.history.keys()

dict_keys(['loss', 'custom_metric', 'custom_loss', 'val_loss', 'val_custom_metric', 'val_custom_loss'])

In [49]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=model.history.history['custom_loss'],mode="lines"))
fig.add_trace(go.Scatter(y=model.history.history['val_custom_loss'],mode="lines"))
fig.update_layout(title="Loss Train v/s Validation")
fig.show()

In [53]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=model.history.history['custom_metric'],mode="lines"))
fig.add_trace(go.Scatter(y=model.history.history['val_custom_metric'],mode="lines"))
fig.update_layout(title="MaskedAccuracy Train v/s Validation")
fig.update_yaxes(range=[0,1])
fig.show()