In [None]:
import os
import sys

sys.path.append(os.path.abspath("..")) # make src available as a package

SEED = int(os.getenv("SEED", 42))

# 1. Dataset Preparation

In [None]:
from src.data import mMARCO


# NOTE: the current implementation of `collate_fn` in `mMARCO`
# doubles each sample in size, so 16*2 = 32
batch_size = 16

# only used for testing here, use the dataloaders otherwise
mmarco = mMARCO(seed=SEED, shuffle_buffer_size=64, shuffle=False)
print(mmarco._data.info.splits)

for sample in mmarco:
    print(sample)
    break

# 2. Training

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# for Apple's Metal (MPS) backend / framework
if torch.backends.mps.is_available():
    device = torch.device("mps")

print(f"Using device: {device}")

Pointwise cross-encoders are supposed to output a class label (e.g., *relevant*
or *not-relevant*) or a relevance score (e.g., from 0 to 1). Now, this model
returns raw logits (practically a relevance score). However, the dataset does
not have any gold labels, which could be used to compute a loss.

Fortunately, the `(query, positive, negative)` triplets can easily be framed as
such:
- `(query, positive)=1` now means *relevant*
- `(query, negative)=0` now means *not-relevant*

Et voilà, this can be used for trainig.

During training, we have to apply sigmoid to the raw logits to get a binary
score though. However, during inference, we want the raw logits to build a
sorted ranking.

NOTE: There are quite a few CEL functions...
- `nn.CrossEntropyLoss` 
  - **use case**: multi-class classification
  - **input**: raw logits for each class (`input`), and the target class indices (`target`).
- `nn.BCELoss`
  - **use case**: binary classification or multi-label classification (independent classes)
  - **input**: probabilities (so sigmoid must already be applied) and the target class indices (`target`).
- `nn.BCEWithLogitsLoss` ← ✅
  - **use case**: binary classification or multi-label classification (independent classes)
  - **input**: raw logits (so sigmoid is applied internally), and the target class indices (`target`).


### Hyperparameters

In [None]:
from src.models import CrossEncoderBERT
import torch.optim as optim
import torch.nn as nn

padding = "longest"
max_length = 256

model = CrossEncoderBERT(
    enable_gradient_checkpointing=True,
)
criterion = nn.BCEWithLogitsLoss()
learning_rate = 2e-5
epochs = 3
# epochs = 20
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)


train_dl, val_dl, test_dl = mmarco.as_dataloaders(seed=SEED, batch_size=batch_size)

### Progress Bars

In [None]:
from string import Formatter
import ipywidgets as widgets

progress_widget = widgets.IntProgress(
    value=0,
    min=0,
    max=epochs,
    description=f"Epoch: 0/{epochs}",
)

def update_progress(epoch: int):
    progress_widget.value = epoch + 1
    progress_widget.description = f"Epoch: {epoch + 1}/{epochs}"
    

training_loss_widget = widgets.FloatProgress(
    value=1,
    min=0.0,
    max=1.0,
    description="Train: 1.000",
    bar_style="",
    style={"bar_color": "darkred"},
)

val_loss_widget = widgets.FloatProgress(
    value=1,
    min=0.0,
    max=1.0,
    description="Val: 1.000",
    bar_style="",
    style={"bar_color": "darkred"},
)


def update_loss(loss: float, widget: widgets.FloatProgress, name: str = "Loss"):
    widget.value = loss
    widget.description = f"{name}: {loss:.3f}"

    match loss:
        case loss if loss < 0.25:
            widget.style = {"bar_color": "green"}
        case loss if loss < 0.50:
            widget.style = {"bar_color": "orange"}
        case loss if loss < 0.75:
            widget.style = {"bar_color": "red"}
        case _:
            widget.style = {"bar_color": "darkred"}

class Counter:

    def __init__(self, name: str = "Counter: {count}", *, show: bool = True) -> None:
        self.name = name
        self.tmpl = Formatter().parse(name)
        placerholders = [field_name for _, field_name, _, _ in self.tmpl if field_name]
        if "count" not in placerholders:
            raise ValueError("The format string 'name' must contain a 'count' placeholder.")
        self.count: int | float = 0
        self.widget = widgets.Label(
            value="Iterations: 0"
        )

        if show:
            display(self.widget)
    
    def update(self, count: int | float):
        self.count += count
        self.widget.value = self.name.format(count=self.count)

    def set(self, count: int | float):
        self.widget.value = self.name.format(count=count)

    def clear(self):
        self.count = 0
        self.widget.value = self.name.format(count=self.count)


### Training Loop

In [None]:
from datetime import datetime
from typing import cast

from src.data.mmarco import mMARCOBatch
from IPython.display import display

def tokenize(batch: mMARCOBatch):
    return model.tokenize(
        batch["queries"],  # pyright: ignore[reportGeneralTypeIssues]
        batch["candidates"],  # pyright: ignore[reportGeneralTypeIssues]
        padding=padding,
        max_length=max_length,
    ).to(device)

start = datetime.now()
print(f"Starting at {start.strftime('%Y-%m-%d %H:%M:%S')}")
batch_counter = Counter("Batches: {count}", show=False)
iter_counter = Counter("Iterations: {count}", show=False)
deviation_label = Counter("σ: {count}", show=False)
display(widgets.VBox([
    widgets.HBox([progress_widget, batch_counter.widget, iter_counter.widget]),
    widgets.HBox([training_loss_widget, deviation_label.widget, val_loss_widget]),
]))

model.to(device)

train_losses = [[] for _ in range(epochs)]
val_losses = [[] for _ in range(epochs)]

for epoch in range(epochs):

    # --- training loop ---
    model.train()
    for batch in train_dl:
        optimizer.zero_grad()

        batch = cast(mMARCOBatch, batch)    
        input = tokenize(batch)

        with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type != "cpu")):
            logits = model.forward(
                input_ids=input["input_ids"],  # type: ignore
                attention_mask=input["attention_mask"],  # type: ignore
                token_type_ids=input["token_type_ids"],  # type: ignore
            ).squeeze(-1)
            loss = criterion(logits, batch["labels"].float().to(device))
            train_losses[epoch].append(loss.item())

        loss.backward()
        optimizer.step()

        update_loss(loss.item(), training_loss_widget, "Train")
        batch_counter.update(1)
        iter_counter.update(len(batch["queries"]))

        deviation_label.set(
            torch.tensor(train_losses[epoch]).std(unbiased=False).item()
        )

    # --- validation loop ---
    model.eval()
    with torch.no_grad():
        for batch in val_dl:
            batch = cast(mMARCOBatch, batch)    
            input = tokenize(batch)
            logits = model.forward(  # type: ignore
                input_ids=input["input_ids"],  # type: ignore
                attention_mask=input["attention_mask"],  # type: ignore
                token_type_ids=input["token_type_ids"],  # type: ignore
            ).squeeze(-1)
            loss = criterion(logits, batch["labels"].float().to(device))

            val_losses[epoch].append(loss.item())
            update_loss(loss.item(), val_loss_widget, "Val")


    update_progress(epoch)

    break
        
    
model.eval()
stop = datetime.now()
elapsed = stop - start
print(f"Stopping at {stop.strftime('%Y-%m-%d %H:%M:%S')}. Elapsed time: {elapsed}")

### Saving The Model

In [None]:
torch.save(model.state_dict(), "bert_mmarco.pth")

# load with:
# model = CrossEncoderBERT()
# model.load_state_dict(torch.load("model_state.pth"))
# model.eval()

### Plotting

During training and validation, the loss was collected for each batch.
However, I am plotting the loss here at epoch-level. The losses are
aggregated by epoch and a deviation is computed.

In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
from typing import Sequence

def _summarize(loss_matrix: Sequence[Sequence[float]]):
    """Helper to summarize a matrix of losses (list of lists) into means, stds, and counts per epoch."""
    means: list[float] = []
    stds: list[float] = []
    counts: list[int] = []
    for losses in loss_matrix:
        if len(losses) == 0:
            means.append(float("nan"))
            stds.append(float("nan"))
            counts.append(0)
        else:
            arr = np.asarray(losses, dtype=float)
            means.append(float(arr.mean()))
            stds.append(float(arr.std(ddof=0)))  # population std for stability
            counts.append(len(arr))
    return means, stds, counts

train_means, train_stds, train_counts = _summarize(train_losses)
val_means, val_stds, val_counts = _summarize(val_losses)

epochs_axis = np.arange(1, len(train_losses) + 1)

fig, ax = plt.subplots(figsize=(8, 5), dpi=110)

# Plot lines
ax.plot(epochs_axis, train_means, marker='o', label='Train Mean Loss', color='C0')
ax.plot(epochs_axis, val_means, marker='o', label='Val Mean Loss', color='C1')

# Std deviation bands (ignore NaNs automatically by masking)
train_upper = np.array(train_means) + np.array(train_stds)
train_lower = np.array(train_means) - np.array(train_stds)
val_upper = np.array(val_means) + np.array(val_stds)
val_lower = np.array(val_means) - np.array(val_stds)

ax.fill_between(epochs_axis, train_lower, train_upper, color='C0', alpha=0.20, linewidth=0)
ax.fill_between(epochs_axis, val_lower, val_upper, color='C1', alpha=0.20, linewidth=0)

# Highlight best (lowest) validation epoch
if any(not math.isnan(v) for v in val_means):
    best_epoch_idx = int(np.nanargmin(val_means))
    ax.scatter(epochs_axis[best_epoch_idx], val_means[best_epoch_idx], s=120, color='C1', edgecolor='k', zorder=5)
    ax.annotate(
        f"Best Val\nEpoch {best_epoch_idx+1}\n{val_means[best_epoch_idx]:.4f}",
        (epochs_axis[best_epoch_idx], val_means[best_epoch_idx]),
        textcoords="offset points", xytext=(10, -5), ha='left', va='top', fontsize=9,
        bbox=dict(boxstyle='round,pad=0.3', fc='white', ec='C1', alpha=0.8)
    )

ax.set_xlabel('Epoch')
ax.set_ylabel('Binary Cross-Entropy Loss')
ax.set_title('Training vs Validation Loss')
ax.grid(True, alpha=0.25)
ax.legend()
ax.margins(x=0.05)
plt.tight_layout()

# Structured summary printout
for i, (tr_m, tr_s, tr_c, va_m, va_s, va_c) in enumerate(zip(train_means, train_stds, train_counts, val_means, val_stds, val_counts), start=1):
    print(f"Epoch {i:02d} | Train: {tr_m:.4f} ± {tr_s:.4f} (n={tr_c}) | Val: {va_m:.4f} ± {va_s:.4f} (n={va_c})")

plt.show()

# 3. Evaluation

In [None]:
model.eval()
with torch.no_grad():
    test_loss = 0.0
    # TODO: also perform accuracy and other metrics

# 4. Inference

In [None]:
def rank(query: str, candidates: list[str]) -> list[tuple[str, float]]:
    model.eval()
    inputs = model.tokenize([query]*len(candidates), list(candidates)).to(device)
    with torch.no_grad():
        logits = model.forward(
            input_ids=inputs["input_ids"],  # type: ignore
            attention_mask=inputs["attention_mask"],  # type: ignore
            token_type_ids=inputs["token_type_ids"],  # type: ignore
        )
    scores = torch.sigmoid(logits).squeeze(-1).tolist()
    return sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)


In [None]:
test_query = "What is the capital of France?"
test_candidates = [
    "Paris is the capital of France.",
    "Marseille is a city in France.",
    "Lyon is known for its cuisine.",
    "France is in Europe.",
    "Macron is the president of France.",
    "Berlin is the capital of Germany.",
    "Madrid is the capital of Spain.", 
    "Rome is the capital of Italy."
]

ranked_results = rank(test_query, test_candidates)

for candidate, score in ranked_results:
    print(f"Score: {score:.4f} - Candidate: {candidate}")
    