# Fine-Tuning NLLB

The purpose of this notebook is document the process of fine-tuning an NLLB model for translating from Literary Tibetan to English. 

Some of the code in this notebook is based on the the tutorial ['How To Fine Tune a NLLB 200 Model for Translating A New Language'](https://cointegrated.medium.com/how-to-fine-tune-a-nllb-200-model-for-translating-a-new-language-a37fc706b865). However, the training loop and preprocessing have been heavily revised.

In [None]:
from transformers.optimization import Adafactor
from transformers import get_constant_schedule_with_warmup
from transformers.optimization import Adafactor
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import matplotlib.pyplot as plt

from tqdm.auto import trange
import numpy as np
import random
import gc
import torch
import os

## Preprocessing Text Pairs

### Loading the Data

The code below loads in the text pairs as a list, [Tibetan, English]. Then batches them. This is a helper function for the custom training loop.

In [None]:
def get_batch_pairs(data, batch_size, num_batches):

    print(f'Loading {data}...', end='\r')
    
    data_path = '../../data/training-batches/' + data

    with open(data_path) as f:
        lines = f.read().split("\n")[:-1]

    pairs = []
    for line in lines:
        try:
            tib, eng = line.split(",")[:2]
            eng = eng.lower()
            pairs.append([tib, eng])
        except:
            pass

    print(f'Batching {data}... ', end='\r')
    
    copy = pairs.copy()
    batches = []

    for i in range(num_batches):
        xx, yy = [], []
        for _ in range(batch_size):
            i = random.randint(0, len(copy)-1)
            item = copy[i]
            xx.append(item[0])
            yy.append(item[1])
            del copy[i]
        batches.append([xx, yy])

    print(f'Training on {data}             ')

    return batches

## Training the Model

### Pre-Trained Model
Here, I've downloaded the pre-trained NLLB model and its associated tokenizer.

'nllb-checkpoint-0' is my proof-of-concept NLLB finetuned on the dataset for 5 epochs.

In [None]:
#model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").cuda()
model = AutoModelForSeq2SeqLM.from_pretrained("/home/j/Documents/Projects/MLotsawa/notebooks/nllb/nllb-checkpoint-0").cuda()
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

### Optimizer

Below, I've selected the Adafactor optimizer for training. The values passed to the optimizer are taken from the tutorial mentioned above and are arbitrary.

In [None]:
optimizer = Adafactor(
    [p for p in model.parameters() if p.requires_grad],
    scale_parameter=False,
    relative_step=False,
    lr=1e-4,
    clip_threshold=1.0,
    weight_decay=1e-3,
)
scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=1000)

In [None]:
MODEL_SAVE_PATH = 'nllb-checkpoint-0.'

### Training Loop

Below, I've written a custom training. The first draft was adapted from the previously mentioned tutorial but it has since been substantially re-written.

In [None]:
all_losses = []

epoch_losses = []

In [None]:
def train(data_dir, optimizer, batch_size=16, epochs=1):
    global all_losses, epoch_losses
    
    x, y, loss = None, None, None
    gc.collect()
    torch.cuda.empty_cache()

    steps_per_batch = 100

    

    for _ in range(epochs):

            remaining_shards = os.listdir(data_dir)

            losses = []  # simple tracking of average loss

            for i in range(len(os.listdir(data_dir))):

                random.shuffle(remaining_shards)

                shard = remaining_shards[0]

                del remaining_shards[0]


                batches = get_batch_pairs(shard, batch_size, steps_per_batch)

                shard_losses = []

                desc = ('Epoch '+str(_)+ '.' + str(i) +', Shard: ' + str(shard))
                tq = trange(len(shard_losses), steps_per_batch, desc=desc) # take 100 random batches from each data shard
                
                for i in tq:

                    xx, yy = batches[i][0], batches[i][1]

                    try:

                        tokenizer.src_lang = 'bo'
                        x = tokenizer(xx, return_tensors='pt', padding=True, truncation=True, max_length=128).to(model.device)
                        tokenizer.src_lang = 'eng_Latn'
                        y = tokenizer(yy, return_tensors='pt', padding=True, truncation=True, max_length=128).to(model.device)
                        # -100 is a magic value ignored in the loss function
                        # because we don't want the model to learn to predict padding ids
                        y.input_ids[y.input_ids == tokenizer.pad_token_id] = -100

                        loss = model(**x, labels=y.input_ids).loss
                        loss.backward()
                        losses.append(loss.item())
                        shard_losses.append(loss.item())

                        optimizer.step()
                        optimizer.zero_grad(set_to_none=True)
                        scheduler.step()

                        print('loss: ' + str(np.mean(losses)),  end="\r")

                    except RuntimeError as e:  # usually, it is out-of-memory
                        optimizer.zero_grad(set_to_none=True)
                        x, y, loss = None, None, None
                        gc.collect()
                        torch.cuda.empty_cache()
                        continue

                print('loss: ' + str(np.mean(losses)))
                all_losses+=losses
                try:
                    plt.close()
                except:
                    pass

                ys1 = all_losses
                xs = [x for x in range(len(ys1))]

                plt.subplot(2, 1, 1)
                plt.plot(xs, ys1)

                ys2 = epoch_losses
                xs = [x for x in range(len(ys2))]

                plt.subplot(2, 1, 2)
                plt.plot(xs, ys2)

                plt.show()

            epoch_losses.append(losses[-1])
            model.save_pretrained(MODEL_SAVE_PATH+str(_))

    return [all_losses, epoch_losses]

In [None]:
history = train(data_dir='../../data/training-batches', optimizer=optimizer, batch_size=32, epochs=100)