# Training

## Installation

In [None]:
!pip install sentence-transformers==3.4.1

Collecting sentence-transformers==3.4.1
  Downloading sentence_transformers-3.4.1-py3-none-any.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers==3.4.1)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers==3.4.1)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers==3.4.1)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers==3.4.1)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transformers==3.4.1)
  Downloading nvidia_cublas

In [None]:
!pip install pytrec_eval

Collecting pytrec_eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytrec_eval
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
  Created wheel for pytrec_eval: filename=pytrec_eval-0.5-cp311-cp311-linux_x86_64.whl size=308657 sha256=88ed1b1bf7394d39d259a829dc6f1dd9241443792abcd83d725adb26b08be0ad
  Stored in directory: /root/.cache/pip/wheels/0f/89/42/86aecdb99975f1840c27bc37fdfed72116abcf82e2c9dc76a8
Successfully built pytrec_eval
Installing collected packages: pytrec_eval
Successfully installed pytrec_eval-0.5


## Imports

In [None]:
"""
This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).

The query and the passage are passed simoultanously to a Transformer network. The network then returns
a score between 0 and 1 how relevant the passage is for a given query.

The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages
for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder
for scoring. You sort the results then according to the output of the CrossEncoder.

This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking.
"""
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import InputExample
from datetime import datetime
import gzip
import os
import tarfile
import tqdm
import logging
from collections import defaultdict
import numpy as np
import sys
import pytrec_eval
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(message)s',datefmt='%Y-%m-%d %H:%M:%S')

## Training preparation

### Initialize hyperparameters (e.g., batch size, etc)

#### To prevent from losing the trained model because of getting disconnected from google colab, we suggest you to store trained model on your google drive. In below we do that by loading google.colab and set the path.


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
base_path = "./gdrive/MyDrive/cross-encoder-reranker-ir-course-2025model/"

Mounted at /content/gdrive


In [None]:
!mkdir -p $base_path

In [None]:
#First, we define the transformer model we want to fine-tune

train_batch_size = 32
num_epochs = 1
# We train the network with as a binary label task
# Given [query, passage] is the label 0 = irrelevant or 1 = relevant?
# We use a positive-to-negative ratio: For 1 positive sample (label 1) we include 4 negative samples (label 0)
# in our training setup. For the negative samples, we use the triplets provided by MS Marco that
# specify (query, positive sample, negative sample).
pos_neg_ration = 4

# Maximal number of training samples we want to use
max_train_samples = 5e6 #2e7

## Load model (cross-encoder/ms-marco-MiniLM-L-2-v2)

In [None]:
#We set num_labels=1, which predicts a continous score between 0 and 1
model_name = 'cross-encoder/ms-marco-MiniLM-L-2-v2'
#model_name = 'cross-encoder/ms-marco-TinyBERT-L-2-v2'
#model_name = 'distilroberta-base'
model = CrossEncoder(model_name, num_labels=1, max_length=512)
model_save_path = base_path  +'finetuned_models/cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/config.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L2-v2/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 134279846297744 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/080681a8d63930d920d45b6763dc48090f080f79.lock
DEBUG:filelock:Lock 134279846297744

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 134279846297744 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/080681a8d63930d920d45b6763dc48090f080f79.lock
DEBUG:filelock:Lock 134279846297744 released on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/080681a8d63930d920d45b6763dc48090f080f79.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/adapter_config.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L2-v2/resolve/main/adapter_config.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/model.safetensors HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L2-v2/resolve/main/model.safetensors HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 134

model.safetensors:   0%|          | 0.00/62.5M [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 134273341506128 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/88f11fa671e11c53b5cfe88bb6594139ec4991eaf8cf6a10bd61c9abbc4f691a.lock
DEBUG:filelock:Lock 134273341506128 released on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/88f11fa671e11c53b5cfe88bb6594139ec4991eaf8cf6a10bd61c9abbc4f691a.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/tokenizer_config.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L2-v2/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 134273339893200 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/a2435fedfac32b9ad70f052d4f84007730cd3109.lock
DEBUG:filelock:Lock 134273339893200 acquired on /root/.cache/huggingface/hub/.locks/models--cross

tokenizer_config.json:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 134273339893200 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/a2435fedfac32b9ad70f052d4f84007730cd3109.lock
DEBUG:filelock:Lock 134273339893200 released on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/a2435fedfac32b9ad70f052d4f84007730cd3109.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/vocab.txt HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L2-v2/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 134273322187152 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:filelock:Lock 134273322187152 acquired on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/fb140275c155a9c7c5a3b3e0e77a9e839594a93

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 134273322187152 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:filelock:Lock 134273322187152 released on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/tokenizer.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L2-v2/resolve/main/tokenizer.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 134273322175248 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/688882a79f44442ddc1f60d70334a7ff5df0fb47.lock
DEBUG:filelock:Lock 134273322175248 acquired on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/688882a79f44442ddc1f60d70334a

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 134273322175248 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/688882a79f44442ddc1f60d70334a7ff5df0fb47.lock
DEBUG:filelock:Lock 134273322175248 released on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/688882a79f44442ddc1f60d70334a7ff5df0fb47.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/added_tokens.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L2-v2/resolve/main/added_tokens.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/special_tokens_map.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L2-v2/resolve/main/special_tokens_map.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire 

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 134273320860752 on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/7520992f25914d962f0e2fd0e0566fc33d19ec59.lock
DEBUG:filelock:Lock 134273320860752 released on /root/.cache/huggingface/hub/.locks/models--cross-encoder--ms-marco-MiniLM-L-2-v2/7520992f25914d962f0e2fd0e0566fc33d19ec59.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/chat_template.jinja HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L2-v2/resolve/main/chat_template.jinja HTTP/1.1" 404 0
INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda


## Download MSMARCO data + BM25 initial ranking run file

In [None]:
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
!tar -xvzf  queries.tar.gz

--2025-05-20 20:18:25--  https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
Resolving msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)... 20.150.34.1
Connecting to msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)|20.150.34.1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18882551 (18M) [application/gzip]
Saving to: ‘queries.tar.gz’


2025-05-20 20:18:25 (39.7 MB/s) - ‘queries.tar.gz’ saved [18882551/18882551]

queries.dev.tsv
queries.eval.tsv
queries.train.tsv


In [None]:
### Now we read the MS Marco dataset
data_folder = 'msmarco-data'
os.makedirs(data_folder, exist_ok=True)


#### Read the corpus files, that contain all the passages. Store them in the corpus dict
corpus = {}
collection_filepath = os.path.join(data_folder, 'collection.tsv')
if not os.path.exists(collection_filepath):
    tar_filepath = os.path.join(data_folder, 'collection.tar.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download collection.tar.gz")
        util.http_get('https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=data_folder)

with open(collection_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        pid, passage = line.strip().split("\t")
        corpus[pid] = passage


### Read the train queries, store in queries dict
queries = {}
queries_filepath = os.path.join('queries.train.tsv')
with open(queries_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        queries[qid] = query



### Now we create our training & dev data
train_samples = []
dev_samples = {}

# We use 200 random queries from the train set for evaluation during training
# Each query has at least one relevant and up to 200 irrelevant (negative) passages
num_dev_queries = 200
num_max_dev_negatives = 200

# msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz and msmarco-qidpidtriples.rnd-shuf.train.tsv.gz is a randomly
# shuffled version of qidpidtriples.train.full.2.tsv.gz from the MS Marco website
# We extracted in the train-eval split 500 random queries that can be used for evaluation during training
train_eval_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz')
if not os.path.exists(train_eval_filepath):
    logging.info("Download "+os.path.basename(train_eval_filepath))
    util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', train_eval_filepath)

with gzip.open(train_eval_filepath, 'rt') as fIn:
    for line in fIn:
        qid, pos_id, neg_id = line.strip().split()

        if qid not in dev_samples and len(dev_samples) < num_dev_queries:
            dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}

        if qid in dev_samples:
            dev_samples[qid]['positive'].add(corpus[pos_id])

            if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
                dev_samples[qid]['negative'].add(corpus[neg_id])


# Read our training file
train_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train.tsv.gz')
if not os.path.exists(train_filepath):
    logging.info("Download "+os.path.basename(train_filepath))
    util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train.tsv.gz', train_filepath)

cnt = 0
with gzip.open(train_filepath, 'rt') as fIn:
    for line in tqdm.tqdm(fIn, unit_scale=True):
        qid, pos_id, neg_id = line.strip().split()

        if qid in dev_samples:
            continue

        query = queries[qid]
        if (cnt % (pos_neg_ration+1)) == 0:
            passage = corpus[pos_id]
            label = 1
        else:
            passage = corpus[neg_id]
            label = 0

        train_samples.append(InputExample(texts=[query, passage], label=label))
        cnt += 1

        if cnt >= max_train_samples:
            break

INFO:root:Download collection.tar.gz
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): msmarco.z22.web.core.windows.net:443
DEBUG:urllib3.connectionpool:https://msmarco.z22.web.core.windows.net:443 "GET /msmarcoranking/collection.tar.gz HTTP/1.1" 200 1035009698


  0%|          | 0.00/1.04G [00:00<?, ?B/s]

INFO:root:Download msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): sbert.net:443
DEBUG:urllib3.connectionpool:https://sbert.net:443 "GET /datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz HTTP/1.1" 301 None
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): public.ukp.informatik.tu-darmstadt.de:443
DEBUG:urllib3.connectionpool:https://public.ukp.informatik.tu-darmstadt.de:443 "GET /reimers/sentence-transformers/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz HTTP/1.1" 200 2313734


  0%|          | 0.00/2.31M [00:00<?, ?B/s]

INFO:root:Download msmarco-qidpidtriples.rnd-shuf.train.tsv.gz
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): sbert.net:443
DEBUG:urllib3.connectionpool:https://sbert.net:443 "GET /datasets/msmarco-qidpidtriples.rnd-shuf.train.tsv.gz HTTP/1.1" 301 None
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): public.ukp.informatik.tu-darmstadt.de:443
DEBUG:urllib3.connectionpool:https://public.ukp.informatik.tu-darmstadt.de:443 "GET /reimers/sentence-transformers/datasets/msmarco-qidpidtriples.rnd-shuf.train.tsv.gz HTTP/1.1" 200 4414877667


  0%|          | 0.00/4.41G [00:00<?, ?B/s]

5.00Mit [00:34, 147kit/s]


## Initialize dataloader

In [None]:
# We create a DataLoader to load our train samples
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

## Initialize CERerankingEvaluator Class
### The CERerankingEvaluator class evaluates the model after every 1k steps of training on the validation set
### Currently, CERerankingEvaluator computes MRR@10 on the valiadion set. You need to change MRR@10 to NDCG@10 for Exercise 4.
###For that, you can download the CERerankingEvaluator class ([link](https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/cross_encoder/evaluation/CERerankingEvaluator.py)) and upload the modified implementation to the brightspace.




In [None]:
# We add an evaluator, which evaluates the performance during training
# It performs a classification task and measures scores like F1 (finding relevant passages) and Average Precision
evaluator = CERerankingEvaluator(dev_samples, name='train-eval')

## Train the model
### You can stop the training after one hour by stopping the run

In [None]:
# Train the model
model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=5000,
          output_path=model_save_path,
          use_amp=True)

NameError: name 'model' is not defined

In [None]:
model.save(model_save_path)
print("✅ Model saved to:", model_save_path)

INFO:sentence_transformers.cross_encoder.CrossEncoder:Save model to ./gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/finetuned_models/cross-encoder-cross-encoder-ms-marco-TinyBERT-L-2-v2-2025-05-19_21-08-20


✅ Model saved to: ./gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/finetuned_models/cross-encoder-cross-encoder-ms-marco-TinyBERT-L-2-v2-2025-05-19_21-08-20
