# Reranker model training

## Set-up

In [None]:
use_wandb = True

In [None]:
DATA_DIR = "data"
MODEL_DIR = "models"

In [None]:
try:
    import google.colab
    from google.colab import drive
    drive.mount('/content/drive')
    FULL_DATA_DIR = f'/content/drive/My Drive/mbr-reranking/{DATA_DIR}'
    FULL_MODEL_DIR = f'/content/drive/My Drive/mbr-reranking/{MODEL_DIR}'

    IN_COLAB = True
except:
    FULL_DATA_DIR = DATA_DIR
    FULL_MODEL_DIR = MODEL_DIR

    IN_COLAB = False

In [None]:
try:
    import sentencepiece
except:
    !pip install sentencepiece
    import sentencepiece

In [None]:
if use_wandb:
    try:
        import wandb
    except:
        !pip install wandb
        import wandb

In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import XLMRobertaTokenizer, XLMRobertaModel
from transformers import AdamW

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

## Model config

In [None]:
config = {
    "head_learning_rate": 1e-4,
    "learning_rate": 1e-5,
    "lr_end_factor": 1e-2,
    "epochs": 3,
    "warmup_epochs": 1,
    "batch_size": 8,
    "warmup_batch_size": 32,
    "valid_batch_size": 128,
    "validation_size": 16_000,
}

## Reading dataset (pt. 1)

In [None]:
import json

# Load the data
with open(f"{FULL_DATA_DIR}/sampled/train.scores", 'rb') as f:
    scores_all = np.load(f)

# ref_all = []
original_all = []
generated_all = []

# with open(f"{FULL_DATA_DIR}/train.eng", 'r') as fp:
#     for line in fp:
#         ref_all.append(line.strip())

with open(f"{FULL_DATA_DIR}/train.deu", 'r') as fp:
    for line in fp:
        original_all.append(line.strip())

with open(f"{FULL_DATA_DIR}/sampled/train.eng", 'r') as fp:
    for line in fp:
        generated_all.append(line.strip())

with open(f"{FULL_DATA_DIR}/sampled/train.info.json", 'r') as f:
    metadata = json.load(f)

    samples_per_sentence = metadata["samples_per_sentence"]

    original_all = [original_all[i // samples_per_sentence] for i in range(len(generated_all))]
    # ref_all = [ref_all[i // samples_per_sentence] for i in range(len(generated_all))]

assert(len(original_all) == len(generated_all) == len(scores_all))

config["train_size"] = len(original_all) - config["validation_size"]

# Train test-split
scores_train = scores_all[:-config["validation_size"]]
original_train = original_all[:-config["validation_size"]]
generated_train = generated_all[:-config["validation_size"]]

scores_valid = scores_all[-config["validation_size"]:]
original_valid = original_all[-config["validation_size"]:]
generated_valid = generated_all[-config["validation_size"]:]

## Model

In [None]:
# Mean pooling
class MeanPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, hidden_states, attention_mask):
        # Mean pooling
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        mean_pooled = sum_embeddings / sum_mask
        return mean_pooled

# Model
class RegressionModel(nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.pretrained_model = pretrained_model
        self.regression_head = torch.nn.Linear(pretrained_model.config.hidden_size, 1)
        self.pooling = MeanPooling()
        self.pretrained_frozen = False

    def forward(self, input_ids, attention_mask):
        if self.pretrained_frozen:
            with torch.no_grad():
                token_embeddings = self.pretrained_model(input_ids, attention_mask=attention_mask)
                pooled_embedding = self.pooling(token_embeddings.last_hidden_state, attention_mask)
        else:
            token_embeddings = self.pretrained_model(input_ids, attention_mask=attention_mask)
            pooled_embedding = self.pooling(token_embeddings.last_hidden_state, attention_mask)
        return self.regression_head(pooled_embedding)

    def freeze_pretrained(self):
        self.pretrained_frozen = True
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

    def unfreeze_pretrained(self):
        self.pretrained_frozen = False
        for param in self.pretrained_model.parameters():
            param.requires_grad = True

In [None]:
# Load the pre-trained model
pretrained_model = XLMRobertaModel.from_pretrained('xlm-roberta-base')
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')

# Define the model
model = RegressionModel(pretrained_model)

# Optimizer settings
head_optimizer = AdamW(model.regression_head.parameters(), lr=config["head_learning_rate"])
optimizer = AdamW(model.parameters(), lr=config["learning_rate"])
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=config["lr_end_factor"], total_iters=(config["epochs"]-config["warmup_epochs"]) * len(generated_train) // config["batch_size"])
scaler = torch.cuda.amp.GradScaler()
loss_fn = nn.MSELoss()

model.train()
model = model.to(device)

## Loading dataset (pt. 2)

In [None]:
from torch.utils.data import Sampler
import random
import math

class SimilarLengthSampler(Sampler):
    def __init__(self, dataset, batch_size, shuffle=False):
        super().__init__(dataset)
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Group sorted indices into bins of size 'batch_size * 100' (or any large number)
        bin_size = batch_size * 100
        self.bins = [list(range(i, min(i + bin_size, len(dataset)))) for i in range(0, len(dataset), bin_size)]

    def __iter__(self):
        if self.shuffle:
            # Shuffle bins and sequences within each bin
            random.shuffle(self.bins)
            bins = [random.sample(bin, len(bin)) for bin in self.bins]
        else:
            bins = self.bins

        # Flatten the list of bins and chunk them into batches
        flattened_bins = [idx for bin in bins for idx in bin]
        batches = [flattened_bins[i:i + self.batch_size] for i in range(0, len(flattened_bins), self.batch_size)]

        if self.shuffle:
            # Optionally shuffle the batches
            random.shuffle(batches)

        for batch in batches:
            yield batch

    def __len__(self):
        return math.ceil(len(self.dataset) / self.batch_size)

In [None]:
class RegressionDataset(Dataset):
    def __init__(self, original, generated, targets):
        texts = [o + tokenizer.sep_token + g for g, o in zip(generated, original)]
        self.encodings = tokenizer(texts, truncation=True, padding=True)
        self.targets = torch.tensor(targets, dtype=torch.float)

        self.sorted_indices = sorted(range(len(texts)), key=lambda i: sum(self.encodings["attention_mask"][i]))

    def __getitem__(self, index):
        idx = self.sorted_indices[index]
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.targets[idx]
        return item

    def __len__(self):
        return len(self.targets)

# Create datasets and dataloaders
dataset_train = RegressionDataset(original_train, generated_train, scores_train)
dataloader_train = DataLoader(dataset_train, batch_size=config["batch_size"], shuffle=True)
dataloader_train_warmup = DataLoader(dataset_train, batch_size=config["warmup_batch_size"], shuffle=True)

dataset_valid = RegressionDataset(original_valid, generated_valid, scores_valid)
dataloader_valid = DataLoader(dataset_valid, batch_size=config["valid_batch_size"], shuffle=True)

In [None]:
def truncate_batch(input_ids, attention_mask):
    # Find the maximum sequence length in this batch
    max_len = attention_mask.sum(dim=1).max().item()

    # Truncate input_ids and attention_mask to max_len
    truncated_input_ids = input_ids[:, :max_len]
    truncated_attention_mask = attention_mask[:, :max_len]

    return truncated_input_ids, truncated_attention_mask

# @torch.compile
def train_batch(batch, opt):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels'].to(device)

    input_ids, attention_mask = truncate_batch(input_ids, attention_mask)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    with torch.autocast(device_type=device, dtype=torch.float16):
        outputs = model(input_ids, attention_mask=attention_mask).squeeze()
        loss = loss_fn(outputs, labels)

    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()
    opt.zero_grad()

    return loss.detach().item()

## Training

In [None]:
if use_wandb:
    wandb.init(project="mbr-reranking", config=config)
    wandb_run_name = wandb.run.name

model_name = "model"
if use_wandb:
    model_name = f"{wandb_run_name}"

update_loss_every = 10
wandb_logs_per_epoch = 100

wandb_steps_per_log = (len(dataloader_train) + wandb_logs_per_epoch-1) // wandb_logs_per_epoch

random.seed(0)
torch.manual_seed(0)

running_loss = 0
running_loss_correction = 0
ema_factor = 0.002

best_valid_loss = float('inf')

# Training loop
model.freeze_pretrained()
opt = head_optimizer
dataloader = dataloader_train_warmup
for epoch in range(config["epochs"]):

    # Freeze the pre-trained model's parameters in the first epoch
    if epoch >= config["warmup_epochs"]:
        model.unfreeze_pretrained()
        opt = optimizer
        dataloader = dataloader_train

    loss_sum = 0
    loss_items = 0

    # Train
    pbar = tqdm(dataloader, total=len(dataloader))
    for step, batch in enumerate(pbar):
        loss = train_batch(batch, opt)

        loss_sum += loss
        loss_items += 1

        running_loss = (1-ema_factor)*running_loss + ema_factor*loss
        running_loss_correction = (1-ema_factor)*running_loss_correction + ema_factor

        if step % update_loss_every == update_loss_every-1:
            pbar.set_description(f"Train loss: {running_loss / running_loss_correction:.3f}")

        if step % wandb_steps_per_log == wandb_steps_per_log-1:
            # Log metrics to wandb
            if use_wandb:
                wandb.log({"epoch": epoch + step / len(dataloader), "running_loss": running_loss / running_loss_correction, "lr": scheduler.get_last_lr()[0]})

        torch.cuda.empty_cache()

        if epoch >= config["warmup_epochs"]:
            scheduler.step()

    # Validate
    valid_loss_sum = 0
    valid_loss_items = 0
    pbar = tqdm(dataloader_valid, total=len(dataloader_valid))
    for step, batch in enumerate(pbar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with torch.no_grad():
            with torch.autocast(device_type=device, dtype=torch.float16):
                outputs = model(input_ids, attention_mask=attention_mask).squeeze()
                loss = loss_fn(outputs, labels)

        valid_loss_sum += loss
        valid_loss_items += 1

        if step % update_loss_every == update_loss_every-1:
            pbar.set_description(f"Valid loss: {valid_loss_sum / valid_loss_items:.3f}")


    # Log metrics to wandb
    if use_wandb:
        wandb.log({"epoch": epoch+1, "train_loss": loss_sum / loss_items, "valid_loss": valid_loss_sum / valid_loss_items})

    # Save the model if it's better
    if valid_loss_sum / valid_loss_items < best_valid_loss:
        print("Saving model")
        torch.save(model.state_dict(), f'{FULL_MODEL_DIR}/{model_name}.pt')
        best_valid_loss = valid_loss_sum / valid_loss_items

    print(f"Epoch: {epoch}, Train loss: {loss_sum / loss_items}, Valid loss: {valid_loss_sum / valid_loss_items}")

In [None]:
# Finalize wandb run
if use_wandb:
    wandb.finish()

## End

If in Google Colab, kill the session:

In [None]:
if IN_COLAB:
    import time
    time.sleep(15)

    from google.colab import runtime
    runtime.unassign()