# Reelix - Reranker Training

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import random
import numpy as np
import torch

SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)                        # Python random
np.random.seed(SEED)                     # NumPy
torch.manual_seed(SEED)                  # PyTorch CPU
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)         # PyTorch GPU
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:

from reranker_model import RerankerModel
from reranker_dataset import TripletDataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, random_split

# Model configs
MODEL_NAME = 'bert-base-uncased'
MAX_LEN = 512

# Training configs
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3

# Load tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
file_path = 'drive/My Drive/Colab Notebooks/MovieRecs/training_dataset/XXX.jsonl'
dataset = TripletDataset(file_path, tokenizer_name=MODEL_NAME, max_length=MAX_LEN)

# Split into train/val
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(f"Train size: {len(train_dataset)} | Val size: {len(val_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize model
model = RerankerModel(model_name=MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
from datetime import datetime
import zoneinfo

def generate_model_dir(strategy_name: str, base_model: str) -> str:
    base = "drive/My Drive/Colab Notebooks/MovieRecs/models/movie_reranker"
    timestamp = datetime.now(zoneinfo.ZoneInfo("America/Los_Angeles")).strftime("%Y%m%d_%H%M")
    return f"{base}_{strategy_name}_{base_model}_{timestamp}"

model_dir = generate_model_dir(strategy_name="first-train", base_model="bert")
print(model_dir)

In [None]:
# Set up checkpoints and model saving
save_dir = model_dir + "checkpoints"
os.makedirs(save_dir, exist_ok=True)

def save_checkpoint(model, optimizer, scheduler, epoch, filename):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }, os.path.join(save_dir, filename))


In [None]:

from transformers import AdamW, get_scheduler

# Exclude biases and layer norm weights from weight decay
no_decay = ["bias", "LayerNorm.weight"]

# Set weight decay coefficient for decay and no_decay (biases and layernorm) groups
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01, # 0 for decay group
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0, # 0 for no_decay group
    },
]

# Initialize AdamW optimizer
optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)

# Set warmup steps to 10%
num_training_steps = len(train_loader) * NUM_EPOCHS
num_warmup_steps = int(0.1 * num_training_steps)


# Schedule linear learning rate warmup & decay - from 0 to LR during warmup; from LR to 0 rest of the steps
scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)


In [None]:
from tqdm import tqdm

best_val_loss = float('inf')
best_epoch = -1

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} - Training"):
        # Extract model inputs from batch
        input_ids_pos      = batch['input_ids_pos'].to(device)
        attn_mask_pos      = batch['attention_mask_pos'].to(device)
        token_type_ids_pos = batch['token_type_ids_pos'].to(device)

        input_ids_neg      = batch['input_ids_neg'].to(device)
        attn_mask_neg      = batch['attention_mask_neg'].to(device)
        token_type_ids_neg = batch['token_type_ids_neg'].to(device)

        # Forward pass
        scores_pos = model(
            input_ids=input_ids_pos,
            attention_mask=attn_mask_pos,
            token_type_ids=token_type_ids_pos
        )
        scores_neg = model(
            input_ids=input_ids_neg,
            attention_mask=attn_mask_neg,
            token_type_ids=token_type_ids_neg
        )

        # Calculate Pairwise margin ranking loss
        margin = 1.0
        loss = torch.mean(torch.clamp(margin - (scores_pos - scores_neg), min=0))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Optimizer step
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    # Compute training loss
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")

    # Validation Evaluation
    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids_pos      = batch['input_ids_pos'].to(device)
            attn_mask_pos      = batch['attention_mask_pos'].to(device)
            token_type_ids_pos = batch['token_type_ids_pos'].to(device)

            input_ids_neg      = batch['input_ids_neg'].to(device)
            attn_mask_neg      = batch['attention_mask_neg'].to(device)
            token_type_ids_neg = batch['token_type_ids_neg'].to(device)

            scores_pos = model(
                input_ids=input_ids_pos,
                attention_mask=attn_mask_pos,
                token_type_ids=token_type_ids_pos
            )
            scores_neg = model(
                input_ids=input_ids_neg,
                attention_mask=attn_mask_neg,
                token_type_ids=token_type_ids_neg
            )

            # Compute loss
            margin = 1.0
            loss = torch.mean(torch.clamp(margin - (scores_pos - scores_neg), min=0))
            val_loss += loss.item()

            # Pairwise accuracy: how often pos > neg
            correct += (scores_pos > scores_neg).sum().item()
            total += scores_pos.size(0)

    avg_val_loss = val_loss / len(val_loader)
    accuracy = correct / total

    print(f"Validation Loss: {avg_val_loss:.4f} | Pairwise Accuracy: {accuracy:.4f}")

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch + 1
        torch.save(model.state_dict(), os.path.join(model_dir, "reranker_best.pt"))
        save_checkpoint(model, optimizer, scheduler, epoch, os.path.join(save_dir, "reranker_best_checkpoint.pt")) # Also save the checkpoint
        print(f"Best model saved at epoch {best_epoch}")


In [None]:
# Save final model
torch.save(model.state_dict(), os.path.join(model_dir, "reranker_final.pt"))
print("Final model saved after training.")

# Report best checkpoint
print(f"\nBest model: epoch {best_epoch} with val loss {best_val_loss:.4f}")