<a href="https://colab.research.google.com/github/kasey-purvor/Latent_semantic_index_SearchEngine/blob/SBERT-Training/SBERT_TRAINING.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **SBERT Training**

In the MS MARCO Passage Ranking (and similarly Document Ranking) datasets, you’ll find multiple files designated as collection, queries, qrels train, and qrels dev. They serve different roles in creating a training/evaluation pipeline for information retrieval. Here’s the key difference:

1. **Collection**

* What It Is: A large file that contains all the text in the dataset – specifically, each passage (for Passage Ranking) or each document (for Document Ranking).
* Typical Format: For the Passage Ranking set, you’ll see something like pid\t passage_text. Each row links a passage ID (pid) to its actual text.
* Usage: You use this to look up the actual passage text, given its ID.

2. **Queries**

* What It Is: A list of user queries (often real or anonymized user questions).
* Typical Format: Usually qid\t query_text, mapping a query ID (qid) to the textual query.
* Usage: When training or evaluating an IR system, you’ll retrieve passages for these queries.

3. **Qrels.train**

* What It Is: The relevance judgments (so-called “qrels,” short for query relevance set) for the training queries.
* Typical Format: In TREC format, something like qid 0 pid relevance_label.
* Usage: Tells you which (qid, pid) pairs are actually relevant. If you’re doing pointwise training, you treat these as positive examples (label=1). Any other (qid, pid) that appears in your candidate set but is not in qrels is treated as negative (label=0).

4. **Qrels.dev**

* What It Is: The relevance judgments (qrels) for the development/validation queries.
* Usage: Allows you to evaluate your model on a held-out dev set. For each dev query, you can see which passages are truly relevant (label=1), and compute metrics like MRR or nDCG.
* Similar Format: qid 0 pid relevance_label.

**Putting It All Together**
* collection gives you the text for each passage.
* queries gives you the text for each user query.
* qrels.train maps training queries to their relevant passages (positives).
* qrels.dev does the same for dev (validation) queries, letting you measure how well your system ranks the correct passages on unseen data.

**In practice, you:**

1. Load the passages from collection into a dictionary: pid -> passage_text.
2. Load queries from queries as qid -> query_text.
3. Use qrels.train to form (query, passage, label) training samples:
* Where label=1 if (qid, pid) is marked as relevant in qrels, and label=0 otherwise (or treat them as negative).
4. When you’re done training, check performance on the dev set using qrels.dev. This dev set is smaller and separate, so you can measure how well your re-ranker generalizes before final testing.

#**Step-by-Step Approach in this Notebook:**
1. Prepare Training Data
2. Load a pre-trained SBERT model
3. Convert data for SBERT training
4. Train SBERT
5. Evaluate model
6. Save the fine-tuned model ready to be used post FAISS

In [None]:
pip install -U sentence-transformers


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.11.0->sentence-transformers)
 

In [None]:
!pip install ftfy

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1


In [None]:
import pandas as pd
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
file_path = "qrels.train.tsv"
file_path = '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/qrels.train.tsv'

df_qrels = pd.read_csv(
    file_path,
    sep='\t',           # tab-separated
    header=None,
    names=["qid", "unused", "pid", "rel"]  # column names for clarity
)

df_qrels.head()

Unnamed: 0,qid,unused,pid,rel
0,1185869,0,0,1
1,1185868,0,16,1
2,597651,0,49,1
3,403613,0,60,1
4,1183785,0,389,1


In [None]:
df_qrels.shape

(532761, 4)

4 columns, 532,761 rows

* qid is query id
* pid is passage id
* rel is relevance label

In [None]:
import pandas as pd

#file_path = "queries.train.tsv"
file_path = '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/queries.train.tsv'

df_queries = pd.read_csv(
    file_path,
    sep='\t',
    header=None,
    names=["qid", "query_text"]
)

df_queries.head()

Unnamed: 0,qid,query_text
0,121352,define extreme
1,634306,what does chattel mean on credit history
2,920825,what was the great leap forward brainly
3,510633,tattoo fixers how much does it cost
4,737889,what is decentralization process.


In [None]:
df_queries.shape

(808731, 2)

In [None]:
df_collection = pd.read_csv(
    '/content/drive/MyDrive/Colab_Notebooks/Information Retrieval/Search Engine Project/collection.tsv',
    sep='\t',
    header=None,
    names=['pid', 'passage']
)
df_collection.head()


Unnamed: 0,pid,passage
0,0,The presence of communication amid scientific ...
1,1,The Manhattan Project and its atomic bomb help...
2,2,Essay on The Manhattan Project - The Manhattan...
3,3,The Manhattan Project was the name for a proje...
4,4,versions of each volume as well as complementa...


In [None]:
df_collection.shape

(8841823, 2)

In [None]:
df_collection.loc[49, "passage"]

'Colorâ\x80\x94urine can be a variety of colors, most often shades of yellow, from very pale or colorless to very dark or amber. Unusual or abnormal urine colors can be the result of a disease process, several medications (e.g., multivitamins can turn urine bright yellow), or the result of eating certain foods.'

^As you can see there is a problem with formatting in some entries in the data. I need to clean it

In [None]:
from ftfy import fix_text

# Fix encoding in 'passage'
df_collection["passage"] = df_collection["passage"].apply(fix_text)

# Check result
print(df_collection.loc[49, "passage"])


KeyboardInterrupt: 

In [None]:
print(df_collection.loc[49, "passage"])

Above you can see that the passages were fixed. Let's now do same for query text.

In [None]:
from ftfy import fix_text

# Fix encoding in 'query text'
df_queries["query_text"] = df_queries["query_text"].apply(fix_text)



##**Now we can merge all tables into a dataframe:**

In [None]:
# Merge qrels with queries on 'qid'
df_merged = pd.merge(df_qrels, df_queries, on='qid', how='left')

# Merge the resulting positives with the collection on 'pid'
df_merged = pd.merge(df_merged, df_collection, on='pid', how='left')

df_merged.drop(["unused"], axis=1, inplace=True)

df_merged.head()

In [None]:
df_merged.shape

saving this merged df:

In [None]:
df_merged.to_csv('df_merged.csv', index=False)
from google.colab import files
files.download('df_merged.csv')

You can see that the dataframe has 532,761 rows (data points)

**The dataframe above cannot yet be used for training the model, as it contains queries, passages and positive relevance scores. I need to add negative relevance scores**

In [None]:
pip install joblib


In [None]:
# all_pids is the list of all passage IDs from the collection
all_pids = df_collection['pid'].unique().tolist()


In [None]:
df_merged.head()

Checking how many CPU cores i have

In [None]:
import multiprocessing

num_cores = multiprocessing.cpu_count()
print("Number of CPU cores:", num_cores)


In [None]:
import random
import pandas as pd
from joblib import Parallel, delayed

# Suppose you have:
# df_qrels: qid, pid, rel=1
# df_queries: qid -> query_text
# df_collection: pid -> passage_text
# merged_df: your existing positive pairs (rel=1) with query/passage text
# all_pids: a list/array of all passage IDs from df_collection

################################
# 1) AGGRESSIVE SUB-SAMPLE (e.g. 10k QUERIES)
################################
all_qids = df_qrels["qid"].unique().tolist()
random.shuffle(all_qids)

subsample_size = 10_000  # more aggressive sub-sample
sub_qids = all_qids[:subsample_size]

################################
# 2) SMALLER CHUNK SIZE (e.g. 500)
################################
chunk_size = 500
chunks = [sub_qids[i : i + chunk_size] for i in range(0, len(sub_qids), chunk_size)]

################################
# Build a dictionary of positives for each qid
################################
pos_dict = {}
for row in df_qrels.itertuples():
    q = row.qid
    p = row.pid
    if q not in pos_dict:
        pos_dict[q] = set()
    pos_dict[q].add(p)

################################
# Negative sampling function
################################
def sample_negatives_for_qid(qid):
    """
    For a single qid:
      1) Determine how many positives we have (pos_count).
      2) Sample the same number of negatives from all_pids that are not in pos_pids (1:1 ratio).
      3) Return a list of (qid, pid, 0) negative pairs.
    """
    pos_pids = pos_dict[qid]
    pos_count = len(pos_pids)

    # candidate_neg_pids = all pids except positives
    candidate_neg_pids = list(set(all_pids) - pos_pids)

    # sample as many negatives as positives, or the entire candidate_neg_pids if smaller
    sample_size = min(pos_count, len(candidate_neg_pids))

    chosen_neg_pids = random.sample(candidate_neg_pids, sample_size)

    return [(qid, pid, 0) for pid in chosen_neg_pids]

################################
# 3) Process each chunk in a loop
################################
all_neg_samples = []  # global list of negative samples

for idx, chunk_qids in enumerate(chunks):
    print(f"Processing chunk {idx+1}/{len(chunks)}: {len(chunk_qids)} queries")

    # Negative sampling for the chunk's queries in parallel
    neg_samples_list = Parallel(n_jobs=2, backend="multiprocessing")(
        delayed(sample_negatives_for_qid)(qid) for qid in chunk_qids
    )

    # Flatten chunk results
    chunk_neg_samples = [item for sublist in neg_samples_list for item in sublist]

    # Append to global list
    all_neg_samples.extend(chunk_neg_samples)

    print(f"  -> Generated {len(chunk_neg_samples)} negatives in this chunk.")

################################
# 4) Convert negative pairs to DataFrame
################################
df_neg = pd.DataFrame(all_neg_samples, columns=["qid", "pid", "rel"])
print("Total negative pairs across all chunks:", len(df_neg))

################################
# 5) Merge with queries & collection to get text
################################
df_neg_merged = pd.merge(df_neg, df_queries, on='qid', how='left')
df_neg_merged = pd.merge(df_neg_merged, df_collection, on='pid', how='left')

# rename columns for consistency
df_neg_merged = df_neg_merged.rename(columns={
    "query_text_col": "query_text",
    "passage_col": "passage_text"
})

################################
# 6) Concatenate negatives (rel=0) with positives (rel=1)
################################
df_all = pd.concat([merged_df, df_neg_merged], ignore_index=True)

print("Final dataset size (pos+neg):", df_all.shape[0])
print(df_all.head(5))


Now that I generated 10k negative query-passage data pairs, I need to split df_all into positives and negatives:

In [None]:
df_all_pos = df_all[df_all["rel"] == 1].copy()
df_all_neg = df_all[df_all["rel"] == 0].copy()


**Sub-Sample 10k positives for class balance:**

In [None]:
df_all_pos_sub = df_all_pos.sample(n=10000, random_state=42)


set for all queries I kept:


In [None]:
df_all_neg_sub = df_all_neg[df_all_neg["qid"].isin(allowed_qids)].copy()


In [None]:
df_all_sub = pd.concat([df_all_pos_sub, df_all_neg_sub], ignore_index=True)


check class balance:

In [None]:
num_pos = df_all_sub[df_all_sub["rel"] == 1].shape[0]
num_neg = df_all_sub[df_all_sub["rel"] == 0].shape[0]
print("Positives:", num_pos, "Negatives:", num_neg)


After the negative relevance query-passage pairs have been generated, I will need to:

1. Inspect Class Balance
2. Final Data Cleaning (ie. remove duplicates)
3. Split Data into Train/Validation/Test (75-5-20)
3. Remove duplicates


In [None]:
df_

df_pos merged qrels (positive relevant pairs) to the queries and the collection. It contains all positive training examples

next step would be to generate negative examples, since qrels.train.tsv only contains positive labels (relevant queries)

# Choosing and importing a relevant SBERT Model:

In [None]:
from sentence_transformers import SentenceTransformer
sentences = ["This is an example sentence", "Each sentence is converted"]

model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-v4')
embeddings = model.encode(sentences)
print(embeddings)


In [None]:
embeddings.shape