# 1) Hybrid Retrieval: BM25 + FAISS + Torch Ranker (<100ms-style)

In [None]:

%%capture
!pip -q install --upgrade pip
!pip -q install datasets transformers sentence-transformers faiss-cpu rank-bm25 torchmetrics scikit-learn lightgbm langdetect unidecode pandas matplotlib tqdm nltk

In [None]:

import os, numpy as np, pandas as pd, torch, time
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
import faiss
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42; np.random.seed(SEED); torch.manual_seed(SEED)
CONFIG = {"language":"en","N_DOCS":30000,"N_QUERIES":3000,"TOPK_BM25":200,"TOPK_ANN":200,"FUSION_K":300}

In [None]:

def load_amz(lang="en", n_docs=30000, n_queries=3000):
    ds = load_dataset("amazon_reviews_multi", lang, split="train")
    df = ds.to_pandas()[["product_id","review_title","review_body"]].dropna()
    g = df.groupby("product_id")
    prod = g.agg({"review_title":lambda s:" | ".join(s.head(10).astype(str)),
                  "review_body":lambda s:" ".join(s.head(5).astype(str))}).reset_index()
    prod["doc_text"] = (prod["review_title"].fillna("")+" "+prod["review_body"].fillna("")).str.strip()
    prod = prod[prod["doc_text"].str.len()>16].sample(frac=1, random_state=SEED).head(n_docs).reset_index(drop=True)
    pids = set(prod["product_id"])
    q = df[df["product_id"].isin(pids)][["review_title","product_id"]].dropna()
    q = q.rename(columns={"review_title":"query","product_id":"relevant_pid"}).drop_duplicates().sample(frac=1, random_state=SEED).head(n_queries).reset_index(drop=True)
    return prod[["product_id","doc_text"]], q
docs_df, queries_df = load_amz(CONFIG["language"], CONFIG["N_DOCS"], CONFIG["N_QUERIES"])

In [None]:

def tok(txt): 
    txt = str(txt).lower().replace("\n"," ").strip()
    return [t for t in txt.split() if t]
corpus_tok = [tok(t) for t in docs_df["doc_text"].tolist()]
bm25 = BM25Okapi(corpus_tok)
dense = SentenceTransformer("intfloat/multilingual-e5-base", device=device)

In [None]:

def encode(texts, bs=128):
    vecs = []
    for i in range(0, len(texts), bs):
        emb = dense.encode(texts[i:i+bs], batch_size=bs, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=False)
        vecs.append(emb)
    return np.vstack(vecs).astype("float32")
doc_emb = encode(docs_df["doc_text"].tolist(), 128)
index = faiss.IndexHNSWFlat(doc_emb.shape[1], 32); index.hnsw.efConstruction = 200; index.hnsw.efSearch = 128
index.add(doc_emb)

In [None]:

def bm25_search(qs, k):
    out=[]
    for q in qs:
        s = bm25.get_scores(tok(q))
        top = np.argpartition(s, -k)[-k:]
        top = top[np.argsort(-s[top])]
        out.append((top, s[top]))
    return out
def ann_search(qs, k):
    qv = encode(qs, 128)
    sco, idx = index.search(qv, k)
    return idx, sco
def rrf(bm_idx, ann_idx, k=300, K=60):
    fused=[]
    for qi in range(len(bm_idx)):
        ranks={}
        for r,d in enumerate(bm_idx[qi]): ranks[d]=ranks.get(d,0)+1.0/(K+r+1)
        for r,d in enumerate(ann_idx[qi]): ranks[d]=ranks.get(d,0)+1.0/(K+r+1)
        items=sorted(ranks.items(), key=lambda x:-x[1])[:k]
        fused.append(np.array([i for i,_ in items], dtype=int))
    return fused

In [None]:

train_q, dev_q = train_test_split(queries_df, test_size=0.2, random_state=42)
def make_split(qdf):
    qs = qdf["query"].tolist()
    bm = bm25_search(qs, CONFIG["TOPK_BM25"]); bm_idx=[o[0] for o in bm]
    ann_idx, _ = ann_search(qs, CONFIG["TOPK_ANN"])
    fused = rrf(bm_idx, ann_idx, CONFIG["FUSION_K"])
    X, y = [], []
    for i,q in enumerate(qdf.itertuples()):
        feats=[]; labs=[]
        for d in fused[i]:
            bmr = np.where(bm_idx[i]==d)[0][0] if d in bm_idx[i] else 9999
            anr = np.where(ann_idx[i]==d)[0][0] if d in ann_idx[i] else 9999
            feats.append([bmr, anr])
            labs.append(1 if docs_df.iloc[d]["product_id"]==q.relevant_pid else 0)
        if sum(labs)==0: continue
        X.extend(feats); y.extend(labs)
    import torch
    X = torch.tensor(np.array(X), dtype=torch.float32); y = torch.tensor(np.array(y), dtype=torch.float32).unsqueeze(1)
    return X, y
Xtr, ytr = make_split(train_q); Xdv, ydv = make_split(dev_q)

In [None]:

import torch.nn as nn, torch.optim as optim, torch
class MLP(nn.Module):
    def __init__(self): super().__init__(); self.net=nn.Sequential(nn.Linear(2,32),nn.ReLU(),nn.Linear(32,1))
    def forward(self,x): return self.net(x)
model=MLP().to(device); opt=optim.AdamW(model.parameters(), lr=3e-3); loss=nn.BCEWithLogitsLoss()
for ep in range(3):
    model.train(); tot=0
    for i in range(0,len(Xtr),512):
        xb, yb=Xtr[i:i+512].to(device), ytr[i:i+512].to(device)
        opt.zero_grad(); p=model(xb); l=loss(p,yb); l.backward(); opt.step(); tot+=l.item()*len(xb)
    print("ep", ep, "loss", tot/len(Xtr))