# Reelix - Reranker Training

In [None]:
from google.colab import drive
# Mount Google drive for data/model load & save
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import random
import torch

# Set up random seeds for reproducibility
SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
import sys
sys.path.append('/content/drive/My Drive/Colab Notebooks/MovieRecs/Reelix Reranker Model')

# Import model and dataset builder
from reranker_model import RerankerModel
from reranker_dataset import TripletDataset

In [None]:
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, random_split

# Model configs
MODEL_NAME = 'bert-base-uncased'
MAX_LEN = 512

# Training configs
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3

# Load tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
file_path = 'drive/My Drive/Colab Notebooks/MovieRecs/training_dataset/reranker_train_36k_hard-negative_0805.jsonl'
dataset = TripletDataset(file_path, tokenizer_name=MODEL_NAME, max_length=MAX_LEN)

# Split into train/val
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(f"Train size: {len(train_dataset)} | Val size: {len(val_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# Initialize model
model = RerankerModel(model_name=MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Train size: 32400 | Val size: 3600


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

RerankerModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementw

In [None]:
from torch.optim import AdamW
from transformers import get_scheduler

# Exclude biases and layer norm weights from weight decay
no_decay = ["bias", "LayerNorm.weight"]

# Set weight decay coefficient for decay and no_decay groups
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01, # 0 for decay group
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0, # 0 for no_decay group
    },
]

# Initialize AdamW optimizer
optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)

# Set warmup steps to 10%
num_training_steps = len(train_loader) * NUM_EPOCHS
num_warmup_steps = int(0.1 * num_training_steps)

# Schedule linear learning rate warmup & decay - from 0 to LR during warmup; from LR to 0 rest of the steps
scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)


In [None]:
from datetime import datetime
import zoneinfo

# Create parameterized directory to save trained model
def generate_model_dir(strategy_name: str, base_model: str) -> str:
    base = "drive/My Drive/Colab Notebooks/MovieRecs/Reelix Reranker Model/movie_reranker"
    timestamp = datetime.now(zoneinfo.ZoneInfo("America/Los_Angeles")).strftime("%Y%m%d_%H%M")
    return f"{base}_{strategy_name}_{base_model}_{timestamp}"

model_dir = generate_model_dir(strategy_name="hard-negative", base_model="bert")
os.makedirs(model_dir, exist_ok=True)
print(model_dir)

drive/My Drive/Colab Notebooks/MovieRecs/Reelix Reranker Model/movie_reranker_hard-negative_bert_20250805_1653


In [None]:
# Set up checkpoints saving
save_dir = model_dir + "/checkpoints"
os.makedirs(save_dir, exist_ok=True)
print (save_dir)

def save_checkpoint(model, optimizer, scheduler, epoch, filename):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }, os.path.join(save_dir, filename))


drive/My Drive/Colab Notebooks/MovieRecs/Reelix Reranker Model/movie_reranker_hard-negative_bert_20250805_1653/checkpoints


In [None]:
from tqdm import tqdm

best_val_loss = float('inf')
best_epoch = -1

# Training loop
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} - Training"):
        # Extract model inputs from batch
        input_ids_pos      = batch['input_ids_pos'].to(device)
        attn_mask_pos      = batch['attention_mask_pos'].to(device)
        token_type_ids_pos = batch['token_type_ids_pos'].to(device)

        input_ids_neg      = batch['input_ids_neg'].to(device)
        attn_mask_neg      = batch['attention_mask_neg'].to(device)
        token_type_ids_neg = batch['token_type_ids_neg'].to(device)

        # Forward pass
        scores_pos = model(
            input_ids=input_ids_pos,
            attention_mask=attn_mask_pos,
            token_type_ids=token_type_ids_pos
        )
        scores_neg = model(
            input_ids=input_ids_neg,
            attention_mask=attn_mask_neg,
            token_type_ids=token_type_ids_neg
        )

        # Calculate Pairwise margin ranking loss
        margin = 1.0
        loss = torch.mean(torch.clamp(margin - (scores_pos - scores_neg), min=0))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Optimizer step
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    # Compute training loss for the epoch
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")

    # Validation Evaluation for the epoch
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    margin_values = []

    with torch.no_grad():
        for batch in val_loader:
            input_ids_pos      = batch['input_ids_pos'].to(device)
            attn_mask_pos      = batch['attention_mask_pos'].to(device)
            token_type_ids_pos = batch['token_type_ids_pos'].to(device)

            input_ids_neg      = batch['input_ids_neg'].to(device)
            attn_mask_neg      = batch['attention_mask_neg'].to(device)
            token_type_ids_neg = batch['token_type_ids_neg'].to(device)

            scores_pos = model(
                input_ids=input_ids_pos,
                attention_mask=attn_mask_pos,
                token_type_ids=token_type_ids_pos
            )
            scores_neg = model(
                input_ids=input_ids_neg,
                attention_mask=attn_mask_neg,
                token_type_ids=token_type_ids_neg
            )

            # Compute loss
            margin = 1.0
            loss = torch.mean(torch.clamp(margin - (scores_pos - scores_neg), min=0))
            val_loss += loss.item()

            # Pairwise accuracy - how often pos > neg
            correct += (scores_pos > scores_neg).sum().item()
            total += scores_pos.size(0)

            # Track margin for each batch
            margin_scores = (scores_pos - scores_neg).detach().cpu().numpy()
            margin_values.extend(margin_scores)

    avg_val_loss = val_loss / len(val_loader)
    accuracy = correct / total

    print(f"Validation Loss: {avg_val_loss:.4f} | Pairwise Accuracy: {accuracy:.4f} | Avg score margin: {avg_margin:.4f}")

    # Save the best model & checkpoint
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch + 1
        torch.save(model.state_dict(), os.path.join(model_dir, "reranker_best.pt"))
        save_checkpoint(model, optimizer, scheduler, epoch, "reranker_best_checkpoint.pt") # Also save the checkpoint
        print(f"Best model saved at epoch {best_epoch}")


  return forward_call(*args, **kwargs)
Epoch 1 - Training: 100%|██████████| 2025/2025 [41:26<00:00,  1.23s/it]


Epoch 1 Train Loss: 0.3844
Validation Loss: 0.2539 | Pairwise Accuracy: 0.8867
Best model saved at epoch 1


Epoch 2 - Training: 100%|██████████| 2025/2025 [41:12<00:00,  1.22s/it]


Epoch 2 Train Loss: 0.2147
Validation Loss: 0.2244 | Pairwise Accuracy: 0.9022
Best model saved at epoch 2


Epoch 3 - Training: 100%|██████████| 2025/2025 [41:06<00:00,  1.22s/it]


Epoch 3 Train Loss: 0.1333
Validation Loss: 0.2343 | Pairwise Accuracy: 0.9047


In [None]:
# Save final model
torch.save(model.state_dict(), os.path.join(model_dir, "reranker_final.pt"))
save_checkpoint(model, optimizer, scheduler, epoch, "reranker_latest_checkpoint.pt")
print("Final model saved after training.")

# Report best checkpoint
print(f"\nBest model: epoch {best_epoch} with val loss {best_val_loss:.4f}")