<a href="https://colab.research.google.com/github/h4ck4l1/datasets/blob/main/NLP_with_RNN_and_Attention/Bidirectional_RNN_with_BeamSearch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
url = "https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip"

In [2]:
from google.colab import auth
auth.authenticate_user()
import os,warnings
os.environ["TF_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore")
import numpy as np
import tensorflow as tf
from tensorflow import keras
from zipfile import ZipFile
import plotly.graph_objects as go
tf.get_logger().setLevel("ERROR")

In [3]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

In [4]:
with tf.device("/job:localhost"):
    file_path = keras.utils.get_file(fname="/content/spa-eng.zip",origin=url,extract=True)
    with ZipFile(file_path,"r") as f:
        f.extractall("/content/spa-eng")
    with open("/content/spa-eng/spa-eng/spa.txt","r") as f:
        text = f.read()
    text = text.replace("¿","").replace("¡","")
    en_text,es_text = zip(*[line.split("\t") for line in text.splitlines()])

In [5]:
def get_layers(en_text,es_text,vocab_size=1000,seq_length=50):

    en_vec_layer = keras.layers.TextVectorization(vocab_size,output_sequence_length=seq_length)
    es_vec_layer = keras.layers.TextVectorization(vocab_size,output_sequence_length=seq_length)
    en_vec_layer.adapt(en_text)
    es_vec_layer.adapt([f"soseq {s} eoseq" for s in es_text])
    return en_vec_layer,es_vec_layer

In [6]:
def get_data(en_vec_layer,es_vec_layer,en_text,es_text,train_size):

    x_train = en_vec_layer(tf.constant(en_text[:train_size]))
    x_valid = en_vec_layer(tf.constant(en_text[train_size:]))
    x_dec_train = es_vec_layer(tf.constant([f"soseq {s}" for s in es_text[:train_size]]))
    x_dec_valid = es_vec_layer(tf.constant([f"soseq {s}" for s in es_text[train_size:]]))
    y_train = es_vec_layer(tf.constant([f"{s} eoseq" for s in es_text[:train_size]]))
    y_valid = es_vec_layer(tf.constant([f"{s} eoseq" for s in es_text[train_size:]]))
    return (x_train,x_dec_train),y_train,(x_valid,x_dec_valid),y_valid

In [7]:
class BeamSearch(keras.Model):

    def __init__(self,vocab_size=1000,embed_size=128,**kwargs):

        super(BeamSearch,self).__init__(**kwargs)
        self.en_embed = keras.layers.Embedding(vocab_size,embed_size,mask_zero=True)
        self.es_embed = keras.layers.Embedding(vocab_size,embed_size,mask_zero=True)
        self.encoder = keras.layers.Bidirectional(keras.layers.LSTM(256,return_state=True))
        self.decoder = keras.layers.LSTM(512,return_sequences=True)
        self.out = keras.layers.Dense(vocab_size,"softmax")

    def call(self,inputs):

        encoder_inputs = inputs[0]
        decoder_inputs = inputs[1]
        encoder_embed_out = self.en_embed(encoder_inputs)
        decoder_embed_out = self.es_embed(decoder_inputs)
        encoder_out, *encoder_state_out = self.encoder(encoder_embed_out)
        final_encoder_state = [tf.concat(encoder_state_out[::2],axis=-1),tf.concat(encoder_state_out[1::2],axis=-1)]
        decoder_out = self.decoder(decoder_embed_out,initial_state=final_encoder_state)
        return self.out(decoder_out)


In [8]:
def piecewise(epoch,lr):
    if epoch < 6:
        return lr
    elif epoch < 10:
        return 5e-4
    else:
        return 5e-4 * tf.math.exp(-0.1695*(epoch-10))

with strategy.scope():
    BATCH_SIZE = 50*8
    train_size = 100_000
    valid_size = len(en_text) - train_size
    train_steps = train_size//BATCH_SIZE
    valid_steps = valid_size//BATCH_SIZE
    en_vec_layer,es_vec_layer = get_layers(en_text,es_text)
    X_train,y_train,X_valid,y_valid = get_data(en_vec_layer,es_vec_layer,en_text,es_text,train_size)
    lr_call = keras.callbacks.LearningRateScheduler(piecewise)
    beam_model = BeamSearch()
    beam_model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=keras.optimizers.AdamW(learning_rate=1e-3),
        metrics=["accuracy"],
        steps_per_execution=20
    )

In [9]:
beam_model.fit(
    X_train,
    y_train,
    validation_data=(X_valid,y_valid),
    epochs=20,
    batch_size=BATCH_SIZE,
    steps_per_epoch=train_steps,
    validation_steps=valid_steps,
    callbacks=[lr_call]
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7cc52438b0d0>

In [20]:
def beam_sentence(sentence:str,beam_width=3):

    translation_list = []
    X_inp = en_vec_layer(tf.constant([sentence]))
    X_dec_inp = es_vec_layer(tf.constant(["soseq"]))
    first_out = beam_model.predict((X_inp,X_dec_inp),verbose=0)[0,0]
    top_beam_proba,top_beam_indices = tf.math.top_k(first_out,beam_width)

    def get_translation_with_proba(translation:str):
        proba_total = 0
        for word_id in range(1,50):
            X_dec_inp = es_vec_layer(tf.constant(["soseq "+translation]))
            output = beam_model.predict((X_inp,X_dec_inp),verbose=0)[0,word_id]
            pred_word = es_vec_layer.get_vocabulary()[np.argmax(output)]
            if pred_word == "eoseq":
                break
            proba_total += np.math.log(np.max(output))
            translation += " " + pred_word
        return translation.strip(),proba_total

    for i in range(beam_width):
        first_word = es_vec_layer.get_vocabulary()[top_beam_indices[i]]
        total_sentence,proba_total = get_translation_with_proba(first_word)
        proba_total += np.math.log(top_beam_proba[i])
        translation_list.append((proba_total,total_sentence))

    return translation_list

In [21]:
sentence = "I love cats and dogs"
print(beam_sentence(sentence))

[(-3.478966907935391, 'amo a los perros y gatos'), (-8.099883536242025, 'me encanta la música y los gatos'), (-5.053357050505417, '[UNK] a los perros y los gatos')]
