# Listwise Cross-Encoder

In [3]:
import os
import sys

sys.path.append(os.path.abspath(".."))  # make src available as a package

In [4]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# for Apple's Metal (MPS) backend / framework
if torch.backends.mps.is_available():
    device = torch.device("mps")

print(f"Using device: {device}")

Using device: cuda


# 1. Training

WARNING:

Candidates seems to be very unequally distributed in length.
Some are very short, some are very long. This causes a lot of padding
in the batch – and I haven't yet figured out how to deal with that (if at all).

In [13]:
epochs = 3
batch_size = 4
max_grad_norm = 1.0
# learning_rate = 2e-5
learning_rate = 2e-4
# optimizer = AdamW # see below

In [14]:
from src.data import MSMARCO

train_dl, val_dl, test_dl = MSMARCO.as_dataloaders(batch_size=batch_size)

In [15]:
from typing import cast
from IPython.display import display
from transformers import BatchEncoding
from src.data.msmarco import MSMARCOBatch
from src.models.CrossEncoderLongformer import CrossEncoderLongformer, print_encoded
import torch
import torch.optim as optim

def prep_labels(logits, cand_mask, batch: MSMARCOBatch) -> torch.Tensor:
    B, Cmax = logits.shape
    labels = torch.full((B, Cmax), -1.0, device=logits.device) # list → tensor
    # Map per-sample labels onto the True positions of cand_mask
    for b, lbls in enumerate(batch["labels"]):
        valid_idx = cand_mask[b].nonzero(as_tuple=False).squeeze(-1)  # indices where candidate exists
        n = min(len(lbls), valid_idx.numel())
        if n > 0:
            labels[b, valid_idx[:n]] = torch.as_tensor(lbls[:n], dtype=torch.float32, device=logits.device)

    return labels


# convenience function for batch tokenization
def tokenize(batch: MSMARCOBatch):
    return CrossEncoderLongformer.batch_tokenize(
        batch["queries"],
        batch["candidates"],
    ).to(device)

In [16]:
# NOTE: uncomment to see the tokenized input
# sample = cast(MSMARCOBatch, next(iter(train_dl)))
# input = tokenize(sample)
# print_encoded(input)

<font color="tomato">NOTE</font>: I ran into memory issues with the current configuration
where memory wouldn't be freed properly, accumulating _sometimes_, until the GPU ran out of
memory. I circumvented this by either throwing more VRAM at the problem – or trying until
the issue didn't occur.

The cell below helps to clean up memory a bit. If nothing helps, restart the kernel by exiting
(uncomment line 9).

The `CrossEncoderLongformer` implementation surely isn't ideal, speaking from a performance
perspective. `listnet` and other functions would need to be optimized as well. However, I
could not reproduce _what exactly_ causes the memory issues. `MSMARCO` is also not streaming,
but materializes whe n loading the dataset (~5GB).

Because of this, **training is incredibly slow**.

In [1]:
import gc
import os

gc.collect()
torch.cuda.empty_cache()
torch.cuda.memory_allocated()
torch.cuda.memory_reserved()
torch.cuda.ipc_collect()  # If using IPC
# os._exit(00) # restart jupyter kernel – VRAM often isn’t freed until the kernel is restarted.

!nvidia-smi



KeyboardInterrupt



<font color="gold">NOTE</font> about the loss:

ListNet top-1 is mathematically a CEL between two probability
distributions – two softmaxes over the predicted and target
relevance scores.

The batch loss with ListNet (top-1) seems to be stagnant around
~2.3... So the model seems to output nearly uniform distributions
for each batch. I don't know why yet, whether it's because the model
can not reach higher training performance, or if there's an inherent
issue with the loss / dataset

I guess it's because the batch size is too small (4) and the batches
are near identical (uniform), so ln(10) ≈ 2.3 is expected, where 10
is likely the average number of candidates per query, each with an
identical distribution of relevant items.

So loss would only go below ln(C) if the model somehow (=different weights?)
starts placing a higher probability on the genuine positives in view. If many
batches don’t include any positive (after truncation), gradients are near-zero
and learning is slow.

I disabled gradient clipping, because I suspect it of hurting learning
in this case.

- [ ] analyze what _exactly_ is going on in the dataset with distributions,
      truncation and padding.

In [19]:
from src.loss import listnet
import torch.nn as nn


# GPU => prefer bfloat16, so no GradScaler needed
# CPU/MPS => stick to FP32 (no autocast)
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
autocast_enabled = device.type == "cuda" # assume GPU is Ampere or later

model = CrossEncoderLongformer().to(device)
optimizer = optim.AdamW(model.parameters(), learning_rate)


for epoch in range(epochs):
    for idx, batch in enumerate(train_dl):
        optimizer.zero_grad(set_to_none=True)

        with torch.autocast(device_type=device.type, dtype=dtype, enabled=autocast_enabled):
            input = tokenize(batch)
            logits, cand_mask = model(input)
            loss = listnet(
                logits,
                prep_labels(logits, cand_mask, batch),
                mask=cand_mask,
                label_transform="exp2m1"
            )

        loss.backward()

        # total_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
        # try:
        #     total_norm_float = float(total_norm)
        # except Exception:
        #     # Fallback in rare cases
        #     total_norm_float = float(total_norm.item()) if hasattr(total_norm, "item") else float("nan")

        optimizer.step()

        print(f"epoch {epoch} batch {idx} loss {loss.item()}")

    break # stop after first epoch for demo purposes

        


epoch 0 batch 0 loss 2.303225517272949
epoch 0 batch 1 loss 2.3032426834106445
epoch 0 batch 2 loss 2.301866292953491
epoch 0 batch 3 loss 2.3041481971740723
epoch 0 batch 4 loss 2.3010330200195312
epoch 0 batch 5 loss 2.3012452125549316
epoch 0 batch 6 loss 2.3068575859069824
epoch 0 batch 7 loss 2.3022613525390625
epoch 0 batch 8 loss 2.3027734756469727
epoch 0 batch 9 loss 2.3020153045654297
epoch 0 batch 10 loss 2.303412437438965
epoch 0 batch 11 loss 2.3050854206085205
epoch 0 batch 12 loss 2.300124168395996
epoch 0 batch 13 loss 2.3022546768188477
epoch 0 batch 14 loss 2.1789731979370117
epoch 0 batch 15 loss 2.2169346809387207
epoch 0 batch 16 loss 2.303145408630371
epoch 0 batch 17 loss 2.436035633087158
epoch 0 batch 18 loss 2.3047494888305664
epoch 0 batch 19 loss 2.3032405376434326
epoch 0 batch 20 loss 2.2763569355010986
epoch 0 batch 21 loss 2.3016200065612793
epoch 0 batch 22 loss 2.3053817749023438
epoch 0 batch 23 loss 2.301638126373291
epoch 0 batch 24 loss 2.302797079

KeyboardInterrupt: 

In [None]:
# TODO: for actual inference, rewrite .rank() to take a query and candidates directly
# Then return the list.

sample = cast(MSMARCOBatch, next(iter(val_dl)))

top_k = 5 # how many top candidates to display

query = sample["queries"][0]
candidates = sample["candidates"][0]
scores = logits[:1].detach().cpu()   # shape: (1, Cmax)
mask = cand_mask[:1].detach().cpu()  # shape: (1, Cmax), dtype: bool

# FIX: if truncation dropped some candidates, align the list to valid mask entries
valid_flags = mask[0].tolist()
if sum(valid_flags) != len(candidates):
    candidates = [c for c, v in zip(candidates, valid_flags) if v]

ranked = model.rank(
    candidates,
    scores=scores,
    mask=mask,
)

print(f"Query: {query}")
print(f"Top-{top_k} candidates:")
for i, (cand, score) in enumerate(ranked[:top_k]):
    relevant = "✅" if sample["labels"][0][i] == 1 else "❌" if sample["labels"][0][i] == 0 else "?"
    print(f"{i+1:2d}. ({score:.4f}; {relevant}) {cand}")

Query: . what is a corporation?
Top-5 candidates:
 1. (0.1060; ❌) B Corp certification shines a light on the companies leading the global movement...
 2. (0.1030; ❌) 1: a government-owned corporation (as a utility or railroad) engaged in a profit-making enterprise that may require the exercise of powers unique to government (as eminent domain) — called also government corporation, publicly held corporation
 3. (0.1021; ❌) Corporation definition, an association of individuals, created by law or under authority of law, having a continuous existence independent of the existences of its members, and powers and liabilities distinct from those of its members. See more.
 4. (0.1011; ❌) An Association is an organized group of people who share in a common interest, activity, or purpose. 1  Start a business Plan your business. Create your business structure Types of business structures. 2  Change or update your business Add a new location to your existing business. Add an endorsement to your exi