<a href="https://colab.research.google.com/github/kasey-purvor/Latent_semantic_index_SearchEngine/blob/SBERT-Training/SBERT_TRAINING.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **SBERT Training**

In the MS MARCO Passage Ranking (and similarly Document Ranking) datasets, you’ll find multiple files designated as collection, queries, qrels train, and qrels dev. They serve different roles in creating a training/evaluation pipeline for information retrieval. Here’s the key difference:

1. **Collection**

* What It Is: A large file that contains all the text in the dataset – specifically, each passage (for Passage Ranking) or each document (for Document Ranking).
* Typical Format: For the Passage Ranking set, you’ll see something like pid\t passage_text. Each row links a passage ID (pid) to its actual text.
* Usage: You use this to look up the actual passage text, given its ID.

2. **Queries**

* What It Is: A list of user queries (often real or anonymized user questions).
* Typical Format: Usually qid\t query_text, mapping a query ID (qid) to the textual query.
* Usage: When training or evaluating an IR system, you’ll retrieve passages for these queries.

3. **Qrels.train**

* What It Is: The relevance judgments (so-called “qrels,” short for query relevance set) for the training queries.
* Typical Format: In TREC format, something like qid 0 pid relevance_label.
* Usage: Tells you which (qid, pid) pairs are actually relevant. If you’re doing pointwise training, you treat these as positive examples (label=1). Any other (qid, pid) that appears in your candidate set but is not in qrels is treated as negative (label=0).

4. **Qrels.dev**

* What It Is: The relevance judgments (qrels) for the development/validation queries.
* Usage: Allows you to evaluate your model on a held-out dev set. For each dev query, you can see which passages are truly relevant (label=1), and compute metrics like MRR or nDCG.
* Similar Format: qid 0 pid relevance_label.

**Putting It All Together**
* collection gives you the text for each passage.
* queries gives you the text for each user query.
* qrels.train maps training queries to their relevant passages (positives).
* qrels.dev does the same for dev (validation) queries, letting you measure how well your system ranks the correct passages on unseen data.

**In practice, you:**

1. Load the passages from collection into a dictionary: pid -> passage_text.
2. Load queries from queries as qid -> query_text.
3. Use qrels.train to form (query, passage, label) training samples:
* Where label=1 if (qid, pid) is marked as relevant in qrels, and label=0 otherwise (or treat them as negative).
4. When you’re done training, check performance on the dev set using qrels.dev. This dev set is smaller and separate, so you can measure how well your re-ranker generalizes before final testing.

#**Step-by-Step Approach in this Notebook:**
1. Prepare Training Data
2. Load a pre-trained SBERT model
3. Convert data for SBERT training
4. Train SBERT
5. Evaluate model
6. Save the fine-tuned model ready to be used post FAISS

In [None]:
pip install -U sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-4.0.2-py3-none-any.whl.metadata (13 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_

In [2]:
!pip install ftfy

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1


In [None]:
import pandas as pd
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [5]:
file_path = "qrels.train.tsv"
file_path = '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/qrels.train.tsv'

df_qrels = pd.read_csv(
    file_path,
    sep='\t',           # tab-separated
    header=None,
    names=["qid", "unused", "pid", "rel"]  # column names for clarity
)

df_qrels.head()

Unnamed: 0,qid,unused,pid,rel
0,1185869,0,0,1
1,1185868,0,16,1
2,597651,0,49,1
3,403613,0,60,1
4,1183785,0,389,1


In [6]:
df_qrels.shape

(532761, 4)

4 columns, 532,761 rows

* qid is query id
* pid is passage id
* rel is relevance label

In [7]:
import pandas as pd

#file_path = "queries.train.tsv"
file_path = '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/queries.train.tsv'

df_queries = pd.read_csv(
    file_path,
    sep='\t',
    header=None,
    names=["qid", "query_text"]
)

df_queries.head()

Unnamed: 0,qid,query_text
0,121352,define extreme
1,634306,what does chattel mean on credit history
2,920825,what was the great leap forward brainly
3,510633,tattoo fixers how much does it cost
4,737889,what is decentralization process.


In [8]:
df_queries.shape

(808731, 2)

In [9]:
df_collection = pd.read_csv(
    '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/collection.tsv',
    sep='\t',
    header=None,
    names=['pid', 'passage']
)
df_collection.head()


Unnamed: 0,pid,passage
0,0,The presence of communication amid scientific ...
1,1,The Manhattan Project and its atomic bomb help...
2,2,Essay on The Manhattan Project - The Manhattan...
3,3,The Manhattan Project was the name for a proje...
4,4,versions of each volume as well as complementa...


In [10]:
df_collection.shape

(8841823, 2)

In [11]:
df_collection.loc[49, "passage"]

'Colorâ\x80\x94urine can be a variety of colors, most often shades of yellow, from very pale or colorless to very dark or amber. Unusual or abnormal urine colors can be the result of a disease process, several medications (e.g., multivitamins can turn urine bright yellow), or the result of eating certain foods.'

^As you can see there is a problem with formatting in some entries in the data. I need to clean it

In [12]:
from ftfy import fix_text

# Fix encoding in 'passage'
df_collection["passage"] = df_collection["passage"].apply(fix_text)

# Check result
print(df_collection.loc[49, "passage"])

KeyboardInterrupt: 

In [None]:
print(df_collection.loc[49, "passage"])

Above you can see that the passages were fixed. Let's now do same for query text.

In [None]:
from ftfy import fix_text

# Fix encoding in 'query text'
df_queries["query_text"] = df_queries["query_text"].apply(fix_text)

In [None]:
df_queries.head()

In [None]:
print(df_queries.loc[121352, "query_text"])

##**Now we can merge all tables into a dataframe:**

In [None]:
# Merge qrels with queries on 'qid'
df_merged = pd.merge(df_qrels, df_queries, on='qid', how='left')

# Merge the resulting positives with the collection on 'pid'
df_merged = pd.merge(df_merged, df_collection, on='pid', how='left')

df_merged.drop(["unused"], axis=1, inplace=True)

df_merged.head()

In [None]:
df_merged.shape

saving this merged df:

In [None]:
df_merged.to_csv('df_merged.csv', index=False)
from google.colab import files
files.download('df_merged.csv')

In [None]:
df_merged = pd.read_csv('df_merged.csv')

Now I don't need to run all the code before this cell, i can just import df_merged.csv

You can see that the dataframe has 532,761 rows (data points)

**The dataframe above cannot yet be used for training the model, as it contains queries, passages and positive relevance scores. I need to add negative relevance scores**

In [None]:
pip install joblib


In [None]:
# all_pids is the list of all passage IDs from the collection
all_pids = df_collection['pid'].unique().tolist()


In [None]:
df_merged.head()

In [None]:
df_merged['rel'].unique()

Checking how many CPU cores i have

In [None]:
import multiprocessing

num_cores = multiprocessing.cpu_count()
print("Number of CPU cores:", num_cores)

**Negative Sampling:**

MS Marco Dataset only contains query-passage pairs with positive relevancy, so negative (0) relevancy pairs have to be manually generated. Here's the plan:

Progressive approach:
* Start with generating 50,000 negative pairs (by increasing queries to ~5,000 and negatives per query to 10)
* Then subsample the positives to 50,000
* This gives you a balanced dataset of 100,000 examples, which is sufficient for initial

In [None]:
import random
import pandas as pd

################################
# 0) Ensure keys are of the same type across DataFrames and all_pids
################################
# Convert key columns to string for consistency
df_qrels['qid'] = df_qrels['qid'].astype(str)
df_qrels['pid'] = df_qrels['pid'].astype(str)
df_queries['qid'] = df_queries['qid'].astype(str)
df_collection['pid'] = df_collection['pid'].astype(str)

# Ensure all_pids are strings as well
all_pids = [str(pid) for pid in all_pids]

# Debug: Print data types to verify consistency
print("df_qrels dtypes:")
print(df_qrels.dtypes)
print("\ndf_queries dtypes:")
print(df_queries.dtypes)
print("\ndf_collection dtypes:")
print(df_collection.dtypes)
print("\nType of first element in all_pids:", type(all_pids[0]))

################################
# 1) Sample 5000 QUERIES to generate 50k negatives
################################
all_qids = df_qrels["qid"].unique().tolist()
random.shuffle(all_qids)

subsample_size = 5000  # Target ~5000 queries
sub_qids = all_qids[:subsample_size]
print(f"Selected {len(sub_qids)} unique queries for negative sampling")

################################
# 2) Process in reasonable chunks
################################
chunk_size = 100
chunks = [sub_qids[i : i + chunk_size] for i in range(0, len(sub_qids), chunk_size)]
print(f"Split into {len(chunks)} chunks of {chunk_size} queries each")

################################
# Build a dictionary of positives for each qid
################################
pos_dict = {}
for row in df_qrels.itertuples(index=False):
    q = row.qid
    p = row.pid
    if q not in pos_dict:
        pos_dict[q] = set()
    pos_dict[q].add(p)

################################
# Improved negative sampling function
################################
def sample_negatives_for_qid(qid, max_samples=10):
    """
    For a given qid, efficiently sample negative passages.

    Args:
        qid: The query ID
        max_samples: Maximum number of negative samples to generate

    Returns:
        List of (qid, pid, 0) tuples representing negative samples
    """
    pos_pids = pos_dict.get(qid, set())

    # Fixed number of negatives per query to prevent excessive sampling
    neg_samples = []
    attempts = 0
    max_attempts = max_samples * 20  # Allow more failed attempts

    while len(neg_samples) < max_samples and attempts < max_attempts:
        attempts += 1
        # Pick a random pid from all_pids
        pid = random.choice(all_pids)
        # Only add if it's not a positive for this query
        if pid not in pos_pids and (qid, pid, 0) not in neg_samples:
            neg_samples.append((qid, pid, 0))

    return neg_samples

################################
# 3) Process each chunk sequentially with progress tracking
################################
all_neg_samples = []  # Global list to store negative samples
target_negatives = 50000
negatives_per_query = min(10, target_negatives // len(sub_qids) + 1)

print(f"Target: {target_negatives} negatives at {negatives_per_query} per query")

for idx, chunk_qids in enumerate(chunks):
    print(f"Processing chunk {idx+1}/{len(chunks)}: {len(chunk_qids)} queries")

    # Process queries sequentially
    for i, qid in enumerate(chunk_qids):
        neg_samples = sample_negatives_for_qid(qid, max_samples=negatives_per_query)
        all_neg_samples.extend(neg_samples)

        # Progress update every 50 queries
        if (i+1) % 50 == 0 or i+1 == len(chunk_qids):
            print(f"  -> Processed {i+1}/{len(chunk_qids)} queries in current chunk")
            print(f"  -> Total negatives so far: {len(all_neg_samples)}")

    # Check if we've reached our target
    if len(all_neg_samples) >= target_negatives:
        print(f"Reached target of {target_negatives} negatives. Stopping.")
        break

################################
# 4) Convert negative pairs to DataFrame
################################
df_neg = pd.DataFrame(all_neg_samples[:target_negatives], columns=["qid", "pid", "rel"])
print("\nTotal negative pairs:", len(df_neg))
print("Sample of negative pairs:")
print(df_neg.head())

################################
# 5) Merge negative pairs with query and passage text
################################
print("Merging negative pairs with query and passage text...")
df_neg_merged = pd.merge(df_neg, df_queries, on='qid', how='left')
df_neg_merged = pd.merge(df_neg_merged, df_collection, on='pid', how='left')

################################
# 6) Subsample positives to match number of negatives
################################
print(f"Subsampling {len(df_neg)} positives from {len(df_merged)} total positives...")
df_pos_subsampled = df_merged.sample(n=len(df_neg), random_state=42)

################################
# 7) Create final balanced dataset
################################
df_balanced = pd.concat([df_pos_subsampled, df_neg_merged], ignore_index=True)
df_balanced = df_balanced.sample(frac=1.0, random_state=42)  # Shuffle the dataset

print("\nFinal balanced dataset size:", df_balanced.shape[0])
print("Class distribution:")
print(df_balanced["rel"].value_counts())
print("\nSample of the balanced dataset:")
print(df_balanced.head(5))

# Save the balanced dataset
df_balanced.to_csv('balanced_dataset_100k.csv', index=False)
print("Balanced dataset saved to 'balanced_dataset_100k.csv'")

df_qrels dtypes:
qid       object
unused     int64
pid       object
rel        int64
dtype: object

df_queries dtypes:
qid           object
query_text    object
dtype: object

df_collection dtypes:
pid        object
passage    object
dtype: object

Type of first element in all_pids: <class 'str'>
Selected 5000 unique queries for negative sampling
Split into 50 chunks of 100 queries each
Target: 50000 negatives at 10 per query
Processing chunk 1/50: 100 queries
  -> Processed 50/100 queries in current chunk
  -> Total negatives so far: 500
  -> Processed 100/100 queries in current chunk
  -> Total negatives so far: 1000
Processing chunk 2/50: 100 queries
  -> Processed 50/100 queries in current chunk
  -> Total negatives so far: 1500
  -> Processed 100/100 queries in current chunk
  -> Total negatives so far: 2000
Processing chunk 3/50: 100 queries
  -> Processed 50/100 queries in current chunk
  -> Total negatives so far: 2500
  -> Processed 100/100 queries in current chunk
  -> Total 

In [None]:
df_balanced

Unnamed: 0,qid,pid,rel,query_text,passage
75721,1035454,8745608,0,who is louis chevrolet,Nocona Western Belt Mens Leather Hair... Mens ...
80184,545923,792048,0,weather in september in santorini,The primary coordinate point for Mosquero is l...
19864,1150869,1505975,1,what is the function of the macrophages in the...,function of alveolar macrophagesThe function o...
76699,886769,3020855,0,what phones use adaptive fast charging technology,A normal amount of potassium in a typical diet...
92991,749145,2601929,0,what is french press coffee?,1 Antigens bind to B cells. 2 Interleukins or...
...,...,...,...,...,...
6265,54434,2472763,1,blind loop syndrome symptoms,The list of signs and symptoms mentioned in va...
54886,1036591,6230802,0,who is nominated olympic basketball assistant ...,iPAS stands for Internet Prospect Acceleration...
76820,581086,4795589,0,what can cause acute kidney injury,Artist concept of the asteroid belt. Credit: N...
860,52116,2455570,1,best thing to eat after a colonic,The best foods to eat after a colon cleanse ar...


In [None]:
df_negative_pairs = df_balanced[df_balanced['rel'] == 0]
df_positive_pairs = df_balanced[df_balanced['rel'] == 1]

In [None]:
df_negative_pairs

Unnamed: 0,qid,pid,rel,query_text,passage
75721,1035454,8745608,0,who is louis chevrolet,Nocona Western Belt Mens Leather Hair... Mens ...
80184,545923,792048,0,weather in september in santorini,The primary coordinate point for Mosquero is l...
76699,886769,3020855,0,what phones use adaptive fast charging technology,A normal amount of potassium in a typical diet...
92991,749145,2601929,0,what is french press coffee?,1 Antigens bind to B cells. 2 Interleukins or...
76434,1155159,4684577,0,what is finpro,"Updated June 16, 2016. Cervicalgia is a term u..."
...,...,...,...,...,...
60263,903070,6965569,0,what tests are performed at st peters in alban...,Most of us realized that in our addiction we w...
87498,909178,2637040,0,what treatments are available for hepatitis c,When choosing the location of a park focused o...
82386,147120,1990522,0,difference between dry cough and wet cough,How to Find Zeros of a Function. f(x) = -2 x +...
54886,1036591,6230802,0,who is nominated olympic basketball assistant ...,iPAS stands for Internet Prospect Acceleration...


In [None]:
df_positive_pairs

Unnamed: 0,qid,pid,rel,query_text,passage
19864,1150869,1505975,1,what is the function of the macrophages in the...,function of alveolar macrophagesThe function o...
27701,427791,2240321,1,is the word documentation singular or plural,The noun documentation is singular. The noun d...
42141,1003258,2439205,1,where was richard nixon born,Mini Bio (1). Richard Nixon was born on Januar...
45080,978868,4904002,1,where is brutus ny,"Brutus is a town in Cayuga County, New York, U..."
16638,930527,4507932,1,what's in mincemeat pie filling,Mincemeat Pie Filling. a true mincemeat pie co...
...,...,...,...,...,...
44131,703575,1568193,1,what is a time warner p700 error,Re: Twc app on roku error code:p700. ‎04-30-20...
37194,687655,5620458,1,what is a ipod,Apple's iPod is a small portable music player....
6265,54434,2472763,1,blind loop syndrome symptoms,The list of signs and symptoms mentioned in va...
860,52116,2455570,1,best thing to eat after a colonic,The best foods to eat after a colon cleanse ar...


You can see both of the classes have 50,000 pairs. This dataframe can be used to train our model.

In [None]:
from google.colab import files

#Save df_all to a CSV file in the Colab environment
df_balanced.to_csv('df_balanced.csv', index=False)

#Download the CSV file to your local machine
files.download('df_balanced.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [14]:
file_path = '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/df_balanced.csv'

df_balanced = pd.read_csv(file_path)

df_balanced.head(5)

Unnamed: 0,qid,pid,rel,query_text,passage
0,1147448,3292484,0,what law was put into place to end child labor,The basic idea is to use a sentence structure ...
1,1146837,1911754,0,what pay range is considered middle class,A: Clinical signs in humans usually develop wi...
2,1150869,1505975,1,what is the function of the macrophages in the...,function of alveolar macrophagesThe function o...
3,525889,6063694,0,two types of nucleic acids viruses may have,→ دِبْلُوماسيّ diplomat diplomat Diplomat διπλ...
4,34120,6965147,0,average cost of new home construction,People with Down syndrome may have a variety o...


In [15]:
print(df_balanced['rel'].value_counts())

rel
0    50000
1    50000
Name: count, dtype: int64


confirming class balance

# Choosing and importing a relevant SBERT Model:

In [16]:
!pip install evaluate


Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from evaluate)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.17-py311-none-any.whl.metadata (7.2 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m 

In [17]:
!pip install datasets
!pip install transformers
!pip install accelerate -U
!pip install transformers[torch]
!pip install wandb

Collecting accelerate
  Downloading accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.6.0-py3-none-any.whl (354 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m354.7/354.7 kB[0m [31m23.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 1.5.2
    Uninstalling accelerate-1.5.2:
      Successfully uninstalled accelerate-1.5.2
Successfully installed accelerate-1.6.0


In [20]:
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import multiprocessing
from datasets import Dataset
from transformers import AutoTokenizer

import evaluate
import torch

In [19]:
from sentence_transformers import InputExample
from torch.utils.data import DataLoader
import random

# 1. Split the balanced dataset into train and validation sets
train_df, val_df = train_test_split(
    df_balanced,
    test_size=0.1,
    random_state=42,
    stratify=df_balanced['rel']  # Maintain class balance in splits
)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")

# 2. Convert DataFrames to lists of InputExample objects
def df_to_input_examples(df):
    examples = []
    for _, row in df.iterrows():
        # SBERT expects InputExample objects with texts and a label
        examples.append(InputExample(
            texts=[row['query_text'], row['passage']],
            label=float(row['rel'])
        ))
    return examples

train_examples = df_to_input_examples(train_df)
val_examples = df_to_input_examples(val_df)

# Sample some examples to verify
print("\nSample training examples:")
for i in range(3):
    ex = random.choice(train_examples)
    print(f"Query: {ex.texts[0][:50]}...")
    print(f"Passage: {ex.texts[1][:50]}...")
    print(f"Label: {ex.label}\n")

Training samples: 90000
Validation samples: 10000

Sample training examples:
Query: which phase of meiosis is a ceavage furrow...
Passage: Cytokinesis is not a phase of mitosis but rather a...
Label: 1.0

Query: distance from volcano hi to kona hi...
Passage: Finding Your Way There: Hawai`i Volcanoes National...
Label: 1.0

Query: what does miscegenation mean...
Passage: Medical Definition of miscegenation. : a mixture o...
Label: 1.0



In [None]:

model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-v4')
model = model.to('cuda')  #GPU acceleration

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)

# Define loss
train_loss = losses.CosineSimilarityLoss(model=model)

# Create evaluator for validation set
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(val_examples, name='val-eval')

# Fine-tune the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=evaluator,
    epochs=3,
    evaluation_steps=100,  # Validation every 100 steps
    output_path='/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/SBERT_MODEL/'
)


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss,Val-eval Pearson Cosine,Val-eval Spearman Cosine
100,No log,No log,0.966384,0.86569
200,No log,No log,0.967562,0.865752
300,No log,No log,0.968365,0.86581
400,No log,No log,0.968413,0.865849
500,0.059500,No log,0.968424,0.865871
600,0.059500,No log,0.968458,0.865884
700,0.059500,No log,0.968582,0.865896
800,0.059500,No log,0.968688,0.8659
900,0.059500,No log,0.968811,0.865897
1000,0.040000,No log,0.968461,0.865881


In [None]:
model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-v4')
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.CosineSimilarityLoss(model)

model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=1,
    output_path = '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/SBERT_MODEL/'  # <-- your output directory
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0597
1000,0.0411
1500,0.0337
2000,0.0317
2500,0.0298
3000,0.0282
3500,0.026
4000,0.0253
4500,0.0256
5000,0.0249


**Training and monitoring evaluation:**

In [None]:
# Define a dev evaluator (measures ranking accuracy)
evaluator = evaluation.BinaryClassificationEvaluator(
    sentences1=val_df['query_text'].tolist(),
    sentences2=val_df['passage'].tolist(),
    labels=val_df['rel'].tolist(),
    name='msmarco-dev'
)

# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=evaluator,
    epochs=3,  # Start with 3 epochs (adjust based on eval performance)
    warmup_steps=1000,
    output_path='./sbert-reranker',  # Save model here
    evaluation_steps=5000,  # Evaluate every 5k steps
    show_progress_bar=True
)

**Post-Training Validation:**
* Load the official test set - qrels.dev.tsv

In [None]:
#Read qrels.dev.tsv
test_qrels = pd.read_csv(
    '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/qrels.dev.tsv',
    sep='\t',
    header=None,
    names=['qid', 'unused', 'pid', 'rel']  # "unused" corresponds to the '0' column
)

# read queries.dev.tsv
test_queries = pd.read_csv(
    '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/queries.dev.tsv',
    sep='\t',
    header=None,
    names=['qid', 'query_text']
)

For final validation we are using the MS MARCO's dev set, which was not used for training or fine-tuning model parameters. In theory it should serve as an unbiased performance assessment.

**When i run the BinaryClassificationEvaluator, it does the following:**
1. Feeds each pair to your model to get a similarity score (e.g., cosine similarity between embeddings).

2. Compares that predicted score vs. the ground-truth rel label for classification metrics (like accuracy, F1, or AUC).

In [None]:
# 2. Merge qrels.dev with queries.dev on 'qid'
#    Resulting columns: [qid, unused, pid, rel, query_text]
test_merged = pd.merge(test_qrels, test_queries, on='qid', how='left')

# 4. Merge again with collection on 'pid' to get the passage text
#    Resulting columns: [qid, unused, pid, rel, query_text, passage]
test_merged = pd.merge(test_merged, df_collection, on='pid', how='left')

# 5. (Optional) drop the 'unused' column
test_merged.drop(columns=['unused'], inplace=True)

print("Dev set (test) shape:", test_merged.shape)
print(test_merged.head())


Ranking:

In [None]:
import numpy as np


def compute_mrr_for_query(sorted_rels):
    for rank, label in enumerate(sorted_rels, start=1):
        if label == 1:
            return 1.0 / rank
    return 0.0

# Now the rest of your loop that calls compute_mrr_for_query
import numpy as np
...
mrr_val = compute_mrr_for_query(sorted_rels)
...

# 1) Load your trained model
model = SentenceTransformer('./sbert-reranker')

# 2) Group test_merged by qid
unique_qids = test_merged['qid'].unique()

all_mrrs = []  # We'll store MRR for each query

for qid in unique_qids:
    sub_df = test_merged[test_merged['qid'] == qid]
    query_text = sub_df.iloc[0]['query_text']
    passages = sub_df['passage'].tolist()
    rel_labels = sub_df['rel'].tolist()

    # 3) Encode query and passages
    query_emb = model.encode([query_text])[0]
    passage_embs = model.encode(passages)

    # 4) Compute similarity scores (dot-product or cosine)
    #    Dot-product example:
    sims = np.array([np.dot(query_emb, p_emb) for p_emb in passage_embs])
    #    Or for cosine:
    # from numpy import linalg as LA
    # query_emb_norm = query_emb / LA.norm(query_emb)
    # passage_embs_norm = passage_embs / LA.norm(passage_embs, axis=1, keepdims=True)
    # sims = np.einsum('ij,j->i', passage_embs_norm, query_emb_norm)

    # 5) Sort passages by descending similarity
    sorted_indices = np.argsort(-sims)
    sorted_rels = [rel_labels[i] for i in sorted_indices]

    # 6) Compute MRR for this query
    #    MRR is 1 / the rank of the first relevant document
    mrr_val = compute_mrr_for_query(sorted_rels)
    all_mrrs.append(mrr_val)

# Average across all queries
final_mrr = np.mean(all_mrrs)
print(f"Mean Reciprocal Rank (MRR) across all dev queries: {final_mrr:.4f}")
