# Comparing Bi-Encoders and Cross-Encoders on Scifact Dataset
**TAS-B Bi-Encoder**

This notebook shows how to:
1. Load the SciFact dataset (corpus, queries, qrels)
2. Perform bi-encoder retrieval using FAISS HNSW.
3. Save or inspect intermediate variables (embeddings, top-k results).
4. Rerank the top-k results with a cross-encoder.
5. Compare the evaluations at each step.


In [2]:
# Cell 0: Select the correct kernel
!jupyter kernelspec list

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
Available kernels:
  wse-env    /home/dl5214/.local/share/jupyter/kernels/wse-env
  python3    /home/shaoyu/anaconda3/share/jupyter/kernels/python3


# Part I: Preparations
1. Import packages.
2. Load data.
3. Prepare the queries for evaluation.

In [9]:
# Cell 1: Imports and Logging
import json
import logging
import torch
import faiss
import pytrec_eval
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
from tqdm import tqdm
import os
import h5py

logging.basicConfig(level=logging.INFO)
print('Imports complete.')

Imports complete.


In [2]:
# Cell 2: Data Loading Functions
def load_scifact_corpus(corpus_file):
    """
    Loads the Scifact corpus from a JSONL file.
    Each line is a JSON object containing '_id' and 'text'.
    Returns lists of document IDs and document texts.
    """
    doc_ids = []
    doc_texts = []
    with open(corpus_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            doc_id = data["_id"]
            text_str = data.get("text", "")
            doc_ids.append(doc_id)
            doc_texts.append(text_str)
    return doc_ids, doc_texts

def load_scifact_queries(queries_file):
    """
    Loads the Scifact queries from a JSONL file.
    Each line is a JSON object containing '_id' and 'text'.
    Returns lists of query IDs and query texts.
    """
    q_ids = []
    q_texts = []
    with open(queries_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            q_ids.append(data["_id"])
            q_texts.append(data.get("text", ""))
    return q_ids, q_texts

def load_qrels(qrels_file):
    """
    Loads the qrels from a TSV file.
    The first line is a header: 'query-id    corpus-id    score'.
    Returns a dictionary: qrels[query_id][doc_id] = relevance_score.
    """
    qrels = {}
    with open(qrels_file, "r", encoding="utf-8") as f:
        # Skip the header line:
        next(f)
        for line in f:
            line = line.strip()
            if not line:
                continue
            qid, did, rel = line.split()
            rel = int(rel)
            if qid not in qrels:
                qrels[qid] = {}
            qrels[qid][did] = rel
    return qrels

print('Data loading functions ready.')


Data loading functions ready.


In [31]:
# Cell 3: Metric Evaluation Function
# We'll define a single function we can call multiple times,
# whether for the bi-encoder alone or after cross-encoder reranking.

def evaluate_results(qrels, results, k_values=None):
    """
    Evaluates retrieval results using pytrec_eval.
    results: dict of dict => results[qid][doc_id] = score
    qrels: dict of dict => qrels[qid][doc_id] = relevance
    k_values: list of cutoff values (e.g. [5,10,20])
    """
    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"map", "ndcg_cut", "recall", "P"})
    scores = evaluator.evaluate(results)

    if k_values is None:
        # If k_values not specified, use typical cutoffs
        k_values = [5, 10, 20, 30, 100]

    ndcg_res = {}
    recall_res = {}
    prec_res = {}
    mrr_vals = []

    # Compute mean average precision (MAP) across all queries
    map_vals = [scores[qid]["map"] for qid in scores]
    map_res = np.mean(map_vals)

    # Compute NDCG@k, Recall@k, P@k
    for k in k_values:
        ndcg_vals = []
        recall_vals = []
        prec_vals = []
        for qid in scores:
            ndcg_key = f"ndcg_cut_{k}"
            recall_key = f"recall_{k}"
            prec_key = f"P_{k}"
            # If the metric isn't found for some reason, skip
            if ndcg_key in scores[qid]:
                ndcg_vals.append(scores[qid][ndcg_key])
            if recall_key in scores[qid]:
                recall_vals.append(scores[qid][recall_key])
            if prec_key in scores[qid]:
                prec_vals.append(scores[qid][prec_key])

        ndcg_res[f"NDCG@{k}"] = np.mean(ndcg_vals) if ndcg_vals else 0.0
        recall_res[f"Recall@{k}"] = np.mean(recall_vals) if recall_vals else 0.0
        prec_res[f"P@{k}"] = np.mean(prec_vals) if prec_vals else 0.0

    # Compute MRR (Mean Reciprocal Rank)
    for qid in results:
        ranked_docs = sorted(results[qid].items(), key=lambda x: x[1], reverse=True)
        for rank, (doc_id, _) in enumerate(ranked_docs, start=1):
            if qrels.get(qid, {}).get(doc_id, 0) > 0:  # Relevant document
                mrr_vals.append(1 / rank)
                break
        else:
            mrr_vals.append(0)  # No relevant document found in results

    mrr_res = np.mean(mrr_vals)

    return ndcg_res, map_res, recall_res, prec_res, mrr_res

print('Evaluation function ready.')

Evaluation function ready.


In [3]:
# Cell 4: Load the SciFact Data
corpus_file = "datasets/scifact/corpus.jsonl"
queries_file = "datasets/scifact/queries.jsonl"
qrels_file = "datasets/scifact/qrels/test.tsv"

doc_ids, doc_texts = load_scifact_corpus(corpus_file)
query_ids, query_texts = load_scifact_queries(queries_file)
qrels = load_qrels(qrels_file)

print(f"Loaded {len(doc_ids)} documents and {len(query_ids)} queries.")
print(f"Unique qrels: {len(qrels)}")

Loaded 5183 documents and 1109 queries.
Unique qrels: 300


In [4]:
# Cell 5: Filter queries to those that appear in qrels
filtered_query_ids = [qid for qid in query_ids if qid in qrels]
filtered_query_texts = [query_texts[query_ids.index(qid)] for qid in filtered_query_ids]

print(f"Filtered queries for evaluation: {len(filtered_query_ids)}")


Filtered queries for evaluation: 300


# Part II: Bi-Encoder Retrieval with FAISS HNSW

We'll:
1. Load a SentenceTransformer (TAS-B) model.
2. Encode all documents with prefix **"passage: "**.
3. Build a FAISS HNSW index.
4. Encode queries with prefix **"query: "**.
5. Perform approximate nearest neighbor (ANN) search.
6. Evaluate results.


In [5]:
# Cell 6: Bi-Encoder Setup
# Check if CUDA is available
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")

# Set the device to "cuda" if available, otherwise fallback to "cpu"
device = "cuda" if cuda_available else "cpu"

# Load the Bi-Encoder model on the specified device
bi_model = SentenceTransformer("sentence-transformers/msmarco-distilbert-base-tas-b", device=device)
print(f"Bi-Encoder model loaded on device: {device}")

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/msmarco-distilbert-base-tas-b


CUDA available: True
Bi-Encoder model loaded on device: cuda


In [6]:
# Cell 7: Encode Corpus with E5 (prefix 'passage: ')
# We'll store these embeddings in a variable so we can reuse or inspect them.
corpus_embeddings = bi_model.encode(
    doc_texts,  # Directly use the text without adding "passage: " prefix
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)

print("Corpus embeddings shape:", corpus_embeddings.shape)

Batches:   0%|          | 0/81 [00:00<?, ?it/s]

Corpus embeddings shape: (5183, 768)


In [10]:
# Directory path to save embeddings
embedding_dir = "./embeddings"
os.makedirs(embedding_dir, exist_ok=True)  # Create directory if it doesn't exist

# Save the embeddings as a .npy file
embedding_file = os.path.join(embedding_dir, "tasb_corpus_embeddings_scifact.npy")
np.save(embedding_file, corpus_embeddings)

print(f"Corpus embeddings saved to: {embedding_file}")

Corpus embeddings saved to: ./embeddings/tasb_corpus_embeddings_scifact.npy


In [11]:
# Save the embeddings as an HDF5 file
hdf5_file = os.path.join(embedding_dir, "tasb_corpus_embeddings_scifact.h5")
with h5py.File(hdf5_file, "w") as f:
    # Use gzip compression to reduce file size
    f.create_dataset("embeddings", data=corpus_embeddings, compression="gzip")
    f.create_dataset("shape", data=corpus_embeddings.shape)  # Store shape for later use

print(f"Corpus embeddings saved in HDF5 format to: {hdf5_file}")

Corpus embeddings saved in HDF5 format to: ./embeddings/tasb_corpus_embeddings_scifact.h5


In [17]:
# Cell 8: Build the FAISS HNSW Index
dimension = corpus_embeddings.shape[1]
M = 15  # number of connections per node in HNSW
efConstruction = 300
efSearch = 1000

logging.info("Building HNSW index...")
index_hnsw = faiss.IndexHNSWFlat(dimension, M)
index_hnsw.hnsw.efConstruction = efConstruction
index_hnsw.hnsw.efSearch = efSearch

# Add corpus embeddings to index
index_hnsw.add(corpus_embeddings)

# If GPU available, move index to GPU
if torch.cuda.is_available():
    logging.info("Moving HNSW index to GPU...")
    faiss_res = faiss.StandardGpuResources()
    index_hnsw = faiss.index_cpu_to_gpu(faiss_res, 0, index_hnsw)

print("HNSW index ready.")

INFO:root:Building HNSW index...
INFO:root:Moving HNSW index to GPU...


HNSW index ready.


In [18]:
# Cell 9: Encode Queries with E5 (prefix 'query: ')
# We'll store query embeddings so we can re-run or examine.

query_embeddings = bi_model.encode(
    filtered_query_texts, # Directly use the text without adding "query: " prefix
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)
print("Query embeddings shape:", query_embeddings.shape)

Batches:   0%|          | 0/5 [00:00<?, ?it/s]

Query embeddings shape: (300, 768)


## Perform ANN Search and Evaluate (Bi-Encoder Only)

We'll retrieve the top-k docs from FAISS for each query, then convert them to a `results` dict for evaluation.

In [19]:
# Cell 10: Bi-Encoder ANN Search
top_k = 500
logging.info("Searching query embeddings in HNSW index...")
D, I = index_hnsw.search(query_embeddings, top_k)

# Convert results to pytrec_eval format
# results[qid][docid] = score
bi_encoder_results = {}
for q_idx, qid in enumerate(filtered_query_ids):
    bi_encoder_results[qid] = {}
    for rank in range(top_k):
        doc_idx = I[q_idx, rank]
        # for a distance D[q_idx, rank], we can invert or just use negative distance
        # Some prefer 1/dist, or -dist, or use it directly.
        # We'll do a small trick: score = 1 / distance.
        # If distance is 0, we can handle that or set a high value.
        dist = D[q_idx, rank]
        score = 1/float(dist) if dist != 0 else 1e9
        doc_id = doc_ids[doc_idx]
        bi_encoder_results[qid][doc_id] = score

print("Bi-Encoder top-k results ready.")

INFO:root:Searching query embeddings in HNSW index...


Bi-Encoder top-k results ready.


In [20]:
# Print the number of retrieved documents for the first 5 queries along with their scores
print("=== Query Results ===")
for query_id, doc_scores in list(bi_encoder_results.items())[:5]:  # Limit to first 5 queries
    num_docs = len(doc_scores)  # Number of documents retrieved for this query
    print(f"Query ID: {query_id}")
    print(f"Number of retrieved documents: {num_docs}")
    print(f"Scores (truncated): {str(doc_scores)[:90]}")
    print("-" * 50)

=== Query Results ===
Query ID: 1
Number of retrieved documents: 500
Scores (truncated): {'10743131': 2.1284670717937657, '17463469': 2.1221479853135135, '4346436': 2.116672482804
--------------------------------------------------
Query ID: 3
Number of retrieved documents: 500
Scores (truncated): {'1388704': 3.019901084527792, '23389795': 2.8117035885503445, '8759633': 2.65501032389616
--------------------------------------------------
Query ID: 5
Number of retrieved documents: 500
Scores (truncated): {'32969964': 2.2264597741480254, '17333231': 2.1708924683216106, '20945963': 2.16720962467
--------------------------------------------------
Query ID: 13
Number of retrieved documents: 500
Scores (truncated): {'17450673': 2.8588590901542905, '1263446': 2.7154921306854023, '31942055': 2.567323036680
--------------------------------------------------
Query ID: 36
Number of retrieved documents: 500
Scores (truncated): {'11705328': 2.6144078556202226, '42441846': 2.5070138176866896, '3742488

In [33]:
# Cell 11: Evaluate Bi-Encoder Results
k_values = [5, 10, 20, 30, 100]
ndcg_bi, map_bi, recall_bi, prec_bi, mrr_bi = evaluate_results(qrels, bi_encoder_results, k_values)

print("=== Bi-Encoder (TAS-B) Results ===")
# print("NDCG@k:", ndcg_bi)
# print("MAP:", map_bi)
# print("Recall@k:", recall_bi)
# print("Precision@k:", prec_bi)
print("NDCG@k:")
for k, v in ndcg_bi.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_bi:.4f}")

print("\nRecall@k:")
for k, v in recall_bi.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_bi.items():
    print(f"  {k:<8}: {v:.4f}")
    
print(f"\nMRR: {mrr_bi:.4f}")

=== Bi-Encoder (TAS-B) Results ===
NDCG@k:
  NDCG@5 : 0.4822
  NDCG@10: 0.5048
  NDCG@20: 0.5258
  NDCG@30: 0.5311
  NDCG@100: 0.5464

MAP: 0.4611

Recall@k:
  Recall@5: 0.5852
  Recall@10: 0.6500
  Recall@20: 0.7276
  Recall@30: 0.7509
  Recall@100: 0.8377

Precision@k:
  P@5     : 0.1293
  P@10    : 0.0730
  P@20    : 0.0413
  P@30    : 0.0286
  P@100   : 0.0095

MRR: 0.4786


# Part III: Cross-Encoder Reranking

We'll now demonstrate how to:
1. Take the **top-k** documents from the **bi-encoder** results.
2. Rerank them with a **CrossEncoder**.
3. Evaluate again using the same `evaluate_results` function.


In [22]:
# Cell 12: Define a function to rerank with a CrossEncoder

def cross_encode_rerank(cross_encoder_model, 
                        query_ids_list, query_texts_list,
                        topk_results, corpus_text_map,
                        rerank_top_k=500, final_top_k=100):
    """
    cross_encoder_model: a CrossEncoder from sentence-transformers
    query_ids_list: list of query IDs (strings)
    query_texts_list: list of corresponding query texts
    topk_results: dict => topk_results[qid][doc_id] = some bi-encoder score
    corpus_text_map: a dict doc_id -> doc_text
    rerank_top_k: how many documents to consider for reranking from the bi-encoder results
    final_top_k: how many documents to keep after cross-encoder reranking

    returns: dict => reranked_results[qid][doc_id] = cross-encoder score
    """
    reranked_results = {}

    # Use tqdm to show total progress for all queries
    with tqdm(total=len(query_ids_list), desc="Cross-encoding queries", unit="query", ncols=80) as pbar:
        for idx, qid in enumerate(query_ids_list):
            # Get top rerank_top_k document candidates for the query
            doc_candidates = sorted(topk_results[qid].items(), key=lambda x: x[1], reverse=True)[:rerank_top_k]
            doc_ids = [doc_id for doc_id, _ in doc_candidates]
            query_text = query_texts_list[idx]

            # Prepare query-document pairs for the cross-encoder
            pair_texts = [(f"query: {query_text}", f"passage: {corpus_text_map[doc_id]}") for doc_id in doc_ids]

            # Perform cross-encoder inference, ensure no additional progress bars are shown
            scores = cross_encoder_model.predict(pair_texts, show_progress_bar=False)

            # Combine document IDs with their Cross-Encoder scores
            reranked_doc_scores = {doc_ids[i]: float(scores[i]) for i in range(len(doc_ids))}

            # Sort reranked results by Cross-Encoder scores and keep final_top_k
            reranked_results[qid] = dict(sorted(reranked_doc_scores.items(), key=lambda x: x[1], reverse=True)[:final_top_k])

            # Update the progress bar
            pbar.update(1)

    return reranked_results

print('Cross-encoder rerank function defined.')

Cross-encoder rerank function defined.


### Prepare Data Structures for Reranking

- We already have `bi_encoder_results`, which is `dict[qid][doc_id] = score`.
- We need a quick way to map doc IDs to actual text for the CrossEncoder.


In [23]:
# Cell 13: Create a dictionary doc_id -> doc_text
doc_id_to_text = {}
for i, did in enumerate(doc_ids):
    doc_id_to_text[did] = doc_texts[i]

print(f"doc_id_to_text dictionary has {len(doc_id_to_text)} entries.")

doc_id_to_text dictionary has 5183 entries.


## 3.1: MiniLM CrossEncoder Reranking

### Load a CrossEncoder
Choose any cross-encoder checkpoint from Hugging Face or sentence-transformers.
For demonstration, we'll pick something like `'cross-encoder/ms-marco-MiniLM-L-6-v2'`.


In [24]:
# Cell 14: CrossEncoder initialization
cross_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"  # or any other suitable model
cross_encoder = CrossEncoder(cross_model_name, device=device)
print("CrossEncoder loaded.")

CrossEncoder loaded.


### Rerank the top-k results from the Bi-Encoder

We'll use the cross-encoder function defined above and pass in the top-k results from the Bi-Encoder.


In [25]:
# Cell 15: Rerank top-k with CrossEncoder

# Parameters to control the number of documents to rerank and return
rerank_top_k = 300  # Number of documents to consider for reranking
final_top_k = 100    # Number of documents to keep after reranking

logging.info(f"Reranking top {rerank_top_k} candidates with cross-encoder, keeping top {final_top_k}...")

cross_encoder_results_minilm = cross_encode_rerank(
    cross_encoder_model=cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,  # Bi-Encoder results
    corpus_text_map=doc_id_to_text,   # Document text mapping
    rerank_top_k=rerank_top_k,        # Candidates to rerank
    final_top_k=final_top_k           # Final results to keep
)

print("Cross-encoder reranking complete.")

INFO:root:Reranking top 300 candidates with cross-encoder, keeping top 100...
Cross-encoding queries: 100%|██████████████| 300/300 [13:10<00:00,  2.63s/query]

Cross-encoder reranking complete.





### Evaluate Cross-Encoder Reranked Results


In [34]:
# Cell 16: Evaluate the cross-encoder-based results
k_values = [5, 10, 20, 30, 100]
ndcg_ce, map_ce, recall_ce, prec_ce, mrr_ce = evaluate_results(qrels, cross_encoder_results_minilm, k_values)

# print("=== Cross-Encoder (Rerank) Results ===")
# print("NDCG@k:", ndcg_ce)
# print("MAP:", map_ce)
# print("Recall@k:", recall_ce)
# print("Precision@k:", prec_ce)

# Print formatted results
print("=== MiniLM Cross-Encoder (Rerank) Results ===")
print("NDCG@k:")
for k, v in ndcg_ce.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ce:.4f}")

print("\nRecall@k:")
for k, v in recall_ce.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ce.items():
    print(f"  {k:<8}: {v:.4f}")
    
print(f"\nMRR: {mrr_ce:.4f}")

=== MiniLM Cross-Encoder (Rerank) Results ===
NDCG@k:
  NDCG@5 : 0.6197
  NDCG@10: 0.6423
  NDCG@20: 0.6539
  NDCG@30: 0.6600
  NDCG@100: 0.6676

MAP: 0.6016

Recall@k:
  Recall@5: 0.7034
  Recall@10: 0.7689
  Recall@20: 0.8119
  Recall@30: 0.8397
  Recall@100: 0.8793

Precision@k:
  P@5     : 0.1547
  P@10    : 0.0863
  P@20    : 0.0458
  P@30    : 0.0317
  P@100   : 0.0100

MRR: 0.6154


## 3.2: Electra CrossEncoder Reranking
To compare the performance of different cross-encoders, here we use `'cross-encoder/ms-marco-electra-base'`.


In [29]:
# Cell 14: CrossEncoder initialization
cross_model_name = "cross-encoder/ms-marco-electra-base"  # or any other suitable model
cross_encoder = CrossEncoder(cross_model_name, device=device)
print("CrossEncoder loaded.")

CrossEncoder loaded.


In [30]:
# Cell 15: Rerank top-k with CrossEncoder

# Parameters to control the number of documents to rerank and return
rerank_top_k = 300  # Number of documents to consider for reranking
final_top_k = 100    # Number of documents to keep after reranking

logging.info(f"Reranking top {rerank_top_k} candidates with cross-encoder, keeping top {final_top_k}...")

cross_encoder_results_electra = cross_encode_rerank(
    cross_encoder_model=cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,  # Bi-Encoder results
    corpus_text_map=doc_id_to_text,   # Document text mapping
    rerank_top_k=rerank_top_k,        # Candidates to rerank
    final_top_k=final_top_k           # Final results to keep
)

print("Cross-encoder reranking complete.")

INFO:root:Reranking top 300 candidates with cross-encoder, keeping top 100...
Cross-encoding queries: 100%|██████████████| 300/300 [43:38<00:00,  8.73s/query]

Cross-encoder reranking complete.





In [35]:
# Cell 16: Evaluate the cross-encoder-based results
k_values = [5, 10, 20, 30, 100]
ndcg_ce, map_ce, recall_ce, prec_ce, mrr_ce = evaluate_results(qrels, cross_encoder_results_electra, k_values)

# print("=== Cross-Encoder (Rerank) Results ===")
# print("NDCG@k:", ndcg_ce)
# print("MAP:", map_ce)
# print("Recall@k:", recall_ce)
# print("Precision@k:", prec_ce)

# Print formatted results
print("=== Electra Cross-Encoder (Rerank) Results ===")
print("NDCG@k:")
for k, v in ndcg_ce.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ce:.4f}")

print("\nRecall@k:")
for k, v in recall_ce.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ce.items():
    print(f"  {k:<8}: {v:.4f}")
    
print(f"\nMRR: {mrr_ce:.4f}")

=== Electra Cross-Encoder (Rerank) Results ===
NDCG@k:
  NDCG@5 : 0.6138
  NDCG@10: 0.6382
  NDCG@20: 0.6490
  NDCG@30: 0.6576
  NDCG@100: 0.6645

MAP: 0.5996

Recall@k:
  Recall@5: 0.6899
  Recall@10: 0.7586
  Recall@20: 0.8003
  Recall@30: 0.8403
  Recall@100: 0.8760

Precision@k:
  P@5     : 0.1527
  P@10    : 0.0867
  P@20    : 0.0457
  P@30    : 0.0318
  P@100   : 0.0100

MRR: 0.6123


## 3.3: DeBERTa CrossEncoder Reranking
To compare the performance of different cross-encoders, here we use `'cross-encoder/ms-marco-Microsoft/DeBERTa-V3-Large'`.


In [None]:
# Cell 14: CrossEncoder initialization
cross_model_name = "cross-encoder/ms-marco-electra-base"  # or any other suitable model
cross_encoder = CrossEncoder(cross_model_name, device=device)
print("CrossEncoder loaded.")

In [None]:
# Cell 15: Rerank top-k with CrossEncoder

# Parameters to control the number of documents to rerank and return
rerank_top_k = 300  # Number of documents to consider for reranking
final_top_k = 100    # Number of documents to keep after reranking

logging.info(f"Reranking top {rerank_top_k} candidates with cross-encoder, keeping top {final_top_k}...")

cross_encoder_results_electra = cross_encode_rerank(
    cross_encoder_model=cross_encoder,
    query_ids_list=filtered_query_ids,
    query_texts_list=filtered_query_texts,
    topk_results=bi_encoder_results,  # Bi-Encoder results
    corpus_text_map=doc_id_to_text,   # Document text mapping
    rerank_top_k=rerank_top_k,        # Candidates to rerank
    final_top_k=final_top_k           # Final results to keep
)

print("Cross-encoder reranking complete.")

In [None]:
# Cell 16: Evaluate the cross-encoder-based results
k_values = [5, 10, 20, 30, 100]
ndcg_ce, map_ce, recall_ce, prec_ce = evaluate_results(qrels, cross_encoder_results, k_values)

# print("=== Cross-Encoder (Rerank) Results ===")
# print("NDCG@k:", ndcg_ce)
# print("MAP:", map_ce)
# print("Recall@k:", recall_ce)
# print("Precision@k:", prec_ce)

# Print formatted results
print("=== Electra Cross-Encoder (Rerank) Results ===")
print("NDCG@k:")
for k, v in ndcg_ce.items():
    print(f"  {k:<7}: {v:.4f}")

print(f"\nMAP: {map_ce:.4f}")

print("\nRecall@k:")
for k, v in recall_ce.items():
    print(f"  {k:<8}: {v:.4f}")

print("\nPrecision@k:")
for k, v in prec_ce.items():
    print(f"  {k:<8}: {v:.4f}")