In [2]:
from datasets import load_dataset 
from IPython.display import display, Markdown
from ipywidgets import widgets
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import torch
import time
import requests
from torch import optim
from torch.nn import functional as F
from transformers import AdamW, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm_notebook
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

sns.set()

In [3]:
model_repo = 'google/mt5-small'
model_path = './mt5_translation_small.pt'
max_sequence_length = 20

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_repo)

In [5]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_repo)
model = model.cuda()

In [6]:
languages_mapping = {
    'en': '<en>',
    'jv': '<jv>'
}

In [7]:
special_tokens_dict = {'additional_special_tokens': list(languages_mapping.values())}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

Embedding(250102, 768)

In [8]:
def encode_input_str(text, target_language, tokenizer, sequence_length, lang_mapping=languages_mapping):
    target_lang_token = lang_mapping[target_language]

    # tokenize
    input_ids = tokenizer.encode(
      text = target_lang_token + text,
      return_tensors = 'pt',
      padding = 'max_length',
      truncation = True, # max 20 sequence, model limitation
      max_length = sequence_length
    )

    return input_ids[0]

def encode_target_str(text, tokenizer, sequence_length):
    token_ids = tokenizer.encode(
      text = text,
      return_tensors = 'pt',
      padding = 'max_length',
      truncation = True, # max 20 sequence, model limitation
      max_length = sequence_length
    )

    return token_ids[0]

def format_translation_data(translations, tokenizer, sequence_length, lang_mapping=languages_mapping):
    # Choose randomly between en or jv as input and target languages
    languages = list(lang_mapping.keys())
    input_language, target_language = np.random.choice(languages, size=2, replace=False)

    # Get translations for the batch
    input_text = translations[input_language]
    target_text = translations[target_language]

    if input_text is None or target_text is None: 
        return None

    input_token_ids = encode_input_str(
    input_text, target_language, tokenizer, sequence_length)

    target_token_ids = encode_target_str(
    target_text, tokenizer, sequence_length)

    return input_token_ids, target_token_ids

def transform_batch(batch, tokenizer):
    inputs = []
    targets = []
    for translations_set in batch['translation']:
        formatted_data = format_translation_data(translations_set, tokenizer, max_sequence_length)

        if formatted_data is None: 
            continue

        input_lang_ids, target_lang_ids = formatted_data
        inputs.append(input_lang_ids.unsqueeze(0))
        targets.append(target_lang_ids.unsqueeze(0))
        # print(input_lang_ids)
        # print(target_lang_ids)
        # break

    batch_input_ids = torch.cat(inputs).cuda()
    batch_target_ids = torch.cat(targets).cuda()

    return batch_input_ids, batch_target_ids

def get_data_generator(dataset, tokenizer, batch_size=32):
    dataset = dataset.sample(frac=1)
    for i in range(0, len(dataset), batch_size):
        raw_batch = dataset[i:i+batch_size]
        yield transform_batch(raw_batch, tokenizer)

In [9]:
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [28]:
test_sentence = 'What are you thinking about so early in the morning?'
print('Raw input text:', test_sentence)

input_ids = encode_input_str(
    text = test_sentence,
    target_language='jv',
    tokenizer=tokenizer,
    sequence_length=model.config.max_length,
    lang_mapping=languages_mapping,
)
input_ids = input_ids.unsqueeze(0).cuda()

print('Truncated input:', tokenizer.convert_tokens_to_string(
    tokenizer.convert_ids_to_tokens(input_ids[0])
))

Raw input text: What are you thinking about so early in the morning?
Truncated input: <jv> What are you thinking about so early in the morning?</s><pad><pad><pad><pad><pad>


In [29]:
output_tokens = model.generate(input_ids, num_beams=10, num_return_sequences=10)
for token_set in output_tokens:
    print(tokenizer.decode(token_set, skip_special_tokens=True))

Apa sing dadi impenmu ing wayah bangun esuk?
Apa sing dadi impenmu esuk umun-umun awan?
Apa sing dadi impenmu ing wayah bangun esuk, kuwi 
Apa sing dadi impenmu ing wayah bangun esuk melek?
Apa sing dadi impenmu ing wayah esuk melek?
Apa sing dadi impenmu ing wayah esuk, esuke-e
Apa sing dadi impenmu ing wayah bangun esuk, awan 
Apa sing dadi impenmu ing wayah esuk?
Apa sing dadi impenmu ing wayah bangun esuk melek melek
Apa sing dadi impenmu ing wayah esuk, awan awan apa


In [30]:
decoded_first_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
url = 'https://latin-java.vercel.app/api/get_javanese_script'
latin_text = {'latin_text': decoded_first_text}
x = requests.post(url, json = latin_text)

display(Markdown(f"## English: \n ### {test_sentence}"))
display(Markdown(f"## Javanese: \n ### {tokenizer.decode(output_tokens[0], skip_special_tokens=True)}"))
display(Markdown(f"\n ## Javanese script: \n ## {eval(x.text)['javanese_script']}\n ‎"))

## English: 
 ### What are you thinking about so early in the morning?

## Javanese: 
 ### Apa sing dadi impenmu ing wayah bangun esuk?


 ## Javanese script: 
 ## ꧋ꦲꦥꦱꦶꦁꦢꦢꦶꦲꦶꦩ꧀ꦥꦺꦤ꧀ꦩꦸꦲꦶꦁꦮꦪꦃꦧꦔꦸꦤ꧀ꦲꦺꦱꦸꦏ꧀?
 ‎