<a href="https://colab.research.google.com/github/h4ck4l1/datasets/blob/main/NLP_with_RNN_and_Attention/NMT_with_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [61]:
import os,warnings
from IPython.display import clear_output
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore")
!pip3 install -q -U "tensorflow-text==2.13.0"
!pip3 install -q -U einops
!pip3 install plotly
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_text as tf_text
np.printoptions(precision=2)
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
pio.templates.default = "plotly_dark"
import einops
from zipfile import ZipFile
from typing import Any
%xmode Minimal
tf.get_logger().setLevel("ERROR")
clear_output()

In [None]:
# resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
# tf.config.experimental_connect_to_cluster(resolver)
# tf.tpu.experimental.initialize_tpu_system(resolver)
# strategy = tf.distribute.TPUStrategy(resolver)
strategy = tf.distribute.OneDeviceStrategy()

In [62]:
class ShapeCheck():

    def __init__(self):
        self.shapes = {}

    def __call__(self,tensor,names,**kwargs):

        if not tf.executing_eagerly():
            return

        for name,dim in einops.parse_shape(tensor,names).items():

            if name not in self.shapes:
                self.shapes[name] = dim

            elif self.shapes[name] == dim:
                continue

            else:
                raise ValueError(f"Dimension mismatch for tensor {tensor}\nfound dimention :{self.shapes[name]}\nnew dimension given :{dim}")


In [63]:
origin = "http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip"

In [64]:
file_path = keras.utils.get_file(fname="spa-eng.zip",origin=origin,extract=True)
with ZipFile(file_path,"r") as f:
    f.extractall("spa-eng")
with open("spa-eng/spa-eng/spa.txt","r") as f:
    text = f.read()
en_text,es_text = zip(*[line.split("\t") for line in text.splitlines()])
for en,es in zip(en_text[:10],es_text[:10]):
    print(en,es)

Go. Ve.
Go. Vete.
Go. Vaya.
Go. Váyase.
Hi. Hola.
Run! ¡Corre!
Run. Corred.
Who? ¿Quién?
Fire! ¡Fuego!
Fire! ¡Incendio!


In [65]:
def text_preprocess(sentence:str):
    sentence = tf_text.normalize_utf8(sentence,"NFKD")
    sentence = tf.strings.lower(sentence)
    sentence = tf.strings.regex_replace(sentence,r"[^ a-z.,!?¿]","")
    sentence = tf.strings.regex_replace(sentence,r"[.,!?¿]",r" \0 ")
    sentence = tf.strings.strip(sentence)
    sentence = tf.strings.join(["[START]",sentence,"[END]"],separator=" ")
    return sentence

In [66]:
for en,es in zip(text_preprocess(en_text[:10]).numpy(),text_preprocess(es_text[:10]).numpy()):
    print(f"{en}   ---->{es}")

b'[START] go . [END]'   ---->b'[START] ve . [END]'
b'[START] go . [END]'   ---->b'[START] vete . [END]'
b'[START] go . [END]'   ---->b'[START] vaya . [END]'
b'[START] go . [END]'   ---->b'[START] vayase . [END]'
b'[START] hi . [END]'   ---->b'[START] hola . [END]'
b'[START] run ! [END]'   ---->b'[START] corre ! [END]'
b'[START] run . [END]'   ---->b'[START] corred . [END]'
b'[START] who ? [END]'   ---->b'[START] \xc2\xbf quien ? [END]'
b'[START] fire ! [END]'   ---->b'[START] fuego ! [END]'
b'[START] fire ! [END]'   ---->b'[START] incendio ! [END]'


In [67]:
vocab_size = 5000
en_vec_layer = keras.layers.TextVectorization(max_tokens=vocab_size,standardize=text_preprocess,ragged=True)
es_vec_layer = keras.layers.TextVectorization(max_tokens=vocab_size,standardize=text_preprocess,ragged=True)
en_vec_layer.adapt(en_text)
es_vec_layer.adapt(es_text)

In [68]:
def preprocess(en_inputs,es_inputs):
    en_inputs = en_vec_layer(en_inputs).to_tensor()
    es_inputs = es_vec_layer(es_inputs).to_tensor()
    return (en_inputs,es_inputs[:,:-1]),es_inputs[:,1:]

In [75]:
AUTO = tf.data.AUTOTUNE
all_indices = np.random.uniform(size=len(en_text))
train_indices = all_indices < 0.8
test_indices = all_indices > 0.8
en_text = np.array(en_text)
es_text = np.array(es_text)
train_ds = (
    tf.data.Dataset
    .from_tensor_slices((en_text[train_indices],es_text[train_indices]))
    .shuffle(len(en_text))
    .batch(64)
    .map(preprocess)
    .prefetch(AUTO)
)
valid_ds = (
    tf.data.Dataset
    .from_tensor_slices((en_text[test_indices],es_text[test_indices]))
    .shuffle(len(en_text))
    .batch(64)
    .map(preprocess)
    .prefetch(AUTO)
)

In [77]:
for (en_in,es_in),tar_in in train_ds.take(1):
    print(en_in.shape,es_in.shape,tar_in.shape)
    print(en_in[:2],es_in[:2],tar_in[:2])

(64, 21) (64, 22) (64, 22)
tf.Tensor(
[[   2    5  536    1  108   35 4821    4    3    0    0    0    0    0
     0    0    0    0    0    0    0]
 [   2   29 2297   22   11    3    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0]], shape=(2, 21), dtype=int64) tf.Tensor(
[[  2  11 817   6 526 177  81   1   4   3   0   0   0   0   0   0   0   0
    0   0   0   0]
 [  2  13   5 862  58  12   3   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0]], shape=(2, 22), dtype=int64) tf.Tensor(
[[ 11 817   6 526 177  81   1   4   3   0   0   0   0   0   0   0   0   0
    0   0   0   0]
 [ 13   5 862  58  12   3   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0]], shape=(2, 22), dtype=int64)


In [137]:
class Encoder(keras.layers.Layer):

    def __init__(self,units: int=256,vec_layer: keras.layers.TextVectorization=en_vec_layer,**kwargs):

        super(Encoder,self).__init__(**kwargs)
        self.units = units
        self.vec_layer = vec_layer
        self.vocab_size = vec_layer.vocabulary_size()

        self.embedder = keras.layers.Embedding(self.vocab_size,units,mask_zero=True)
        self.encoder_unit = keras.layers.Bidirectional(keras.layers.LSTM(units,return_state=True,return_sequences=True,recurrent_initializer="glorot_uniform"),merge_mode="sum")

    def call(self,encoder_inputs):

        shape_checker = ShapeCheck()
        shape_checker(encoder_inputs,"batch encoder_sequence")
        encoder_embedded_outputs = self.embedder(encoder_inputs)
        shape_checker(encoder_embedded_outputs,"batch encoder_sequence units")
        encoder_outputs,*encoder_state = self.encoder_unit(encoder_embedded_outputs)
        shape_checker(encoder_state[0],"batch units")
        shape_checker(encoder_state[1],"batch units")

        return encoder_outputs,encoder_state

In [None]:
encoder = Encoder()

In [87]:
for (en_in,es_in),tar_in in train_ds:
    encoder(en_in)

In [88]:
for (en_in,es_in),tar_in in valid_ds:
    encoder(en_in)

In [138]:
class Decoder(keras.layers.Layer):

    def __init__(self,units:int=256,vec_layer:keras.layers.TextVectorization=es_vec_layer,**kwargs):

        super(Decoder,self).__init__(**kwargs)
        self.units = units
        self.vec_layer = vec_layer
        self.vocab_size = vec_layer.vocabulary_size()

        self.embedder = keras.layers.Embedding(self.vocab_size,units,mask_zero=True)
        self.decoder_unit = keras.layers.LSTM(units,return_state=True,return_sequences=True,recurrent_initializer="glorot_uniform")


    def call(self,decoder_inputs,decoder_initial_state=None):

        shape_checker = ShapeCheck()
        shape_checker(decoder_inputs,"batch decoder_sequence")
        decoder_embedded_outputs = self.embedder(decoder_inputs)
        shape_checker(decoder_embedded_outputs,"batch decoder_sequence units")
        decoder_outputs,*decoder_state = self.decoder_unit(decoder_embedded_outputs,initial_state=decoder_initial_state)
        shape_checker(decoder_outputs,"batch decoder_sequence units")
        shape_checker(decoder_state[0],"batch units")
        shape_checker(decoder_state[1],"batch units")

        return decoder_outputs,decoder_state

In [98]:
decoder = Decoder()
for (en_in,es_in),tar_in in train_ds:
    decoder(es_in)

In [99]:
for (en_in,es_in),tar_in in valid_ds:
    decoder(es_in)

In [139]:
class CrossAttention(keras.layers.Layer):

    def __init__(self,units=256,**kwargs):

        super(CrossAttention,self).__init__(**kwargs)

        self.mha = keras.layers.MultiHeadAttention(num_heads=1,key_dim=units)
        self.add = keras.layers.Add()
        self.layer_norm = keras.layers.LayerNormalization()

    def call(self,encoder_outputs,decoder_outputs):

        shape_checker = ShapeCheck()
        shape_checker(encoder_outputs,"batch encoder_sequence units")
        shape_checker(decoder_outputs,"batch decoder_sequence units")

        attention_outputs,attention_scores = self.mha(query=decoder_outputs,value=encoder_outputs,return_attention_scores=True)
        shape_checker(attention_outputs,"batch decoder_sequence units")
        shape_checker(attention_scores,"batch num_heads decoder_sequence encoder_sequence")
        self.attention_scores = tf.reduce_mean(attention_scores,axis=1)
        normalized_attention_outputs = self.layer_norm(self.add([attention_outputs,decoder_outputs]))
        return normalized_attention_outputs

In [102]:
attention_layer = CrossAttention()

for (en_in,es_in),tar_in in train_ds:
    attention_layer(encoder(en_in),decoder(es_in))

for (en_in,es_in),tar_in in valid_ds:
    attention_layer(encoder(en_in),decoder(es_in))

In [140]:
class Translator(keras.Model):

    def __init__(self,units=256,**kwargs):

        super(Translator,self).__init__(**kwargs)

        self.encoder_layer = Encoder(units=units)
        self.decoder_layer = Decoder(units=units)
        self.attention_layer = CrossAttention(units=units)

        self.words_to_ids = keras.layers.StringLookup(
            vocabulary=self.decoder_layer.vec_layer.get_vocabulary(),
            oov_token="[UNK]",
            mask_token=""
        )
        self.ids_to_words = keras.layers.StringLookup(
            vocabulary=self.decoder_layer.vec_layer.get_vocabulary(),
            oov_token="[UNK]",
            mask_token="",
            invert=True
        )
        self.start_token = self.words_to_ids(["[START]"])
        self.end_token = self.words_to_ids(["[END]"])

        self.out = keras.layers.Dense(self.decoder_layer.vec_layer.vocabulary_size())


    def call(self,inputs):

        encoder_inputs,decoder_inputs = inputs
        encoder_outputs,self.encoder_state = self.encoder_layer(encoder_inputs)
        decoder_outputs,self.decoder_state = self.decoder_layer(decoder_inputs)
        attention_outputs = self.attention_layer(decoder_outputs,encoder_outputs)
        total_outputs = self.out(attention_outputs)

        try:
            del total_outputs._keras_mask

        except AttributeError as err:

            pass

        return total_outputs


    def text_to_encoder_outputs(self,texts):
        texts = tf.convert_to_tensor(texts)
        en_vec_outputs = self.encoder_layer.vec_layer(texts).to_tensor()
        return self.encoder(en_vec_outputs)

    def get_decoder_initial_state(self,encoder_outputs):
        batch_size = tf.shape(encoder_outputs)[0]
        start_tokens = tf.fill(dims=[batch_size,1],value=self.start_token)
        done = tf.zeros(shape=[batch_size,1],dtype=tf.bool)
        embedding = self.decoder.embedder(start_tokens)
        return start_tokens,done,self.decoder_layer.decoder_unit.get_initial_state(embedding)

    def get_next_token(self,encoder_inputs,next_token,done,state,temperature=0.0):
        total_out,state = self(encoder_inputs,next_token)

        if temperature:
            scaled_total_out = total_out/temperature
            next_token = tf.random.categorical(scaled_total_out,num_samples=1)
        else:
            next_token = tf.argmax(total_out,axis=-1)

        done = done | (next_token == self.end_token)
        next_token = tf.where(done,tf.constant(0,dtype=tf.int64),next_token)
        return next_token,done,state

    def tokens_to_text(self,tokens):
        texts = self.ids_to_words(tokens)
        texts = tf.strings.reduce_join(texts,separator=" ")
        texts = tf.strings.regex_replace(texts,r"^ *\[START\]* ","")
        texts = tf.strings.regex_replace(texts,r" *\[END]\ *$","")
        texts = tf.strings.strip(texts)
        return texts


In [141]:
model = Translator()

In [None]:
for en_in,es_in in train_ds.map(lambda x,y:x):
    model((en_in,es_in))

In [None]:
for en_in,es_in in valid_ds.map(lambda x,y:x):
    model((en_in,es_in))