In [None]:
import pandas as pd
import numpy as np
from unidecode import unidecode
import tensorflow as tf 
from sklearn.model_selection import train_test_split
from transformers import TFAutoModelForSeq2SeqLM, ByT5Tokenizer
from tensorflow.keras.callbacks import EarlyStopping
import re
import warnings
warnings.filterwarnings("ignore")

In [2]:
def get_data(url: str, n_rows: int):
    data = pd.read_csv(url).iloc[:n_rows, 1:]
    data = data.sample(frac=1).copy()
    data.columns = ["input", "target"]
    data = data.loc[(~data["input"].isna()) & (~data["target"].isna())].copy()
    return data

In [3]:
data = get_data(url="data/eng_hindi.csv", n_rows=100000)

In [4]:
data.head()

Unnamed: 0,input,target
3771,when a bunch of cartoonists in Denmark,जब डेनमार्क के कुछ कार्टूनिस्टों ने
62641,"On sex, reach and Pedersen, adds mobile commun...",लिंग रिच और पेडरसेन पर एड्स.मोबाइल संचार: सामा...
1959,The resultant oil slick led to the death of a ...,इस तेल की बनी मोटी परतों के कारण बहुत सी मछलिय...
90314,"Among them, the GangaSahastraNam (Hundred name...",जिनमें श्रीगंगासहस्रनामस्तोत्रम् और आरती सबसे ...
55049,But he was soon sobered by the innate lucidity...,"लेकिन शीर्घ्र ही वे अपने सहज , स्वाभाविक विचार..."


In [5]:
def preprocess_data(data: pd.DataFrame):
    data = data.copy()
    data["input"] = data["input"].map(unidecode).copy()
    return data

In [6]:
data = preprocess_data(data)

In [7]:
data

Unnamed: 0,input,target
3771,when a bunch of cartoonists in Denmark.,जब डेनमार्क के कुछ कार्टूनिस्टों ने.
62641,"On sex, reach and Pedersen, adds mobile commun...",लिंग रिच और पेडरसेन पर एड्स.
1959,The resultant oil slick led to the death of a ...,इस तेल की बनी मोटी परतों के कारण बहुत सी मछलिय...
90314,"Among them, the GangaSahastraNam (Hundred name...",जिनमें श्रीगंगासहस्रनामस्तोत्रम् और आरती सबसे ...
55049,But he was soon sobered by the innate lucidity...,"लेकिन शीर्घ्र ही वे अपने सहज , स्वाभाविक विचार..."
...,...,...
81648,"Apart from him, several other Litterateurs hav...",इनके अलावा और भी अनेक साहित्यकारों ने रामायण स...
92496,And this is very similar to a Q and U example.,और यह बहुत कुछ एक Q और U के उदाहरण जैसा है।.
16893,No wonder hybrid corn has revolutionised agric...,इस बात पर कोई आश्चर्य नहीं होना चाहिए कि अमेरि...
92093,Do n't be worried that babies might be sick an...,इस बात की चिन्ता न करें कि यदि शिशुओं को पीठ प...


In [8]:
def split_data(data: pd.DataFrame, input_col: str="input", target_col: str="target", test_size: float=0.1):
    x_train, x_test, y_train, y_test = train_test_split(data[input_col], data[target_col], 
                                                        random_state=42, test_size=test_size)
    
    print(f'x_train.shape: {x_train.shape}, x_test.shape: {x_test.shape}, '+
          f'y_train.shape: {y_train.shape}, y_test.shape: {y_test.shape}')
    x_train, x_test, y_train, y_test = x_train.to_list(), x_test.to_list(), y_train.to_list(), y_test.to_list()
    return x_train, x_test, y_train, y_test

In [9]:
x_train, x_test, y_train, y_test = split_data(data=data)

x_train.shape: (89998,), x_test.shape: (10000,), y_train.shape: (89998,), y_test.shape: (10000,)


In [10]:
x_train[0], y_train[0]

('The component of probable present radiative forcing by IPCC Fourth Assessment Report.',
 'आईपीसीसी की चतुर्थ मूल्यांकन रिर्पोट द्वारा (radiative forcing)अनुमानित वर्तमान विकिरणशील बाध्यता के घटक (IPCC Fourth Assessment Report).')

In [11]:
strategy = tf.distribute.MirroredStrategy()
CHECKPOINT = "google/byt5-small"
N_TOKENS = 200
BATCH_SIZE = 8 * strategy.num_replicas_in_sync

In [12]:
def tokenize(input: list, target: list, n_tokens: int):
    tokenizer = ByT5Tokenizer.from_pretrained(CHECKPOINT)
    print(f'Example:\n{input[0]}\n{tokenizer.tokenize(input[0])}')
    tokenized_data = tokenizer(text=input, text_target=target, 
                               max_length=n_tokens, truncation=True, padding="max_length")
    return tokenized_data    

In [13]:
tokenize(input=x_train[-2:], target=y_train[-2:], n_tokens=N_TOKENS)

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.50k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.59k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/698 [00:00<?, ?B/s]

Example:
She walked out of the plane.
['S', 'h', 'e', ' ', 'w', 'a', 'l', 'k', 'e', 'd', ' ', 'o', 'u', 't', ' ', 'o', 'f', ' ', 't', 'h', 'e', ' ', 'p', 'l', 'a', 'n', 'e', '.']


{'input_ids': [[86, 107, 104, 35, 122, 100, 111, 110, 104, 103, 35, 114, 120, 119, 35, 114, 105, 35, 119, 107, 104, 35, 115, 111, 100, 113, 104, 49, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [76, 35, 102, 100, 113, 35, 100, 111, 118, 114, 35, 103, 114, 35, 118, 114, 112, 104, 35, 77, 100, 102, 110, 108, 104, 35, 70, 107, 100, 113, 48, 112, 114, 119, 108, 114, 113, 47, 49, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [14]:
tokenized_train = tokenize(input=x_train, target=y_train, n_tokens=N_TOKENS)
tokenized_test = tokenize(input=x_test, target=y_test, n_tokens=N_TOKENS)

Example:
The component of probable present radiative forcing by IPCC Fourth Assessment Report.
['T', 'h', 'e', ' ', 'c', 'o', 'm', 'p', 'o', 'n', 'e', 'n', 't', ' ', 'o', 'f', ' ', 'p', 'r', 'o', 'b', 'a', 'b', 'l', 'e', ' ', 'p', 'r', 'e', 's', 'e', 'n', 't', ' ', 'r', 'a', 'd', 'i', 'a', 't', 'i', 'v', 'e', ' ', 'f', 'o', 'r', 'c', 'i', 'n', 'g', ' ', 'b', 'y', ' ', 'I', 'P', 'C', 'C', ' ', 'F', 'o', 'u', 'r', 't', 'h', ' ', 'A', 's', 's', 'e', 's', 's', 'm', 'e', 'n', 't', ' ', 'R', 'e', 'p', 'o', 'r', 't', '.']
Example:
The namaskara-mandapa has profuse wood-carvings , while the wall of the shrine has interesting mural paintings .
['T', 'h', 'e', ' ', 'n', 'a', 'm', 'a', 's', 'k', 'a', 'r', 'a', '-', 'm', 'a', 'n', 'd', 'a', 'p', 'a', ' ', 'h', 'a', 's', ' ', 'p', 'r', 'o', 'f', 'u', 's', 'e', ' ', 'w', 'o', 'o', 'd', '-', 'c', 'a', 'r', 'v', 'i', 'n', 'g', 's', ' ', ',', ' ', 'w', 'h', 'i', 'l', 'e', ' ', 't', 'h', 'e', ' ', 'w', 'a', 'l', 'l', ' ', 'o', 'f', ' ', 't', 'h', 'e', '

In [15]:
def return_tf_tensors(data):
    data = tf.data.Dataset.from_tensor_slices(dict(data))
    data = data.prefetch(tf.data.AUTOTUNE)
    return data

In [16]:
train_tf_data = return_tf_tensors(tokenized_train)
test_tf_data = return_tf_tensors(tokenized_test)

In [17]:
for i in train_tf_data.take(2):
    print(i)

{'input_ids': <tf.Tensor: shape=(200,), dtype=int32, numpy=
array([ 87, 107, 104,  35, 102, 114, 112, 115, 114, 113, 104, 113, 119,
        35, 114, 105,  35, 115, 117, 114, 101, 100, 101, 111, 104,  35,
       115, 117, 104, 118, 104, 113, 119,  35, 117, 100, 103, 108, 100,
       119, 108, 121, 104,  35, 105, 114, 117, 102, 108, 113, 106,  35,
       101, 124,  35,  76,  83,  70,  70,  35,  73, 114, 120, 117, 119,
       107,  35,  68, 118, 118, 104, 118, 118, 112, 104, 113, 119,  35,
        85, 104, 115, 114, 117, 119,  49,   1,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
    

In [18]:
def fit_model(train_data, val_data, epochs=2, eta=1e-4, early_stopping_patience=1, batch_size=BATCH_SIZE):
    with strategy.scope():
        model = TFAutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
        model.compile(optimizer=tf.keras.optimizers.Adam(eta))

    print(model.summary())
    early_stop = EarlyStopping(monitor="val_loss", patience=early_stopping_patience, mode="min")
    model.fit(train_data.shuffle(len(train_data)).batch(batch_size), validation_data=val_data.shuffle(len(val_data)).batch(batch_size), 
          epochs=epochs, callbacks=[early_stop])
    return model

In [19]:
model = fit_model(train_data=train_tf_data, val_data=test_tf_data)

Downloading tf_model.h5:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at google/byt5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Model: "tft5_for_conditional_generation"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 shared (Embedding)          multiple                  565248    
                                                                 
 encoder (TFT5MainLayer)     multiple                  217657472 
                                                                 
 decoder (TFT5MainLayer)     multiple                  81980288  
                                                                 
 lm_head (Dense)             multiple                  565248    
                                                                 
Total params: 299,637,760
Trainable params: 299,637,760
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/2
Epoch 2/2


**INFERENCE**

In [20]:
def inference_tokenize(input: list, n_tokens: int):
    tokenizer = ByT5Tokenizer.from_pretrained(CHECKPOINT)
    tokenized_data = tokenizer(text=input, max_length=n_tokens, truncation=True, padding="max_length", return_tensors="tf")
    return tokenizer, tokenized_data    

In [23]:
def inference(txt: str):
    test_data = [txt]
    inference_tokenizer, tokenized_data = inference_tokenize(input=test_data, n_tokens=N_TOKENS)
    pred = model.generate(**tokenized_data, max_new_tokens=N_TOKENS)
    result = inference_tokenizer.decode(pred[0])
    result = re.sub("<.*?>", "", result)
    print(f"ENGLISH:\n{txt}\n\nHINDI:\n{result}")
    return (txt, result)

In [31]:
txt = "What is your name?"
txt, result = inference(txt)

ENGLISH:
What is your name?

HINDI:
आपका नाम क्या है?.


In [33]:
txt = x_test[1]
txt, result = inference(txt)

ENGLISH:
The imports of grey goods from Japan increased from 75.

HINDI:
जापान से काफी उपभोक्ताओं की कमी से बढ़ती गई थी।.


In [34]:
txt = x_test[500]
txt, result = inference(txt)

ENGLISH:
the uses of it are different in Hindi.

HINDI:
इसके उपयोग अलग हैं।.


In [36]:
txt = x_test[1100]
txt, result = inference(txt)

ENGLISH:
Everyone wants to hear their news anchor say, "Mister Splashy Pants.

HINDI:
हर समाचार को अपने समाचार अनुसार कहना चाहते हैं, “मिस्टर स्प्लैशी पान्ट्स.
