# 4) Cross-Encoder Reranking + Distillation to Dual-Encoder

In [None]:

%%capture
!pip -q install --upgrade pip
!pip -q install datasets transformers sentence-transformers faiss-cpu rank-bm25 torchmetrics scikit-learn lightgbm langdetect unidecode pandas matplotlib tqdm nltk

In [None]:

import numpy as np, torch, faiss
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder, InputExample, losses
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:

train = load_dataset("ms_marco","v2.1", split="train[:1%]")
pairs = []
for r in train:
    q = r["query"]
    doc = r["wellFormedAnswers"][0] if r["wellFormedAnswers"] else (r["passages"]["passage_text"][0] if r["passages"]["passage_text"] else None)
    if doc: pairs.append((q, doc))
corpus = [d for (_,d) in pairs[:30000]]
queries = [q for (q,_) in pairs[:2000]]

In [None]:

student = SentenceTransformer("sentence-transformers/msmarco-distilbert-base-tas-b", device=device)
doc_vec = student.encode(corpus, batch_size=128, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=True).astype("float32")
q_vec = student.encode(queries, batch_size=128, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=True).astype("float32")
index = faiss.IndexFlatIP(doc_vec.shape[1]); index.add(doc_vec)
scores, idx = index.search(q_vec, 50)

In [None]:

ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device=device)
teacher_pairs = []
for i,q in enumerate(queries):
    for j in idx[i]:
        teacher_pairs.append([q, corpus[j]])
teacher_scores = ce.predict(teacher_pairs, batch_size=64, show_progress_bar=True)

In [None]:

train_data = [InputExample(texts=p, label=float(s)) for p,s in zip(teacher_pairs, teacher_scores)]
loader = DataLoader(train_data, batch_size=64, shuffle=True)
loss = losses.CosineSimilarityLoss(student)
student.fit([(loader, loss)], epochs=1, warmup_steps=50, output_path="artifacts_ce_distilled_student")
print("Distillation complete. Saved model.")