# Information Retrieval: Bi-encoder → FAISS → Cross-encoder Rerank (EN/VN)
**Objective/Mục tiêu**: Build a toy pipeline: encode corpus with bi-encoder, retrieve top-k via FAISS, rerank with cross-encoder. Compare Recall@k/MRR.



# !pip install -q sentence-transformers faiss-cpu datasets scikit-learn


In [None]:

from sentence_transformers import SentenceTransformer, CrossEncoder, util
from datasets import load_dataset
import faiss, numpy as np, json, time

# Load a small corpus (AG News titles as toy docs)
ds = load_dataset("ag_news")
corpus = [x["text"] for x in ds["train"].select(range(10000))]
queries = [x["text"] for x in ds["test"].select(range(200))]

bi_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
doc_embs = bi_model.encode(corpus, batch_size=128, convert_to_numpy=True, show_progress_bar=True, normalize_embeddings=True)

index = faiss.IndexFlatIP(doc_embs.shape[1])
index.add(doc_embs)

def search(q, k=20):
    qv = bi_model.encode([q], convert_to_numpy=True, normalize_embeddings=True)
    scores, idxs = index.search(qv, k)
    return idxs[0].tolist(), scores[0].tolist()

# Reranker (cross-encoder)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

def rerank(q, doc_ids, topk=5):
    pairs = [(q, corpus[i]) for i in doc_ids]
    scores = reranker.predict(pairs)
    order = np.argsort(scores)[::-1][:topk]
    return [doc_ids[i] for i in order], [float(scores[i]) for i in order]

# Toy evaluation: treat same-class titles as pseudo-relevant (weak signal)
labels = [x["label"] for x in ds["train"].select(range(10000))]
test_pairs = list(ds["test"].select(range(200)))

def mrr_at_k(k=20):
    rr = []
    for sample in test_pairs:
        q = sample["text"]
        y = sample["label"]
        idxs, _ = search(q, k=k)
        # find first retrieved with same class label (weak heuristic)
        hit = -1
        for rank, i in enumerate(idxs, start=1):
            if labels[i] == y:
                hit = rank; break
        rr.append(0 if hit == -1 else 1.0/hit)
    return float(np.mean(rr))

mrr_bi = mrr_at_k(k=20)

# Rerank only top-20 then compute MRR@5 on reranked subset
def mrr_after_rerank(k_in=20, k_out=5):
    rr = []
    for sample in test_pairs:
        q = sample["text"]; y = sample["label"]
        idxs, _ = search(q, k=k_in)
        rr_ids, _ = rerank(q, idxs, topk=k_out)
        hit = -1
        for rank, i in enumerate(rr_ids, start=1):
            if labels[i] == y:
                hit = rank; break
        rr.append(0 if hit == -1 else 1.0/hit)
    return float(np.mean(rr))

mrr_rr = mrr_after_rerank()
print({"MRR@20_bi": mrr_bi, "MRR@5_after_rerank": mrr_rr})



# Single query demo
q = "Apple announces new iPhone with improved camera"
idxs, scs = search(q, k=10)
rr_ids, rr_scores = rerank(q, idxs, topk=5)
print("Query:", q)
print("Top-5 after rerank:")
for i, s in zip(rr_ids, rr_scores):
    print(s, "||", corpus[i][:120].replace("
"," "))
