<a href="https://colab.research.google.com/github/ganeshmp01/DS-Classwork/blob/main/TransformersTranslation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import os
os.environ['KERAS_BACKEND']= 'tensorflow'

import pathlib
import random
import string
import re
import numpy as np

import tensorflow.data as tf_data
import tensorflow.strings as tf_strings

import keras
from keras import layers
from keras import ops
from keras.layers import TextVectorization

In [5]:
text_file = keras.utils.get_file(
    fname = 'spa-eng.zip',
    origin = "https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
    extract = True,
)
text_file = pathlib.Path(text_file).parent / "spa-eng" / 'spa.txt'

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip
[1m2638744/2638744[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [6]:
with open(text_file) as f:
  lines = f.read().split("\n")[:-1]
text_pairs = []
for line in lines:
  eng,spa = line.split("\t")
  spa = "[start]" + spa + "[end]"
  text_pairs.append((eng, spa))

In [7]:
num_val_samples = int(0.10 * len(text_pairs))
num_train_samples = len(text_pairs) - 2* num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples:num_train_samples+num_val_samples]
test_pairs = text_pairs[num_train_samples+num_val_samples:]

print(f"{len(train_pairs)} Train Samples")
print(f"{len(val_pairs)} Val Samples")
print(f"{len(test_pairs)} Test Samples")

95172 Train Samples
11896 Val Samples
11896 Test Samples


In [8]:
## Processing
strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[","")
strip_chars = strip_chars.replace("]","")

vocab_size = 15000
sequence_length = 20
batch_size = 64

In [9]:
strip_chars

'!"#$%&\'()*+,-./:;<=>?@\\^_`{|}~¿'

In [10]:
def preprocessing(text):
  lowercase = tf_strings.lower(text)
  return tf_strings.regex_replace(lowercase, f"[{re.escape(strip_chars)}]", "")

In [11]:
eng_vectorization = TextVectorization(max_tokens=vocab_size, output_mode = "int", output_sequence_length = sequence_length )
spa_vectorization = TextVectorization(max_tokens=vocab_size, output_mode = "int", output_sequence_length = sequence_length+1, standardize = preprocessing)

train_eng_text = [pair[0] for pair in train_pairs]
train_spa_text = [pair[1] for pair in train_pairs]

eng_vectorization.adapt(train_eng_text)
spa_vectorization.adapt(train_spa_text)

In [12]:
def format_dataset(eng, spa):
  eng = eng_vectorization(eng)
  spa = spa_vectorization(spa)
  return (
      {
          'encoder_inputs':eng,
          'decoder_inputs':spa[:,:-1],
      },
      spa[:,1:]
  )

def make_dataset(pairs):
  eng_texts, spa_texts = zip(*pairs)
  eng_texts = list(eng_texts)
  spa_texts = list(spa_texts)
  dataset = tf_data.Dataset.from_tensor_slices((eng_texts,spa_texts))
  dataset = dataset.batch(batch_size)
  dataset = dataset.map(format_dataset)
  return dataset.cache().shuffle(2048).prefetch(16)

In [13]:
train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)

In [14]:
train_ds

<_PrefetchDataset element_spec=({'encoder_inputs': TensorSpec(shape=(None, None), dtype=tf.int64, name=None), 'decoder_inputs': TensorSpec(shape=(None, None), dtype=tf.int64, name=None)}, TensorSpec(shape=(None, None), dtype=tf.int64, name=None))>

In [15]:
import keras.ops as ops


class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, mask=None):
        if mask is not None:
            padding_mask = ops.cast(mask[:, None, :], dtype="int32")
        else:
            padding_mask = None

        attention_output = self.attention(
            query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
        )
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "dense_dim": self.dense_dim,
                "num_heads": self.num_heads,
            }
        )
        return config

In [16]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        length = ops.shape(inputs)[-1]
        positions = ops.arange(0, length, 1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        if mask is None:
            return None
        else:
            return ops.not_equal(inputs, 0)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config


class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(latent_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs)
        if mask is not None:
            padding_mask = ops.cast(mask[:, None, :], dtype="int32")
            padding_mask = ops.minimum(padding_mask, causal_mask)
        else:
            padding_mask = None

        attention_output_1 = self.attention_1(
            query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(out_2 + proj_output)

    def get_causal_attention_mask(self, inputs):
        input_shape = ops.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = ops.arange(sequence_length)[:, None]
        j = ops.arange(sequence_length)
        mask = ops.cast(i >= j, dtype="int32")
        mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = ops.concatenate(
            [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
            axis=0,
        )
        return ops.tile(mask, mult)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "latent_dim": self.latent_dim,
                "num_heads": self.num_heads,
            }
        )
        return config

In [17]:
embed_dim = 256
num_heads = 8
latent_dim = 2048

encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
encoder = keras.Model(encoder_inputs, encoder_outputs)

In [18]:
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)
x = layers.Dropout(0.5)(x)
decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)

In [19]:
decoder_outputs = decoder([decoder_inputs, encoder_outputs])
transformer = keras.Model(
    [encoder_inputs, decoder_inputs], decoder_outputs, name="transformer"
)

In [31]:
epochs = 40 #SEt it to 40
transformer.summary()

In [32]:
transformer.compile("rmsprop",loss ='sparse_categorical_crossentropy',metrics = ['accuracy'])
transformer.fit(train_ds,epochs = epochs, validation_data= val_ds)

Epoch 1/40
[1m1488/1488[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 61ms/step - accuracy: 0.8270 - loss: 1.2991 - val_accuracy: 0.9017 - val_loss: 0.7696
Epoch 2/40
[1m1488/1488[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m121s[0m 51ms/step - accuracy: 0.9488 - loss: 0.3988 - val_accuracy: 0.9718 - val_loss: 0.2846
Epoch 3/40
[1m1488/1488[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 50ms/step - accuracy: 0.9825 - loss: 0.1567 - val_accuracy: 0.9921 - val_loss: 0.1166
Epoch 4/40
[1m1488/1488[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 50ms/step - accuracy: 0.9898 - loss: 0.1084 - val_accuracy: 0.9977 - val_loss: 0.0606
Epoch 5/40
[1m1488/1488[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 50ms/step - accuracy: 0.9973 - loss: 0.0346 - val_accuracy: 0.9985 - val_loss: 0.0293
Epoch 6/40
[1m1488/1488[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 49ms/step - accuracy: 0.9996 - loss: 0.0155 - val_accuracy: 0.9996 - val_loss: 0.0111
Ep

<keras.src.callbacks.history.History at 0x7d9f6e696470>

In [33]:
spa_vocab = spa_vectorization.get_vocabulary()
spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))
max_decoded_sentence_length = 20


def decode_sequence(input_sentence):
    tokenized_input_sentence = eng_vectorization([input_sentence])
    decoded_sentence = "[start]"
    for i in range(max_decoded_sentence_length):
        tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]
        predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])

        # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
        sampled_token_index = ops.convert_to_numpy(
            ops.argmax(predictions[0, i, :])
        ).item(0)
        sampled_token = spa_index_lookup[sampled_token_index]
        decoded_sentence += " " + sampled_token

        if sampled_token == "[end]":
            break
    return decoded_sentence


test_eng_texts = [pair[0] for pair in test_pairs]
for _ in range(30):
    input_sentence = random.choice(test_eng_texts)
    translated = decode_sequence(input_sentence)


In [34]:
input_sentence

'In our park, we have a nice slide for children to play on.'

In [35]:
translated

'[start] mudamos amigas le lo es muchas [UNK] poniendo casi a tarde[end] para        '