# Comparing Bi-Encoders and Cross-Encoders on NFCorpus Dataset
**TAS-B Bi-Encoder**

This notebook shows how to:
1. Load the NFCorpus dataset (corpus, queries, qrels)
2. Perform bi-encoder retrieval using FAISS HNSW.
3. Save or inspect intermediate variables (embeddings, top-k results).
4. Rerank the top-k results with a cross-encoder.
5. Compare the evaluations at each step.


In [2]:
# Cell 0: Select the correct kernel
!jupyter kernelspec list

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
Available kernels:
  wse-env    /home/dl5214/.local/share/jupyter/kernels/wse-env
  python3    /home/shaoyu/anaconda3/share/jupyter/kernels/python3


# Part I: Preparations
1. Import packages.
2. Load data.
3. Prepare the queries for evaluation.

In [1]:
# Cell 1: Imports and Logging
import json
import logging
import torch
import faiss
import pytrec_eval
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder, InputExample
from tqdm import tqdm
import os
import h5py
from torch.utils.data import DataLoader
import random

logging.basicConfig(level=logging.INFO)
print('Imports complete.')

Imports complete.




In [17]:
torch.cuda.empty_cache()

In [2]:
# Cell 2: Data Loading Functions
def load_scifact_corpus(corpus_file):
    """
    Loads the Scifact corpus from a JSONL file.
    Each line is a JSON object containing '_id' and 'text'.
    Returns lists of document IDs and document texts.
    """
    doc_ids = []
    doc_texts = []
    with open(corpus_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            doc_id = data["_id"]
            text_str = data.get("text", "")
            doc_ids.append(doc_id)
            doc_texts.append(text_str)
    return doc_ids, doc_texts

def load_scifact_queries(queries_file):
    """
    Loads the Scifact queries from a JSONL file.
    Each line is a JSON object containing '_id' and 'text'.
    Returns lists of query IDs and query texts.
    """
    q_ids = []
    q_texts = []
    with open(queries_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            q_ids.append(data["_id"])
            q_texts.append(data.get("text", ""))
    return q_ids, q_texts

def load_qrels(qrels_file):
    """
    Loads the qrels from a TSV file.
    The first line is a header: 'query-id    corpus-id    score'.
    Returns a dictionary: qrels[query_id][doc_id] = relevance_score.
    """
    qrels = {}
    with open(qrels_file, "r", encoding="utf-8") as f:
        # Skip the header line:
        next(f)
        for line in f:
            line = line.strip()
            if not line:
                continue
            qid, did, rel = line.split()
            rel = int(rel)
            if qid not in qrels:
                qrels[qid] = {}
            qrels[qid][did] = rel
    return qrels

print('Data loading functions ready.')


Data loading functions ready.


In [3]:
# Cell 3: Metric Evaluation Function
# We'll define a single function we can call multiple times,
# whether for the bi-encoder alone or after cross-encoder reranking.

def evaluate_results(qrels, results, k_values=None):
    """
    Evaluates retrieval results using pytrec_eval.
    results: dict of dict => results[qid][doc_id] = score
    qrels: dict of dict => qrels[qid][doc_id] = relevance
    k_values: list of cutoff values (e.g. [5,10,20])
    """
    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"map", "ndcg_cut", "recall", "P"})
    scores = evaluator.evaluate(results)

    if k_values is None:
        # If k_values not specified, use typical cutoffs
        k_values = [5, 10, 20, 30, 100]

    ndcg_res = {}
    recall_res = {}
    prec_res = {}
    mrr_vals = []

    # Compute mean average precision (MAP) across all queries
    map_vals = [scores[qid]["map"] for qid in scores]
    map_res = np.mean(map_vals)

    # Compute NDCG@k, Recall@k, P@k
    for k in k_values:
        ndcg_vals = []
        recall_vals = []
        prec_vals = []
        for qid in scores:
            ndcg_key = f"ndcg_cut_{k}"
            recall_key = f"recall_{k}"
            prec_key = f"P_{k}"
            # If the metric isn't found for some reason, skip
            if ndcg_key in scores[qid]:
                ndcg_vals.append(scores[qid][ndcg_key])
            if recall_key in scores[qid]:
                recall_vals.append(scores[qid][recall_key])
            if prec_key in scores[qid]:
                prec_vals.append(scores[qid][prec_key])

        ndcg_res[f"NDCG@{k}"] = np.mean(ndcg_vals) if ndcg_vals else 0.0
        recall_res[f"Recall@{k}"] = np.mean(recall_vals) if recall_vals else 0.0
        prec_res[f"P@{k}"] = np.mean(prec_vals) if prec_vals else 0.0

    # Compute MRR (Mean Reciprocal Rank)
    for qid in results:
        ranked_docs = sorted(results[qid].items(), key=lambda x: x[1], reverse=True)
        for rank, (doc_id, _) in enumerate(ranked_docs, start=1):
            if qrels.get(qid, {}).get(doc_id, 0) > 0:  # Relevant document
                mrr_vals.append(1 / rank)
                break
        else:
            mrr_vals.append(0)  # No relevant document found in results

    mrr_res = np.mean(mrr_vals)

    return ndcg_res, map_res, recall_res, prec_res, mrr_res

print('Evaluation function ready.')

Evaluation function ready.


In [4]:
# Cell 4: Load the SciFact Data
corpus_file = "datasets/nfcorpus/corpus.jsonl"
queries_file = "datasets/nfcorpus/queries.jsonl"
qrels_file = "datasets/nfcorpus/qrels/test.tsv"

doc_ids, doc_texts = load_scifact_corpus(corpus_file)
query_ids, query_texts = load_scifact_queries(queries_file)
qrels = load_qrels(qrels_file)

print(f"Loaded {len(doc_ids)} documents and {len(query_ids)} queries.")
print(f"Unique qrels: {len(qrels)}")

Loaded 3633 documents and 3237 queries.
Unique qrels: 323


In [5]:
# Cell 5: Filter queries to those that appear in qrels
filtered_query_ids = [qid for qid in query_ids if qid in qrels]
filtered_query_texts = [query_texts[query_ids.index(qid)] for qid in filtered_query_ids]

print(f"Filtered queries for evaluation: {len(filtered_query_ids)}")


Filtered queries for evaluation: 323


# Part II: Bi-Encoder Retrieval with FAISS HNSW

We'll:
1. Load a SentenceTransformer (E5) model.
2. Encode all documents with prefix **"passage: "**.
3. Build a FAISS HNSW index.
4. Encode queries with prefix **"query: "**.
5. Perform approximate nearest neighbor (ANN) search.
6. Evaluate results.


In [6]:
# Cell 6: Bi-Encoder Setup
# Check if CUDA is available
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")

# Set the device to "cuda" if available, otherwise fallback to "cpu"
device = "cuda" if cuda_available else "cpu"

# Load the Bi-Encoder model on the specified device
bi_model = SentenceTransformer("msmarco-distilbert-base-tas-b", device=device)
print(f"Bi-Encoder model loaded on device: {device}")

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: msmarco-distilbert-base-tas-b


CUDA available: True
Bi-Encoder model loaded on device: cuda


In [7]:
# Cell 7: Encode Corpus with TAS-B
# We'll store these embeddings in a variable so we can reuse or inspect them.
corpus_embeddings = bi_model.encode(
    doc_texts,
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)

print("Corpus embeddings shape:", corpus_embeddings.shape)

Batches:   0%|          | 0/57 [00:00<?, ?it/s]

Corpus embeddings shape: (3633, 768)


In [9]:
# Cell 8: Build the FAISS HNSW Index
dimension = corpus_embeddings.shape[1]
M = 15  # number of connections per node in HNSW
efConstruction = 300
efSearch = 1000

logging.info("Building HNSW index...")
index_hnsw = faiss.IndexHNSWFlat(dimension, M)
index_hnsw.hnsw.efConstruction = efConstruction
index_hnsw.hnsw.efSearch = efSearch

# Add corpus embeddings to index
index_hnsw.add(corpus_embeddings)

# If GPU available, move index to GPU
if torch.cuda.is_available():
    logging.info("Moving HNSW index to GPU...")
    faiss_res = faiss.StandardGpuResources()
    index_hnsw = faiss.index_cpu_to_gpu(faiss_res, 0, index_hnsw)

print("HNSW index ready.")

INFO:root:Building HNSW index...
INFO:root:Moving HNSW index to GPU...


HNSW index ready.


In [10]:
# Cell 9: Encode Queries with TAS-B
# We'll store query embeddings so we can re-run or examine.

query_embeddings = bi_model.encode(
    filtered_query_texts,
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)
print("Query embeddings shape:", query_embeddings.shape)

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Query embeddings shape: (323, 768)


## Perform ANN Search and Evaluate (Bi-Encoder Only)

We'll retrieve the top-k docs from FAISS for each query, then convert them to a `results` dict for evaluation.

In [11]:
# Cell 10: Bi-Encoder ANN Search
top_k = 500
logging.info("Searching query embeddings in HNSW index...")
D, I = index_hnsw.search(query_embeddings, top_k)

# Convert results to pytrec_eval format
# results[qid][docid] = score
bi_encoder_results = {}
for q_idx, qid in enumerate(filtered_query_ids):
    bi_encoder_results[qid] = {}
    for rank in range(top_k):
        doc_idx = I[q_idx, rank]
        # for a distance D[q_idx, rank], we can invert or just use negative distance
        # Some prefer 1/dist, or -dist, or use it directly.
        # We'll do a small trick: score = 1 / distance.
        # If distance is 0, we can handle that or set a high value.
        dist = D[q_idx, rank]
        score = 1/float(dist) if dist != 0 else 1e9
        doc_id = doc_ids[doc_idx]
        bi_encoder_results[qid][doc_id] = score

print("Bi-Encoder top-k results ready.")

INFO:root:Searching query embeddings in HNSW index...


Bi-Encoder top-k results ready.


In [13]:
# Print the number of retrieved documents for the first 5 queries along with their scores
print("=== Query Results ===")
for query_id, doc_scores in list(bi_encoder_results.items())[:5]:  # Limit to first 5 queries
    num_docs = len(doc_scores)  # Number of documents retrieved for this query
    print(f"Query ID: {query_id}")
    print(f"Number of retrieved documents: {num_docs}")
    print(f"Scores (truncated): {str(doc_scores)[:90]}")
    print("-" * 50)

=== Query Results ===
Query ID: PLAIN-2
Number of retrieved documents: 500
Scores (truncated): {'MED-10': 3.042715626366151, 'MED-2429': 2.8025310152433374, 'MED-2439': 2.78169012566604
--------------------------------------------------
Query ID: PLAIN-12
Number of retrieved documents: 500
Scores (truncated): {'MED-2519': 2.060865946861249, 'MED-2514': 2.0394279051095356, 'MED-1928': 2.005049787385
--------------------------------------------------
Query ID: PLAIN-23
Number of retrieved documents: 500
Scores (truncated): {'MED-1961': 2.648593253886031, 'MED-1169': 2.257566035513047, 'MED-5088': 2.2232102440477
--------------------------------------------------
Query ID: PLAIN-33
Number of retrieved documents: 500
Scores (truncated): {'MED-2723': 2.8594476924407637, 'MED-2489': 2.6964898428889374, 'MED-5006': 2.60171977712
--------------------------------------------------
Query ID: PLAIN-44
Number of retrieved documents: 500
Scores (truncated): {'MED-2240': 2.568471889128318, 'MED-2791

In [14]:
# Cell 11: Evaluate Bi-Encoder Results
k_values = [5, 10, 20, 30, 100]
ndcg_bi, map_bi, recall_bi, prec_bi, mrr_bi = evaluate_results(qrels, bi_encoder_results, k_values)

print("=== Bi-Encoder (E5) Results ===")
# print("NDCG@k:", ndcg_bi)
# print("MAP:", map_bi)
# print("Recall@k:", recall_bi)
# print("Precision@k:", prec_bi)
print("NDCG@k:")
for k, v in ndcg_bi.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_bi:.4f}")

print("\nRecall@k:")
for k, v in recall_bi.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_bi.items():
    print(f"  {k:<8}: {v:.4f}")
    
print(f"\nMRR: {mrr_bi:.4f}")

=== Bi-Encoder (E5) Results ===
NDCG@k:
  NDCG@5 : 0.2810
  NDCG@10: 0.2564
  NDCG@20: 0.2360
  NDCG@30: 0.2285
  NDCG@100: 0.2303

MAP: 0.1132

Recall@k:
  Recall@5: 0.0917
  Recall@10: 0.1197
  Recall@20: 0.1507
  Recall@30: 0.1678
  Recall@100: 0.2338

Precision@k:
  P@5     : 0.2446
  P@10    : 0.1938
  P@20    : 0.1399
  P@30    : 0.1134
  P@100   : 0.0593

MRR: 0.4441


# Part III: Cross-Encoder Reranking

We'll now demonstrate how to:
1. Take the **top-k** documents from the **bi-encoder** results.
2. Rerank them with a **CrossEncoder**.
3. Evaluate again using the same `evaluate_results` function.


In [12]:
# Cell 12: Define a function to rerank with a CrossEncoder

def cross_encode_rerank(cross_encoder_model, 
                        query_ids_list, query_texts_list,
                        topk_results, corpus_text_map,
                        rerank_top_k=500, final_top_k=100):
    """
    cross_encoder_model: a CrossEncoder from sentence-transformers
    query_ids_list: list of query IDs (strings)
    query_texts_list: list of corresponding query texts
    topk_results: dict => topk_results[qid][doc_id] = some bi-encoder score
    corpus_text_map: a dict doc_id -> doc_text
    rerank_top_k: how many documents to consider for reranking from the bi-encoder results
    final_top_k: how many documents to keep after cross-encoder reranking

    returns: dict => reranked_results[qid][doc_id] = cross-encoder score
    """
    reranked_results = {}

    # Use tqdm to show total progress for all queries
    with tqdm(total=len(query_ids_list), desc="Cross-encoding queries", unit="query", ncols=80) as pbar:
        for idx, qid in enumerate(query_ids_list):
            # Get top rerank_top_k document candidates for the query
            doc_candidates = sorted(topk_results[qid].items(), key=lambda x: x[1], reverse=True)[:rerank_top_k]
            doc_ids = [doc_id for doc_id, _ in doc_candidates]
            query_text = query_texts_list[idx]

            # Prepare query-document pairs for the cross-encoder
#             pair_texts = [(f"query: {query_text}", f"passage: {corpus_text_map[doc_id]}") for doc_id in doc_ids]
            pair_texts = [(query_text, corpus_text_map[doc_id]) for doc_id in doc_ids]
            
            # Perform cross-encoder inference, ensure no additional progress bars are shown
            scores = cross_encoder_model.predict(pair_texts, show_progress_bar=False)

            # Combine document IDs with their Cross-Encoder scores
            reranked_doc_scores = {doc_ids[i]: float(scores[i]) for i in range(len(doc_ids))}

            # Sort reranked results by Cross-Encoder scores and keep final_top_k
            reranked_results[qid] = dict(sorted(reranked_doc_scores.items(), key=lambda x: x[1], reverse=True)[:final_top_k])

            # Update the progress bar
            pbar.update(1)

    return reranked_results

print('Cross-encoder rerank function defined.')

Cross-encoder rerank function defined.


### Prepare Data Structures for Reranking

- We already have `bi_encoder_results`, which is `dict[qid][doc_id] = score`.
- We need a quick way to map doc IDs to actual text for the CrossEncoder.


In [13]:
# Cell 13: Create a dictionary doc_id -> doc_text
doc_id_to_text = {}
for i, did in enumerate(doc_ids):
    doc_id_to_text[did] = doc_texts[i]

print(f"doc_id_to_text dictionary has {len(doc_id_to_text)} entries.")

doc_id_to_text dictionary has 3633 entries.


## 3.1: MiniLM CrossEncoder Reranking

### Load a CrossEncoder
Choose any cross-encoder checkpoint from Hugging Face or sentence-transformers.
For demonstration, we'll pick something like `'cross-encoder/ms-marco-MiniLM-L-6-v2'`.


In [17]:
# Cell 14: CrossEncoder initialization
cross_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"  # or any other suitable model
cross_encoder = CrossEncoder(cross_model_name, device=device)
print("CrossEncoder loaded.")

CrossEncoder loaded.


### Rerank the top-k results from the Bi-Encoder

We'll use the cross-encoder function defined above and pass in the top-k results from the Bi-Encoder.


In [18]:
# Cell 15: Rerank top-k with CrossEncoder

# Parameters to control the number of documents to rerank and return
rerank_top_k = 300  # Number of documents to consider for reranking
final_top_k = 100    # Number of documents to keep after reranking

logging.info(f"Reranking top {rerank_top_k} candidates with cross-encoder, keeping top {final_top_k}...")

cross_encoder_results_minilm = cross_encode_rerank(
    cross_encoder_model=cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,  # Bi-Encoder results
    corpus_text_map=doc_id_to_text,   # Document text mapping
    rerank_top_k=rerank_top_k,        # Candidates to rerank
    final_top_k=final_top_k           # Final results to keep
)

print("Cross-encoder reranking complete.")

INFO:root:Reranking top 300 candidates with cross-encoder, keeping top 100...
Cross-encoding queries: 100%|██████████████| 323/323 [06:30<00:00,  1.21s/query]

Cross-encoder reranking complete.





### Evaluate Cross-Encoder Reranked Results


In [19]:
# Cell 16: Evaluate the cross-encoder-based results
k_values = [5, 10, 20, 30, 100]
ndcg_ce, map_ce, recall_ce, prec_ce, mrr_ce = evaluate_results(qrels, cross_encoder_results_minilm, k_values)

# print("=== Cross-Encoder (Rerank) Results ===")
# print("NDCG@k:", ndcg_ce)
# print("MAP:", map_ce)
# print("Recall@k:", recall_ce)
# print("Precision@k:", prec_ce)

# Print formatted results
print("=== MiniLM Cross-Encoder (Rerank) Results ===")
print("NDCG@k:")
for k, v in ndcg_ce.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ce:.4f}")

print("\nRecall@k:")
for k, v in recall_ce.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ce.items():
    print(f"  {k:<8}: {v:.4f}")
    
print(f"\nMRR: {mrr_ce:.4f}")

=== MiniLM Cross-Encoder (Rerank) Results ===
NDCG@k:
  NDCG@5 : 0.3715
  NDCG@10: 0.3338
  NDCG@20: 0.3059
  NDCG@30: 0.2931
  NDCG@100: 0.2861

MAP: 0.1486

Recall@k:
  Recall@5: 0.1267
  Recall@10: 0.1555
  Recall@20: 0.1826
  Recall@30: 0.1967
  Recall@100: 0.2590

Precision@k:
  P@5     : 0.3121
  P@10    : 0.2378
  P@20    : 0.1743
  P@30    : 0.1404
  P@100   : 0.0690

MRR: 0.5546


## 3.2: Electra CrossEncoder Reranking
To compare the performance of different cross-encoders, here we use `'cross-encoder/ms-marco-electra-base'`.


In [20]:
# Cell 14: CrossEncoder initialization
cross_model_name = "cross-encoder/ms-marco-electra-base"  # or any other suitable model
cross_encoder = CrossEncoder(cross_model_name, device=device)
print("CrossEncoder loaded.")

CrossEncoder loaded.


In [21]:
# Cell 15: Rerank top-k with CrossEncoder

# Parameters to control the number of documents to rerank and return
rerank_top_k = 300  # Number of documents to consider for reranking
final_top_k = 100    # Number of documents to keep after reranking

logging.info(f"Reranking top {rerank_top_k} candidates with cross-encoder, keeping top {final_top_k}...")

cross_encoder_results_electra = cross_encode_rerank(
    cross_encoder_model=cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,  # Bi-Encoder results
    corpus_text_map=doc_id_to_text,   # Document text mapping
    rerank_top_k=rerank_top_k,        # Candidates to rerank
    final_top_k=final_top_k           # Final results to keep
)

print("Cross-encoder reranking complete.")

INFO:root:Reranking top 300 candidates with cross-encoder, keeping top 100...
Cross-encoding queries: 100%|██████████████| 323/323 [46:06<00:00,  8.57s/query]

Cross-encoder reranking complete.





In [22]:
# Cell 16: Evaluate the cross-encoder-based results
k_values = [5, 10, 20, 30, 100]
ndcg_ce, map_ce, recall_ce, prec_ce, mrr_ce = evaluate_results(qrels, cross_encoder_results_electra, k_values)

# print("=== Cross-Encoder (Rerank) Results ===")
# print("NDCG@k:", ndcg_ce)
# print("MAP:", map_ce)
# print("Recall@k:", recall_ce)
# print("Precision@k:", prec_ce)

# Print formatted results
print("=== Electra Cross-Encoder (Rerank) Results ===")
print("NDCG@k:")
for k, v in ndcg_ce.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ce:.4f}")

print("\nRecall@k:")
for k, v in recall_ce.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ce.items():
    print(f"  {k:<8}: {v:.4f}")
    
print(f"\nMRR: {mrr_ce:.4f}")

=== Electra Cross-Encoder (Rerank) Results ===
NDCG@k:
  NDCG@5 : 0.3168
  NDCG@10: 0.2873
  NDCG@20: 0.2596
  NDCG@30: 0.2515
  NDCG@100: 0.2535

MAP: 0.1223

Recall@k:
  Recall@5: 0.1104
  Recall@10: 0.1363
  Recall@20: 0.1593
  Recall@30: 0.1785
  Recall@100: 0.2468

Precision@k:
  P@5     : 0.2762
  P@10    : 0.2130
  P@20    : 0.1471
  P@30    : 0.1200
  P@100   : 0.0625

MRR: 0.4823


## 3.3: Tiny-BERT CrossEncoder Reranking
To compare the performance of different cross-encoders, here we use `'cross-encoder/ms-marco-TinyBERT-L-2-v2'`.


In [None]:
# Cell 14: CrossEncoder initialization
cross_model_name = "cross-encoder/ms-marco-TinyBERT-L-2-v2"  # or any other suitable model
cross_encoder = CrossEncoder(cross_model_name, device=device)
print("CrossEncoder loaded.")

In [None]:
# Cell 15: Rerank top-k with CrossEncoder

# Parameters to control the number of documents to rerank and return
rerank_top_k = 300  # Number of documents to consider for reranking
final_top_k = 100    # Number of documents to keep after reranking

logging.info(f"Reranking top {rerank_top_k} candidates with cross-encoder, keeping top {final_top_k}...")

cross_encoder_results_tinybert = cross_encode_rerank(
    cross_encoder_model=cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,  # Bi-Encoder results
    corpus_text_map=doc_id_to_text,   # Document text mapping
    rerank_top_k=rerank_top_k,        # Candidates to rerank
    final_top_k=final_top_k           # Final results to keep
)

print("Cross-encoder reranking complete.")

In [None]:
# Cell 16: Evaluate the cross-encoder-based results
k_values = [5, 10, 20, 30, 100]
ndcg_ce, map_ce, recall_ce, prec_ce, mrr_ce = evaluate_results(qrels, cross_encoder_results_tinybert, k_values)

# print("=== Cross-Encoder (Rerank) Results ===")
# print("NDCG@k:", ndcg_ce)
# print("MAP:", map_ce)
# print("Recall@k:", recall_ce)
# print("Precision@k:", prec_ce)

# Print formatted results
print("=== Tiny-BERT Cross-Encoder (Rerank) Results ===")
print("NDCG@k:")
for k, v in ndcg_ce.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ce:.4f}")

print("\nRecall@k:")
for k, v in recall_ce.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ce.items():
    print(f"  {k:<8}: {v:.4f}")
    
print(f"\nMRR: {mrr_ce:.4f}")

# Part IV: Fine Tuning Cross-Encoders
Below demonstrates the process of fine-tuning Cross-Encoders on query-document pairs for relevance scoring. It includes functions to load training data, prepare input examples, and fine-tune a Cross-Encoder model using training data. The fine-tuned model can then be used for downstream reranking tasks or other applications requiring precise query-document matching.

In [14]:
def build_cross_encoder_input_examples(qrels, query_dict, doc_dict, neg_ratio=10):
    """
    Converts qrels and additional negative samples to InputExample objects for CrossEncoder training.
    Ensures the inclusion of both positive and negative samples.

    qrels: dict of dict => qrels[qid][did] = relevance_score (e.g., 1 for positive samples)
    query_dict: dict of query_id -> query_text
    doc_dict: dict of doc_id -> doc_text
    neg_ratio: Ratio of negative samples to positive samples (e.g., 10 means 10 negatives per positive).

    Returns:
        examples: List of InputExample
        num_pos: Number of positive examples
        num_neg: Number of negative examples
    """
    examples = []
    num_pos = 0
    num_neg = 0

    for qid in qrels:
        if qid not in query_dict:
            continue
        
        # Positive samples (label=1)
        positive_docs = [did for did, score in qrels[qid].items() if score == 1]
        for did in positive_docs:
            if did in doc_dict:
                examples.append(InputExample(texts=[query_dict[qid], doc_dict[did]], label=1))
                num_pos += 1
        
        # Negative samples (label=0)
        all_docs = set(doc_dict.keys())
        negative_candidates = list(all_docs - set(qrels[qid].keys()))  # Exclude positive and qrels docs
        negative_samples = random.sample(negative_candidates, min(len(negative_candidates), len(positive_docs) * neg_ratio))
        
        for did in negative_samples:
            if did in doc_dict:
                examples.append(InputExample(texts=[query_dict[qid], doc_dict[did]], label=0))
                num_neg += 1

    print(f"Generated {num_pos} positive samples and {num_neg} negative samples.")
    return examples


def fine_tune_cross_encoder(model_name,
                            train_examples,
                            output_model_path="fine_tuned_cross_encoder",
                            epochs=1,
                            batch_size=8,
                            lr=1e-5,
                            warmup_ratio=0.1):
    """
    Fine-tunes a CrossEncoder on the given (query, doc, label) examples.
    model_name: e.g. "cross-encoder/ms-marco-MiniLM-L-6-v2"
    train_examples: list of InputExample
    output_model_path: where to save the fine-tuned model
    epochs, batch_size, lr, warmup_ratio: training parameters
    """
    # Initialize a CrossEncoder with 1 output (for regression or binary classification)
    cross_encoder = CrossEncoder(model_name, num_labels=1)

    # Create a DataLoader
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)

    # Total steps
    total_steps = len(train_dataloader) * epochs
    warmup_steps = int(total_steps * warmup_ratio)

    # Fit the model
    cross_encoder.fit(
        train_dataloader=train_dataloader,
        epochs=epochs,
        warmup_steps=warmup_steps,
        optimizer_params={'lr': lr},
    )

    # Save the fine-tuned model
    cross_encoder.save(output_model_path)

    # Reload the model from disk to ensure it's saved correctly
    fine_tuned_model = CrossEncoder(output_model_path)
    return fine_tuned_model

print("CrossEncoder fine-tuning functions ready.")

CrossEncoder fine-tuning functions ready.


In [14]:
# Prepare Training Data for CrossEncoder

# 1) Training data filepath:
train_qrels_file = "datasets/nfcorpus/qrels/train.tsv"   # or wherever your train file is
train_queries_file = "datasets/nfcorpus/queries.jsonl"
train_corpus_file = "datasets/nfcorpus/corpus.jsonl"

# 2) Load queries and docs for training
train_qids, train_qtexts = load_scifact_queries(train_queries_file)
train_docids, train_doctexts = load_scifact_corpus(train_corpus_file)

# Build dictionaries for easy lookup
train_query_dict = {q: t for q, t in zip(train_qids, train_qtexts)}
train_doc_dict = {d: t for d, t in zip(train_docids, train_doctexts)}

# 3) Load the training qrels and build InputExample for cross-encoder
train_qrels = load_qrels(train_qrels_file)  # Returns a nested dictionary
# train_examples = build_cross_encoder_input_examples(
#     [(qid, did, train_qrels[qid][did]) for qid in train_qrels for did in train_qrels[qid]],
#     train_query_dict,
#     train_doc_dict
# )
train_examples = build_cross_encoder_input_examples(
    qrels=train_qrels,  # Pass the nested dictionary directly
    query_dict=train_query_dict,
    doc_dict=train_doc_dict,
    neg_ratio=10  # Adjust as needed
)
print(f"Number of training examples: {len(train_examples)}")

Generated 110575 positive samples and 938541 negative samples.
Number of training examples: 1049116


## 4.1 Fine Tuning MiniLM Cross Encoder

In [None]:
#######################################################
# Fine-Tune the MiniLM CrossEncoder
#######################################################
# 1) Name of the base cross-encoder
base_cross_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"

# 2) Output path for the fine-tuned model
fine_tuned_model_path = "./models/nfcorpus_fine_tuned_cross_encoder_minilm"

# 3) Fine-tune
# Adjust epochs, batch_size, lr, etc. as needed.
fine_tuned_cross_encoder = fine_tune_cross_encoder(
    model_name=base_cross_model_name,
    train_examples=train_examples,
    output_model_path=fine_tuned_model_path,
    epochs=3,        # example
    batch_size=16,   # example
    lr=2e-5,         # example
    warmup_ratio=0.1
)

print("Fine-tuning complete. Saved to:", fine_tuned_model_path)

INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda


Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/65570 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [16]:
#######################################################
# Rerank with the Fine-Tuned CrossEncoder
#######################################################

logging.info("Reranking top-k with the fine-tuned cross-encoder...")
fine_tuned_model_path = "./models/nfcorpus_fine_tuned_cross_encoder_minilm"
fine_tuned_cross_encoder = CrossEncoder(fine_tuned_model_path)

rerank_top_k = 300
final_top_k = 100

cross_encoder_results_finetuned = cross_encode_rerank(
    cross_encoder_model=fine_tuned_cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,   # from your bi-encoder retrieval
    corpus_text_map=doc_id_to_text,
    rerank_top_k=rerank_top_k,
    final_top_k=final_top_k
)

print("Fine-tuned CrossEncoder reranking complete.")

INFO:root:Reranking top-k with the fine-tuned cross-encoder...
INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda
Cross-encoding queries: 100%|██████████████| 323/323 [06:29<00:00,  1.21s/query]

Fine-tuned CrossEncoder reranking complete.





In [17]:
#######################################################
# Evaluate the Fine-Tuned CrossEncoder Results
#######################################################
# 1) Use the same evaluate_results function
k_values = [5, 10, 20, 30, 100]
ndcg_ft, map_ft, recall_ft, prec_ft, mrr_ft = evaluate_results(qrels, cross_encoder_results_finetuned, k_values)

# 2) Print or log the metrics
print("=== Fine-Tuned Cross-Encoder (MiniLM) Results ===")
print("NDCG@k:")
for k, v in ndcg_ft.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ft:.4f}")

print("\nRecall@k:")
for k, v in recall_ft.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ft.items():
    print(f"  {k:<8}: {v:.4f}")

print(f"\nMRR: {mrr_ft:.4f}")

=== Fine-Tuned Cross-Encoder (MiniLM) Results ===
NDCG@k:
  NDCG@5 : 0.3933
  NDCG@10: 0.3684
  NDCG@20: 0.3396
  NDCG@30: 0.3265
  NDCG@100: 0.3128

MAP: 0.1662

Recall@k:
  Recall@5: 0.1273
  Recall@10: 0.1714
  Recall@20: 0.2107
  Recall@30: 0.2325
  Recall@100: 0.3021

Precision@k:
  P@5     : 0.3548
  P@10    : 0.2923
  P@20    : 0.2146
  P@30    : 0.1737
  P@100   : 0.0819

MRR: 0.5434


## 4.2 Fine Tuning TinyBERT Cross Encoder

In [None]:
#######################################################
# Fine-Tune the MiniLM CrossEncoder
#######################################################
# 1) Name of the base cross-encoder
base_cross_model_name = "cross-encoder/ms-marco-TinyBERT-L-2-v2"

# 2) Output path for the fine-tuned model
fine_tuned_model_path = "./models/nfcorpus_fine_tuned_cross_encoder_tinybert"

# 3) Fine-tune
# Adjust epochs, batch_size, lr, etc. as needed.
fine_tuned_cross_encoder = fine_tune_cross_encoder(
    model_name=base_cross_model_name,
    train_examples=train_examples,
    output_model_path=fine_tuned_model_path,
    epochs=3,        # example
    batch_size=16,   # example
    lr=2e-5,         # example
    warmup_ratio=0.1
)

print("Fine-tuning complete. Saved to:", fine_tuned_model_path)

In [None]:
#######################################################
# Rerank with the Fine-Tuned CrossEncoder
#######################################################

logging.info("Reranking top-k with the fine-tuned cross-encoder...")
fine_tuned_model_path = "./models/scifact_fine_tuned_cross_encoder_tinybert"
fine_tuned_cross_encoder = CrossEncoder(fine_tuned_model_path)

rerank_top_k = 300
final_top_k = 100

cross_encoder_results_finetuned = cross_encode_rerank(
    cross_encoder_model=fine_tuned_cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,   # from your bi-encoder retrieval
    corpus_text_map=doc_id_to_text,
    rerank_top_k=rerank_top_k,
    final_top_k=final_top_k
)

print("Fine-tuned CrossEncoder reranking complete.")

In [None]:
#######################################################
# Evaluate the Fine-Tuned CrossEncoder Results
#######################################################
# 1) Use the same evaluate_results function
k_values = [5, 10, 20, 30, 100]
ndcg_ft, map_ft, recall_ft, prec_ft, mrr_ft = evaluate_results(qrels, cross_encoder_results_finetuned, k_values)

# 2) Print or log the metrics
print("=== Fine-Tuned Cross-Encoder (Tiny-BERT) Results ===")
print("NDCG@k:")
for k, v in ndcg_ft.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ft:.4f}")

print("\nRecall@k:")
for k, v in recall_ft.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ft.items():
    print(f"  {k:<8}: {v:.4f}")

print(f"\nMRR: {mrr_ft:.4f}")

## 4.3 Fine Tuning MiniLM Cross Encoder (Limited Samples)

In [15]:
def build_cross_encoder_limited_examples(qrels, query_dict, doc_dict, pos_limit=2000, neg_limit=20000):
    """
    Converts qrels and additional limited negative samples to InputExample objects for CrossEncoder training.
    Ensures the inclusion of both positive and negative samples with given limits.

    qrels: dict of dict => qrels[qid][did] = relevance_score (e.g., 1 for positive samples)
    query_dict: dict of query_id -> query_text
    doc_dict: dict of doc_id -> doc_text
    pos_limit: Limit on the number of positive samples (e.g., 2000).
    neg_limit: Limit on the number of negative samples (e.g., 20000).

    Returns:
        examples: List of InputExample
        num_pos: Number of positive examples
        num_neg: Number of negative examples
    """
    examples = []
    num_pos = 0
    num_neg = 0

    all_positive_samples = []
    all_negative_candidates = []

    # Gather all positive and negative candidates
    for qid in qrels:
        if qid not in query_dict:
            continue

        # Positive samples (label=1)
        positive_docs = [did for did, score in qrels[qid].items() if score == 1]
        for did in positive_docs:
            if did in doc_dict:
                all_positive_samples.append((qid, did, 1))

        # Negative candidates (not in qrels or positive docs)
        all_docs = set(doc_dict.keys())
        negative_candidates = list(all_docs - set(qrels[qid].keys()))
        for did in negative_candidates:
            if did in doc_dict:
                all_negative_candidates.append((qid, did, 0))

    # Randomly sample positives and negatives
    selected_positives = random.sample(all_positive_samples, min(len(all_positive_samples), pos_limit))
    selected_negatives = random.sample(all_negative_candidates, min(len(all_negative_candidates), neg_limit))

    # Build InputExample for positive samples
    for qid, did, label in selected_positives:
        examples.append(InputExample(texts=[query_dict[qid], doc_dict[did]], label=label))
        num_pos += 1

    # Build InputExample for negative samples
    for qid, did, label in selected_negatives:
        examples.append(InputExample(texts=[query_dict[qid], doc_dict[did]], label=label))
        num_neg += 1

    print(f"Generated {num_pos} positive samples and {num_neg} negative samples (limited).")
    return examples

In [None]:
# Prepare Training Data for CrossEncoder

# 1) Training data filepath:
train_qrels_file = "datasets/nfcorpus/qrels/train.tsv"   # or wherever your train file is
train_queries_file = "datasets/nfcorpus/queries.jsonl"
train_corpus_file = "datasets/nfcorpus/corpus.jsonl"

# 2) Load queries and docs for training
train_qids, train_qtexts = load_scifact_queries(train_queries_file)
train_docids, train_doctexts = load_scifact_corpus(train_corpus_file)

# Build dictionaries for easy lookup
train_query_dict = {q: t for q, t in zip(train_qids, train_qtexts)}
train_doc_dict = {d: t for d, t in zip(train_docids, train_doctexts)}

# 3) Load the training qrels and build InputExample for cross-encoder
train_qrels = load_qrels(train_qrels_file)  # Returns a nested dictionary
# train_examples = build_cross_encoder_input_examples(
#     [(qid, did, train_qrels[qid][did]) for qid in train_qrels for did in train_qrels[qid]],
#     train_query_dict,
#     train_doc_dict
# )
train_examples = build_cross_encoder_limited_examples(
    qrels=train_qrels,  # Pass the nested dictionary directly
    query_dict=train_query_dict,
    doc_dict=train_doc_dict
)
print(f"Number of training examples: {len(train_examples)}")

In [None]:
#######################################################
# Fine-Tune the MiniLM CrossEncoder
#######################################################
# 1) Name of the base cross-encoder
base_cross_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"

# 2) Output path for the fine-tuned model
fine_tuned_model_path = "./models/nfcorpus_fine_tuned_cross_encoder_minilm_limited"

# 3) Fine-tune
# Adjust epochs, batch_size, lr, etc. as needed.
fine_tuned_cross_encoder = fine_tune_cross_encoder(
    model_name=base_cross_model_name,
    train_examples=train_examples,
    output_model_path=fine_tuned_model_path,
    epochs=3,        # example
    batch_size=16,   # example
    lr=2e-5,         # example
    warmup_ratio=0.1
)

print("Fine-tuning complete. Saved to:", fine_tuned_model_path)

In [16]:
#######################################################
# Rerank with the Fine-Tuned CrossEncoder
#######################################################

logging.info("Reranking top-k with the fine-tuned cross-encoder...")
fine_tuned_model_path = "./models/nfcorpus_fine_tuned_cross_encoder_minilm_limited"
fine_tuned_cross_encoder = CrossEncoder(fine_tuned_model_path)

rerank_top_k = 300
final_top_k = 100

cross_encoder_results_finetuned = cross_encode_rerank(
    cross_encoder_model=fine_tuned_cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,   # from your bi-encoder retrieval
    corpus_text_map=doc_id_to_text,
    rerank_top_k=rerank_top_k,
    final_top_k=final_top_k
)

print("Fine-tuned CrossEncoder reranking complete.")

INFO:root:Reranking top-k with the fine-tuned cross-encoder...
INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda
Cross-encoding queries: 100%|██████████████| 323/323 [06:35<00:00,  1.22s/query]

Fine-tuned CrossEncoder reranking complete.





In [17]:
#######################################################
# Evaluate the Fine-Tuned CrossEncoder Results
#######################################################
# 1) Use the same evaluate_results function
k_values = [5, 10, 20, 30, 100]
ndcg_ft, map_ft, recall_ft, prec_ft, mrr_ft = evaluate_results(qrels, cross_encoder_results_finetuned, k_values)

# 2) Print or log the metrics
print("=== Fine-Tuned Cross-Encoder (MiniLM, Limited) Results ===")
print("NDCG@k:")
for k, v in ndcg_ft.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ft:.4f}")

print("\nRecall@k:")
for k, v in recall_ft.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ft.items():
    print(f"  {k:<8}: {v:.4f}")

print(f"\nMRR: {mrr_ft:.4f}")

=== Fine-Tuned Cross-Encoder (MiniLM, Limited) Results ===
NDCG@k:
  NDCG@5 : 0.3808
  NDCG@10: 0.3450
  NDCG@20: 0.3163
  NDCG@30: 0.3038
  NDCG@100: 0.2955

MAP: 0.1560

Recall@k:
  Recall@5: 0.1299
  Recall@10: 0.1609
  Recall@20: 0.1890
  Recall@30: 0.2078
  Recall@100: 0.2713

Precision@k:
  P@5     : 0.3276
  P@10    : 0.2505
  P@20    : 0.1842
  P@30    : 0.1478
  P@100   : 0.0722

MRR: 0.5586


## 4.4 Fine Tuning TinyBERT Cross Encoder (Limited Samples)

In [None]:
#######################################################
# Fine-Tune the MiniLM CrossEncoder
#######################################################
# 1) Name of the base cross-encoder
base_cross_model_name = "cross-encoder/ms-marco-TinyBERT-L-2-v2"

# 2) Output path for the fine-tuned model
fine_tuned_model_path = "./models/nfcorpus_fine_tuned_cross_encoder_tinybert_limited"

# 3) Fine-tune
# Adjust epochs, batch_size, lr, etc. as needed.
fine_tuned_cross_encoder = fine_tune_cross_encoder(
    model_name=base_cross_model_name,
    train_examples=train_examples,
    output_model_path=fine_tuned_model_path,
    epochs=3,        # example
    batch_size=16,   # example
    lr=2e-5,         # example
    warmup_ratio=0.1
)

print("Fine-tuning complete. Saved to:", fine_tuned_model_path)

In [18]:
#######################################################
# Rerank with the Fine-Tuned CrossEncoder
#######################################################

logging.info("Reranking top-k with the fine-tuned cross-encoder...")
fine_tuned_model_path = "./models/nfcorpus_fine_tuned_cross_encoder_tinybert_limited"
fine_tuned_cross_encoder = CrossEncoder(fine_tuned_model_path)

rerank_top_k = 300
final_top_k = 100

cross_encoder_results_finetuned = cross_encode_rerank(
    cross_encoder_model=fine_tuned_cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,   # from your bi-encoder retrieval
    corpus_text_map=doc_id_to_text,
    rerank_top_k=rerank_top_k,
    final_top_k=final_top_k
)

print("Fine-tuned CrossEncoder reranking complete.")

INFO:root:Reranking top-k with the fine-tuned cross-encoder...
INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda
Cross-encoding queries: 100%|██████████████| 323/323 [00:36<00:00,  8.75query/s]

Fine-tuned CrossEncoder reranking complete.





In [19]:
#######################################################
# Evaluate the Fine-Tuned CrossEncoder Results
#######################################################
# 1) Use the same evaluate_results function
k_values = [5, 10, 20, 30, 100]
ndcg_ft, map_ft, recall_ft, prec_ft, mrr_ft = evaluate_results(qrels, cross_encoder_results_finetuned, k_values)

# 2) Print or log the metrics
print("=== Fine-Tuned Cross-Encoder (Tiny-BERT, Limited) Results ===")
print("NDCG@k:")
for k, v in ndcg_ft.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ft:.4f}")

print("\nRecall@k:")
for k, v in recall_ft.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ft.items():
    print(f"  {k:<8}: {v:.4f}")

print(f"\nMRR: {mrr_ft:.4f}")

=== Fine-Tuned Cross-Encoder (Tiny-BERT, Limited) Results ===
NDCG@k:
  NDCG@5 : 0.3549
  NDCG@10: 0.3192
  NDCG@20: 0.2917
  NDCG@30: 0.2801
  NDCG@100: 0.2757

MAP: 0.1409

Recall@k:
  Recall@5: 0.1224
  Recall@10: 0.1538
  Recall@20: 0.1798
  Recall@30: 0.1929
  Recall@100: 0.2562

Precision@k:
  P@5     : 0.3046
  P@10    : 0.2307
  P@20    : 0.1667
  P@30    : 0.1358
  P@100   : 0.0684

MRR: 0.5326
