In [1]:
!pip install sentencepiece transformers==4.33 datasets sacremoses sacrebleu  -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━

In [10]:
import gc
import random
import numpy as np
import torch
import pandas as pd
from transformers import NllbTokenizer
from transformers import Adafactor, get_scheduler
from transformers import AutoModelForSeq2SeqLM
from transformers import get_constant_schedule_with_warmup
from tqdm.auto import tqdm, trange
from sklearn.model_selection import train_test_split

In [4]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
class Translator:

  def __init__(self, learningRate=1e-3, batchSize=16):
    self.tokenizer = NllbTokenizer.from_pretrained('facebook/nllb-200-distilled-600M')
    self.__fix_tokenizer()
    self.model = AutoModelForSeq2SeqLM.from_pretrained('facebook/nllb-200-distilled-600M')
    self.model.resize_token_embeddings(len(self.tokenizer))
    self.learningRate = learningRate
    self.batchSize = batchSize
    self.losses = []
    self.LANGS = [('eng', 'eng_Latn'), ('ami', 'ami_Latn')]


  def __setup(self):
    self.optimizer = Adafactor(
        [p for p in self.model.parameters() if p.requires_grad],
        scale_parameter=False,
        relative_step=False,
        lr=self.learningRate,  # Starting with a higher learning rate
        clip_threshold=1.0,
        weight_decay=1e-3,
      )

    self.scheduler = get_constant_schedule_with_warmup(
        self.optimizer,
        num_warmup_steps=500
      )


  def __generate_splits(self, data):
    self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(
        data['Amis'],
        data['English'],
        random_state=104,
        test_size=0.2,
        shuffle=True
      )


  def __fix_tokenizer(self, new_lang='ami_Latn'):
    old_len = len(self.tokenizer) - int(new_lang in self.tokenizer.added_tokens_encoder)
    self.tokenizer.lang_code_to_id[new_lang] = old_len-1
    self.tokenizer.id_to_lang_code[old_len-1] = new_lang
    # always move "mask" to the last position
    self.tokenizer.fairseq_tokens_to_ids["<mask>"] = len(self.tokenizer.sp_model) + len(self.tokenizer.lang_code_to_id) + self.tokenizer.fairseq_offset

    self.tokenizer.fairseq_tokens_to_ids.update(self.tokenizer.lang_code_to_id)
    self.tokenizer.fairseq_ids_to_tokens = {v: k for k, v in self.tokenizer.fairseq_tokens_to_ids.items()}
    if new_lang not in self.tokenizer._additional_special_tokens:
        self.tokenizer._additional_special_tokens.append(new_lang)
    # clear the added token encoder; otherwise a new token may end up there by mistake
    self.tokenizer.added_tokens_encoder = {}
    self.tokenizer.added_tokens_decoder = {}


  def __cleanup(self):
    gc.collect()
    torch.cuda.empty_cache()


  def __get_batch_pairs(self):
    xx, yy = [], []
    for _ in range(self.batchSize):
        src = self.x_train.iloc[random.randint(0, len(self.x_train)-1)]
        tgt = self.y_train.iloc[random.randint(0, len(self.y_train)-1)]
        xx.append(src)
        yy.append(tgt)
    return xx, yy, 'ami_Latn', 'eng_Latn'


  def __tokenize_and_transition(self, lang):
    return self.tokenizer(
        lang,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=128
      ).to(self.model.device)


  def evaluate(self):
    pd.Series(self.losses).ewm(100).mean().plot()


  def train(self, data, epochs: int=50000, save_interval: int=5000, save_to: str="'/content/drive/MyDrive/MTApplication/models/nllb-eng-ami-v1'"):
    self.__setup()
    self.__generate_splits(data)
    self.__cleanup()
    self.model.cuda();
    self.model.train()

    for i in tqdm(range(epochs)):
      xx, yy, lang1, lang2 = self.__get_batch_pairs()
      try:
          # Tokenization and moving tensors to GPU
          self.tokenizer.src_lang = lang1
          x = self.__tokenize_and_transition(xx)
          y = self.__tokenize_and_transition(yy)
          self.tokenizer.tgt_lang = lang2

          y.input_ids[y.input_ids == self.tokenizer.pad_token_id] = -100

          # Forward and backward passes
          self.optimizer.zero_grad()
          loss = self.model(**x, labels=y.input_ids).loss
          loss.backward()
          self.optimizer.step()
          self.scheduler.step()

          self.losses.append(loss.item())
          if i % 500 == 0:
              print(f"Step {i}, Loss: {loss.item()}")

          # Checkpoint saving
          if i % save_interval == 0 and i > 0:
              current_loss_avg = np.mean(self.losses[-1000:])
              print(f"Average Loss last 1000 steps: {current_loss_avg}")
              self.model.save_pretrained(f'{self.MODEL_SAVE_PATH}/checkpoint_{i}')
              self.tokenizer.save_pretrained(f'{self.MODEL_SAVE_PATH}/checkpoint_{i}')

      except RuntimeError as e:
          self.optimizer.zero_grad()
          self.cleanup()
          print('RuntimeError:', e)
          continue

      except Exception as e:
          self.optimizer.zero_grad()
          print('An error occurred:', e)
          break

    self.model.save_pretrained(f'{save_to}/final')
    self.tokenizer.save_pretrained(f'{save_to}/final')


    #def translate(self, source: str):
