# Fine-Tuning NLLB

The purpose of this notebook is document the process of fine-tuning an NLLB model for translating from Literary Tibetan to English. Some of the code is notebook is based on and adapted from the the tutorial ['How To Fine Tune a NLLB 200 Model for Translating A New Language'](https://cointegrated.medium.com/how-to-fine-tune-a-nllb-200-model-for-translating-a-new-language-a37fc706b865)

I've also drawn on ['Neural Machine Translation With Keras NLP'](https://keras.io/examples/nlp/neural_machine_translation_with_keras_nlp/) for the preprocessing steps.

In [2]:
from transformers.optimization import Adafactor
from transformers import get_constant_schedule_with_warmup
from transformers.optimization import Adafactor
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm.auto import trange
import numpy as np
import pathlib
import random

## Preprocessing Text Pairs

### Loading the Data

The code below loads in the text pairs as a list, [Tibetan, English].

In [3]:
text_file = pathlib.Path('../../data/training-batches/training-batch-1.txt')

with open(text_file) as f:
    lines = f.read().split("\n")[:-1]
xx=[]
yy=[]
for line in lines:
    try:
        tib, eng = line.split(",")[:2]
        eng = eng.lower()
        xx.append(tib)
        yy.append(eng)
    except:
        pass

### Tokenize the Text Pairs

Below, I've used NLLB's pretrained tokenizers to tokenize the data.

In [4]:
max_length = 128  # token sequences will be truncated
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

In [5]:
def get_batch_pairs(batch_size, data):
    (l1, long1), (l2, long2) = random.sample([('bo', 'bo'), ('en', 'eng_Latn')], 2)
    xx, yy = [], []
    for _ in range(batch_size):
        item = data.iloc[random.randint(0, len(data)-1)]
        xx.append(item[l1])
        yy.append(item[l2])
    return xx, yy, long1, long2

## Training the Model

### Pre-Trained Model
Here, I've downloaded the pre-trained NLLB model.

In [6]:
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


### Optimizer

Below, I've selected the Adafactor optimizer for training. The values passed to the optimizer are taken from the tutorial mentioned above and are arbitrary.

In [7]:
optimizer = Adafactor(
    [p for p in model.parameters() if p.requires_grad],
    scale_parameter=False,
    relative_step=False,
    lr=1e-4,
    clip_threshold=1.0,
    weight_decay=1e-3,
)
scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=1000)

In [8]:
batch_size = 16
training_steps = 10000  # Set a large number of steps,
# and then just interrupt the training manually
losses = []  # with this list, I do very simple tracking of average loss
MODEL_SAVE_PATH = 'nllb'

In [9]:
def train(xx, yy, batch_size, optimizer):
    x, y, loss = None, None, None

    tq = trange(len(losses), training_steps)
    for i in tq:
        lang1 = 'bo'
        lang2 = 'eng_Latn'
        tokenizer.src_lang = lang1
        x = tokenizer(xx, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(model.device)
        tokenizer.src_lang = lang2
        y = tokenizer(yy, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(model.device)
        # -100 is a magic value ignored in the loss function
        # because we don't want the model to learn to predict padding ids
        y.input_ids[y.input_ids == tokenizer.pad_token_id] = -100

        loss = model(**x, labels=y.input_ids).loss
        loss.backward()
        losses.append(loss.item())

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        if i % 1000 == 0:
            # each 1000 steps, I report average loss at these steps
            print(i, np.mean(losses[-1000:]))

        if i % 1000 == 0 and i > 0:
            model.save_pretrained(MODEL_SAVE_PATH+str(i))

In [10]:
train(xx[0], yy[0], batch_size, optimizer)

  0%|          | 0/10000 [00:00<?, ?it/s]

0 5.024219512939453


KeyboardInterrupt: 