In [None]:
!pip install transformers sentencepiece datasets translate-toolkit --quiet

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import clear_output
from IPython.utils import io
import torch
from torch import optim
from torch.nn import functional as F

with io.capture_output() as captured:
  !pip install transformers sentencepiece

from transformers import AdamW, AutoTokenizer, get_linear_schedule_with_warmup
from tqdm.notebook import tqdm

from transformers.models.mt5 import MT5Config, MT5ForConditionalGeneration


In [None]:
!wget https://object.pouta.csc.fi/OPUS-Tatoeba/v2021-07-22/tmx/en-ru.tmx.gz

In [None]:
!gzip -d /content/en-ru.tmx.gz

In [None]:
with open("en-ru.tmx", 'r', encoding="utf-8") as input_file:
  for x in range(50):
    print(input_file.readline())

In [None]:
from translate.storage.tmx import tmxfile

with open("en-ru.tmx", 'rb') as input_file:
  tmx_file = tmxfile(input_file, 'en', 'ru')

In [None]:
dataset = []

for node in tmx_file.unit_iter():
  dataset.append({'en': node.source, 'ru': node.target})

In [None]:
train_dataset = dataset[:10000]
test_dataset = dataset[10000:15000]

In [None]:
import random

k = random.randint(0, 100000)

print(k)

In [None]:
model_repo = 'google/mt5-base'

config = MT5Config.from_pretrained(model_repo)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_repo)

In [None]:
model = MT5ForConditionalGeneration.from_pretrained(model_repo)

In [None]:
LANG_TOKEN_MAPPING = {
    'ru': '<ru>',
    'en': '<en>'
}

In [None]:
example_input_str = '<ru>Привет Мир!.'
input_ids = tokenizer.encode(example_input_str,
                             return_tensors='pt',
                             padding='max_length',
                             truncation=True,
                             max_length=40)
print(input_ids)

tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print(tokens)

In [None]:
def encode_input_str(text, target_lang, tokenizer, seq_len,
                     lang_token_map=LANG_TOKEN_MAPPING):
  target_lang_token = lang_token_map[target_lang]

  # Tokenize and add special tokens
  input_ids = tokenizer.encode(
      text = target_lang_token + text,
      return_tensors = 'pt',
      padding = 'max_length',
      truncation = True,
      max_length = seq_len)

  return input_ids[0]

def encode_target_str(text, tokenizer, seq_len):
  token_ids = tokenizer.encode(
      text = text,
      return_tensors = 'pt',
      padding = 'max_length',
      truncation = True,
      max_length = seq_len)

  return token_ids[0]

def format_translation_data(translations, lang_token_map,
                            tokenizer, seq_len=128):
  # Choose a random 2 languages for in i/o
  langs = list(lang_token_map.keys())
  input_lang, target_lang = np.random.choice(langs, size=2, replace=False)

  # Get the translations for the batch
  input_text = translations[input_lang]
  target_text = translations[target_lang]

  # print(input_lang, input_text)
  # print(target_lang, target_text)

  if input_text is None or target_text is None:
    return None

  input_token_ids = encode_input_str(input_text, target_lang, tokenizer, seq_len, lang_token_map)
  target_token_ids = encode_target_str(target_text, tokenizer, seq_len)

  return input_token_ids, target_token_ids

def transform_batch(batch, lang_token_map, tokenizer):
  inputs = []
  targets = []
  for translation_set in batch:
    formatted_data = format_translation_data(translation_set, lang_token_map, tokenizer, max_seq_len)

    # print(formatted_data)

    if formatted_data is None:
      continue

    input_ids, target_ids = formatted_data
    inputs.append(input_ids.unsqueeze(0))
    targets.append(target_ids.unsqueeze(0))

  batch_input_ids = torch.cat(inputs).cuda()
  batch_target_ids = torch.cat(targets).cuda()

  return batch_input_ids, batch_target_ids

def get_data(dataset, lang_token_map, tokenizer, batch_size=32):
  np.random.shuffle(dataset)
  for i in range(0, len(dataset), batch_size):
    raw_batch = dataset[i:i+batch_size]

    yield transform_batch(raw_batch, lang_token_map, tokenizer)

In [None]:
n_epochs = 5
batch_size = 15
print_freq = 100
max_seq_len=40
lr = 5e-4

checkpoint_freq = 1000

n_batches = int(np.ceil(len(train_dataset) / batch_size))
total_steps = n_epochs * n_batches
n_warmup_steps = int(total_steps * 0.01)

print("n_batches", n_batches)
print("total_steps", total_steps)
print("n_warmup_steps", n_warmup_steps)

In [None]:
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer, n_warmup_steps, total_steps)

losses = []
test_losses = []

In [None]:
def eval_model(model, dataset, max_iters=8):
  test_generator = get_data(dataset, LANG_TOKEN_MAPPING,
                                      tokenizer, batch_size)
  eval_losses = []
  with torch.no_grad():
    for i, (input_batch, label_batch) in enumerate(test_generator):
      if i >= max_iters:
        break

      model_out = model.forward(
          input_ids = input_batch,
          labels = label_batch)
      eval_losses.append(model_out.loss.item())

  return np.mean(eval_losses)

In [None]:
best_test_loss = float('inf')

In [None]:
for epoch_idx in range(n_epochs):
  data_generator = get_data(train_dataset, LANG_TOKEN_MAPPING, tokenizer, batch_size)

  for batch_idx, (input_batch, label_batch) in tqdm(enumerate(data_generator), total=n_batches):

    optimizer.zero_grad()

    model_out = model.forward(
        input_ids = input_batch,
        labels = label_batch)

    loss = model_out.loss
    losses.append(loss.item())

    loss.backward()

    optimizer.step()
    scheduler.step()

      # Print training update info
    if (batch_idx + 1) % print_freq == 0:
      avg_loss = np.mean(losses[-print_freq:])
      print('Epoch: {} | Step: {} | Avg. loss: {:.3f} | lr: {:.6f}'.format(
          epoch_idx+1, batch_idx+1, avg_loss, scheduler.get_last_lr()[0]))

    if (batch_idx + 1) % checkpoint_freq == 0:
      test_loss = eval_model(model, test_dataset)
      test_losses.append(test_loss)
      print('Test loss {:.3f}'.format(test_loss))
      if best_test_loss > test_loss:
        print('Saving model with test loss of {:.3f}'.format(test_loss))
        torch.save(model.state_dict(), model_path)
        best_test_loss = test_loss

  torch.save(model.state_dict(), model_path)

In [None]:
window_size = 50
smoothed_losses = []
for i in range(len(losses)-window_size):
  smoothed_losses.append(np.mean(losses[i:i+window_size]))

plt.plot(smoothed_losses[100:])

In [None]:
plt.plot(test_losses[:])

In [None]:
test_sentence = test_dataset[16]['en']
print('Raw input text:', test_sentence)

input_ids = encode_input_str(
    text = test_sentence,
    target_lang = 'ru',
    tokenizer = tokenizer,
    seq_len = model.config.max_length,
    lang_token_map = LANG_TOKEN_MAPPING)

input_ids = input_ids.unsqueeze(0).cuda()

print('Truncated input text:', tokenizer.convert_tokens_to_string(
    tokenizer.convert_ids_to_tokens(input_ids[0])))

In [None]:
output_tokens = model.generate(input_ids, num_beams=10, num_return_sequences=3)

for token_set in output_tokens:
  print(tokenizer.decode(token_set, skip_special_tokens=True))