In [1]:
import os, gc, math, random, sqlite3, ujson, numpy as np, torch
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple, Iterable
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.cross_encoder import CrossEncoder
from torch.utils.data import DataLoader
import faiss

random.seed(42); np.random.seed(42); torch.manual_seed(42)

@dataclass
class Cfg:
    lang: str = "rus"
    corpus_limit_docs: int = 120_000
    max_chars: int = 1000
    subset_queries: int = 20_000
    bi_model: str = "intfloat/multilingual-e5-base"
    bi_batch: int = 16
    bi_epochs: int = 2
    bi_max_len: int = 192
    topk: int = 200
    xe_model: str = "BAAI/bge-reranker-v2-m3"
    xe_batch: int = 2
    xe_epochs: int = 1
    xe_max_len: int = 128
    xe_cap: int = 6000
    hn_neg_per_pos: int = 1
    rerank_topk: int = 20

cfg = Cfg()
cfg

2025-11-08 17:47:31.597093: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762624051.619409     223 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762624051.626234     223 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

Cfg(lang='rus', corpus_limit_docs=120000, max_chars=1000, subset_queries=20000, bi_model='intfloat/multilingual-e5-base', bi_batch=16, bi_epochs=2, bi_max_len=192, topk=200, xe_model='BAAI/bge-reranker-v2-m3', xe_batch=2, xe_epochs=1, xe_max_len=128, xe_cap=6000, hn_neg_per_pos=1, rerank_topk=20)

In [2]:
HF_REPO_ID = "PaDaS-Lab/webfaq-retrieval"
HF_REVISION = "main"

def hf_path(lang: str, filename: str) -> str:
    return hf_hub_download(repo_id=HF_REPO_ID, repo_type="dataset",
                           filename=f"{lang}/{filename}", revision=HF_REVISION,
                           local_dir=None, local_dir_use_symlinks=True)

def iter_jsonl(local_path: str) -> Iterable[dict]:
    with open(local_path, "rb") as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue
            try:
                yield ujson.loads(line.decode("utf-8"))
            except:
                yield ujson.loads(line)

def load_queries(lang: str, split: str, limit: int | None = None) -> Dict[str, str]:
    p = hf_path(lang, f"queries-{split}.jsonl")
    items = []
    for row in iter_jsonl(p):
        qid = str(row.get("_id") or row.get("id"))
        q = (row.get("text") or row.get("query") or "").strip()
        if q:
            items.append((qid, q))
    random.shuffle(items)
    if limit:
        items = items[:limit]
    return dict(items)

def load_qrels(lang: str, split: str) -> Dict[str, Set[str]]:
    p = hf_path(lang, f"qrels-{split}.jsonl")
    m: Dict[str, Set[str]] = {}
    for row in iter_jsonl(p):
        qid = str(row.get("query-id") or row.get("_id") or row.get("qid"))
        did = str(row.get("corpus-id") or row.get("doc_id") or row.get("document_id") or row.get("pid"))
        if not qid or not did:
            continue
        if qid not in m:
            m[qid] = set()
        m[qid].add(did)
    return m

def build_sqlite_corpus(lang: str, out_path: str, limit_docs: int, max_chars: int) -> Tuple[str, int]:
    p = hf_path(lang, "corpus.jsonl")
    con = sqlite3.connect(out_path)
    cur = con.cursor()
    cur.execute("DROP TABLE IF EXISTS docs")
    cur.execute("CREATE TABLE docs (id TEXT PRIMARY KEY, text TEXT)")
    con.commit()
    added = 0
    batch = []
    with open(p, "rb") as f:
        for raw in f:
            if not raw.strip():
                continue
            row = ujson.loads(raw.decode("utf-8"))
            did = str(row.get("_id") or row.get("doc_id") or row.get("corpus-id") or row.get("id"))
            text = (row.get("text") or "").strip()
            title = (row.get("title") or "").strip()
            if title and title.lower() != "text":
                text = (title + " " + text).strip()
            if max_chars:
                text = text[:max_chars]
            batch.append((did, text))
            if len(batch) >= 1000:
                cur.executemany("INSERT OR REPLACE INTO docs(id,text) VALUES(?,?)", batch)
                con.commit()
                added += len(batch)
                batch = []
                if added >= limit_docs:
                    break
    if batch and added < limit_docs:
        cur.executemany("INSERT OR REPLACE INTO docs(id,text) VALUES(?,?)", batch)
        con.commit()
        added += len(batch)
    cur.close()
    con.close()
    return out_path, added


In [3]:
def _dcg(scores: List[int]) -> float:
    s = 0.0
    for i, rel in enumerate(scores, start=1):
        if rel > 0:
            s += (2 ** rel - 1) / math.log2(i + 1)
    return s

def _idcg(n: int) -> float:
    return _dcg([1] * n) if n > 0 else 0.0

def recall_at_k(run: Dict[str, List[str]], qrels: Dict[str, Set[str]], k: int = 10) -> float:
    a = b = 0
    for qid, rels in qrels.items():
        b += 1
        top = run.get(qid, [])[:k]
        a += 1 if any(d in rels for d in top) else 0
    return a / b if b else 0.0

def mrr_at_k(run: Dict[str, List[str]], qrels: Dict[str, Set[str]], k: int = 10) -> float:
    s = 0.0
    n = 0
    for qid, rels in qrels.items():
        n += 1
        top = run.get(qid, [])[:k]
        rr = 0.0
        for i, d in enumerate(top, start=1):
            if d in rels:
                rr = 1.0 / i
                break
        s += rr
    return s / n if n else 0.0

def ndcg_at_k(run: Dict[str, List[str]], qrels: Dict[str, Set[str]], k: int = 10) -> float:
    s = 0.0
    n = 0
    for qid, rels in qrels.items():
        n += 1
        top = run.get(qid, [])[:k]
        gains = [1 if d in rels else 0 for d in top]
        idcg = _idcg(min(len(rels), k))
        s += (_dcg(gains) / (idcg if idcg > 0 else 1.0))
    return s / n if n else 0.0


In [4]:
urls_ok = True
sqlite_path, added = build_sqlite_corpus(cfg.lang, "corpus.sqlite", cfg.corpus_limit_docs, cfg.max_chars)
con = sqlite3.connect(sqlite_path)
cur = con.cursor()
fetch = lambda did: (cur.execute("SELECT text FROM docs WHERE id=?", (did,)).fetchone() or [""])[0]

queries_tr = load_queries(cfg.lang, "train", limit=cfg.subset_queries)
queries_te = load_queries(cfg.lang, "test", limit=None)
qrels_tr = load_qrels(cfg.lang, "train")
qrels_te = load_qrels(cfg.lang, "test")

len(queries_tr), len(queries_te), added


(20000, 10000, 120000)

In [5]:
import random, gc, torch
cfg.bi_model   = "intfloat/multilingual-e5-small"
cfg.bi_epochs  = 1
cfg.bi_max_len = 128
cfg.bi_batch   = 32
BI_CAP         = 6000


In [6]:

import os, gc, random, torch
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
from sentence_transformers import SentenceTransformer, InputExample, losses

BI_CAP         = getattr(cfg, "bi_cap", 6000)
MODEL_NAME     = cfg.bi_model
EPOCHS         = cfg.bi_epochs
MAX_LEN        = cfg.bi_max_len  
INIT_BS        = cfg.bi_batch   
LR             = 2e-5
MAX_NORM       = 1.0     
SEED           = 42

def format_e5(text: str, is_query: bool) -> str:
    return ("query: " if is_query else "passage: ") + text.strip()

random.seed(SEED)
pairs = []
for qid, q in queries_tr.items():
    rel = qrels_tr.get(qid, [])
    if not rel:
        continue
    did = next(iter(rel))
    t = fetch(did)
    if t:
        pairs.append((q, t))
random.shuffle(pairs)
pairs = pairs[:BI_CAP]
print(f"pairs={len(pairs)}")

class PairDataset(Dataset):
    def __init__(self, pairs):
        self.ex = [InputExample(texts=[format_e5(q, True), format_e5(p, False)]) for q, p in pairs]
    def __len__(self): return len(self.ex)
    def __getitem__(self, idx): return self.ex[idx]

device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
print("device:", device)

model = SentenceTransformer(MODEL_NAME, device=device)
model.max_seq_length = MAX_LEN
loss_fn = losses.MultipleNegativesRankingLoss(model)

def train_with_bs(batch_size: int) -> bool:
    try:
        ds = PairDataset(pairs)
        dl = DataLoader(ds, shuffle=True, batch_size=batch_size, collate_fn=model.smart_batching_collate, drop_last=True)
        t_total = len(dl) * EPOCHS
        optimizer = AdamW(model.parameters(), lr=LR)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max(1, int(0.1*t_total)), num_training_steps=t_total)

        global_step = 0
        for epoch in range(EPOCHS):
            model.train()
            pbar = tqdm(dl, desc=f"Epoch {epoch+1}/{EPOCHS} | bs={batch_size}", leave=True)
            for features, labels in pbar:
                features = [{k: v.to(device) for k, v in f.items()} for f in features]
                labels = labels.to(device)

                loss = loss_fn(features, labels)
                loss.backward()
                clip_grad_norm_(model.parameters(), MAX_NORM)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

                global_step += 1
                pbar.set_postfix(loss=float(loss.detach().cpu()))
            pbar.close()
        return True
    except RuntimeError as e:
        msg = str(e).lower()
        if ("out of memory" in msg or ("cuda" in msg and "memory" in msg)) and device == "cuda":
            torch.cuda.empty_cache()
            return False
        raise

bs = int(INIT_BS)
while bs >= 8:
    print(f"try batch_size={bs}")
    ok = train_with_bs(bs)
    if ok:
        break
    print(f"OOM → batch_size: {bs} → {max(8, bs//2)}")
    bs = max(8, bs // 2)

if bs < 8:
    raise RuntimeError("Не удалось подобрать batch_size ≥ 8")

bi = model
gc.collect()
len(pairs)


pairs=6000
device: cuda




try batch_size=32


Epoch 1/1 | bs=32:   0%|          | 0/187 [00:00<?, ?it/s]

6000

In [7]:
import numpy as np, faiss, gc, torch
from tqdm.auto import tqdm

def l2norm(x: np.ndarray) -> np.ndarray:
    n = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
    return x / n

if torch.cuda.is_available():
    try:
        bi.to("cuda")
    except Exception:
        pass

DOC_CAP = getattr(cfg, "doc_cap", None)
ids_all = [r[0] for r in cur.execute("SELECT id FROM docs").fetchall()]
ids = ids_all[:DOC_CAP] if DOC_CAP else ids_all

sample_id = ids[0]
dim = bi.encode([format_e5(fetch(sample_id), False)],
                batch_size=1, convert_to_numpy=True, show_progress_bar=False).shape[1]

index = faiss.IndexFlatIP(dim)

B = getattr(cfg, "index_batch", 256 if torch.cuda.is_available() else 64)

print(f"Indexing on device={'cuda' if torch.cuda.is_available() else 'cpu'} | docs={len(ids)} | dim={dim} | start_batch={B}")
pbar = tqdm(total=len(ids), desc="Indexing (CUDA encode)", leave=True)

i = 0
while i < len(ids):
    curr = min(B, len(ids) - i)
    try:
        chunk = ids[i:i+curr]
        texts = [format_e5(fetch(_id), False) for _id in chunk]
        em = bi.encode(texts, batch_size=curr, convert_to_numpy=True, show_progress_bar=False).astype("float32")
        em = l2norm(em)
        index.add(em)
        i += curr
        pbar.update(curr)
        del texts, em, chunk
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if (i // max(1, B)) % 10 == 0:
            gc.collect()
    except RuntimeError as e:
        msg = str(e).lower()
        if ("out of memory" in msg or ("cuda" in msg and "memory" in msg)) and torch.cuda.is_available() and B > 16:
            torch.cuda.empty_cache()
            B = max(16, B // 2)
            print(f"OOM {i} → уменьшаю batch до {B}")
            continue
        raise
pbar.close()

(len(ids), dim, index.ntotal)


Indexing on device=cuda | docs=120000 | dim=384 | start_batch=256


Indexing (CUDA encode):   0%|          | 0/120000 [00:00<?, ?it/s]

(120000, 384, 120000)

In [13]:
from tqdm.auto import tqdm

hn_pool = {}
HN_BS = getattr(cfg, "hn_batch", max(32, cfg.bi_batch * 2))
print(f"Build hard negatives | queries={len(queries_tr)} | hn_batch={HN_BS} | topk={cfg.topk}")

tr_qids = list(queries_tr.keys())
for i in tqdm(range(0, len(tr_qids), HN_BS), desc="Search train (CUDA)"):
    qb = tr_qids[i:i+HN_BS]
    qt = [format_e5(queries_tr[q], True) for q in qb]
    qe = bi.encode(qt, batch_size=HN_BS, convert_to_numpy=True, show_progress_bar=False).astype("float32")
    qe = l2norm(qe)
    _, I = index.search(qe, cfg.topk)
    for j, idxs in enumerate(I):
        cand = [ids[k] for k in idxs if k >= 0]
        rel = qrels_tr.get(qb[j], set())
        hn_pool[qb[j]] = [d for d in cand if d not in rel]
    del qt, qe, I
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
len(hn_pool)


Build hard negatives | queries=20000 | hn_batch=64 | topk=200


Search train (CUDA):   0%|          | 0/313 [00:00<?, ?it/s]

20000

In [18]:
import os, gc, math, torch, random
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.nn import BCEWithLogitsLoss
from torch.nn.utils import clip_grad_norm_
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from tqdm.auto import tqdm
from sentence_transformers import InputExample

try:
    if torch.cuda.is_available() and 'bi' in globals():
        bi.to('cpu'); torch.cuda.empty_cache()
except Exception:
    pass

def vram_info():
    if not torch.cuda.is_available(): return (0,0)
    free, total = torch.cuda.mem_get_info()
    return (free//(1024**3), total//(1024**3))

XE_MODEL_BIG   = getattr(cfg, "xe_model", "BAAI/bge-reranker-v2-m3")
XE_MODEL_SMALL = "ai-forever/ruBert-tiny2"
XE_EPOCHS      = int(getattr(cfg, "xe_epochs", 1))
XE_MAX_LEN     = int(getattr(cfg, "xe_max_len", 128))
XE_BATCH       = max(1, int(getattr(cfg, "xe_batch", 2)))
XE_LR_BIG      = float(getattr(cfg, "xe_lr", 5e-6))
XE_LR_SMALL    = 2e-5
WARMUP_FR      = float(getattr(cfg, "xe_warmup_fr", 0.20))
WEIGHT_DECAY   = float(getattr(cfg, "xe_wd", 0.01))
GRAD_ACCUM     = int(getattr(cfg, "xe_grad_accum", 2))
CLIP_NORM      = 1.0
SAVE_DIR       = "./outputs/crossencoder"
SEED           = 42

random.seed(SEED); torch.manual_seed(SEED)

if "xe_examples" not in globals() or not len(xe_examples):
    raise RuntimeError("xe_examples пуст (сначала выполни ячейку 8)")

pos = sum(1 for e in xe_examples if float(e.label) >= 0.5)
neg = len(xe_examples) - pos
print(f"XE examples: {len(xe_examples)} | pos={pos} neg={neg} | VRAM free/total: {vram_info()} GB")

class XEDataset(Dataset):
    def __init__(self, examples, tokenizer, max_len: int):
        self.ex = examples; self.tok = tokenizer; self.max_len = max_len
    def __len__(self): return len(self.ex)
    def __getitem__(self, i):
        e = self.ex[i]; q, p = e.texts
        enc = self.tok(q, p, truncation=True, padding="max_length",
                       max_length=self.max_len, return_tensors="pt")
        out = {k: v.squeeze(0) for k,v in enc.items()}
        out["labels"] = torch.tensor(float(e.label), dtype=torch.float32)
        return out

def train_once(model_name, lr_init, max_len_init, bs_init):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.unk_token

    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)
    try: model.gradient_checkpointing_enable()
    except Exception: pass

    try:
        model.to(device)
    except RuntimeError as e:
        if "out of memory" in str(e).lower() and device == "cuda":
            torch.cuda.empty_cache()
            return False, f"OOM при переносе {model_name} на CUDA"
        raise

    bs = int(bs_init)
    ml = int(max_len_init)
    lr = float(lr_init)
    use_cuda = (device == "cuda")

    pos_w = torch.tensor([max(1.0, neg / max(1, pos))], device=device if use_cuda else "cpu")
    scaler = torch.cuda.amp.GradScaler(enabled=use_cuda)

    while True:
        try:
            ds = XEDataset(xe_examples, tokenizer, ml)
            dl = DataLoader(ds, shuffle=True, batch_size=bs, pin_memory=use_cuda)
            total = max(1, len(dl) * XE_EPOCHS)
            opt = AdamW([
                {"params":[p for n,p in model.named_parameters() if p.requires_grad and not any(x in n for x in ("bias","LayerNorm.weight","layer_norm.weight","norm.weight","ln.weight"))], "weight_decay": WEIGHT_DECAY},
                {"params":[p for n,p in model.named_parameters() if p.requires_grad and     any(x in n for x in ("bias","LayerNorm.weight","layer_norm.weight","norm.weight","ln.weight"))], "weight_decay": 0.0},
            ], lr=lr)
            warm = max(1, int(WARMUP_FR * total))
            sch = get_linear_schedule_with_warmup(opt, num_warmup_steps=warm, num_training_steps=total)
            loss_fn = BCEWithLogitsLoss(pos_weight=pos_w.to(device if use_cuda else "cpu"))

            print(f"CrossEncoder train → model={model_name}, device={device}, bs={bs}, acc={GRAD_ACCUM}, "
                  f"epochs={XE_EPOCHS}, max_len={ml}, steps/epoch={len(dl)}, lr={lr}, warmup={warm}")

            step_acc = 0
            for ep in range(1, XE_EPOCHS+1):
                model.train()
                running = 0.0
                pbar = tqdm(dl, desc=f"XE Epoch {ep}/{XE_EPOCHS}", leave=True)
                for step, batch in enumerate(pbar, 1):
                    batch = {k: v.to(device, non_blocking=use_cuda) for k,v in batch.items()}
                    if use_cuda:
                        with torch.cuda.amp.autocast():
                            out = model(input_ids=batch["input_ids"],
                                        attention_mask=batch["attention_mask"],
                                        token_type_ids=batch.get("token_type_ids", None))
                            logits = out.logits.squeeze(-1)
                            loss = loss_fn(logits, batch["labels"]) / GRAD_ACCUM
                        scaler.scale(loss).backward()
                    else:
                        out = model(input_ids=batch["input_ids"],
                                    attention_mask=batch["attention_mask"],
                                    token_type_ids=batch.get("token_type_ids", None))
                        logits = out.logits.squeeze(-1)
                        loss = loss_fn(logits, batch["labels"]) / GRAD_ACCUM
                        loss.backward()

                    step_acc += 1
                    if step_acc % GRAD_ACCUM == 0:
                        clip_grad_norm_(model.parameters(), CLIP_NORM)
                        if use_cuda:
                            scaler.step(opt); scaler.update()
                        else:
                            opt.step()
                        opt.zero_grad(set_to_none=True); sch.step()

                    running += float(loss.detach().cpu()) * GRAD_ACCUM
                    if step % 10 == 0:
                        pbar.set_postfix(loss=f"{running/step:.4f}")
                pbar.close()

            os.makedirs(SAVE_DIR, exist_ok=True)
            model.save_pretrained(SAVE_DIR)
            tokenizer.save_pretrained(SAVE_DIR)
            print(f"Saved to: {SAVE_DIR}")
            del dl, ds, model; gc.collect()
            if use_cuda: torch.cuda.empty_cache()
            return True, "ok"

        except RuntimeError as e:
            msg = str(e).lower()
            oom = ("out of memory" in msg) or ("cuda" in msg and "memory" in msg) or ("same device" in msg)
            if oom and use_cuda:
                torch.cuda.empty_cache()
                if bs > 1:
                    bs = max(1, bs // 2); print(f"OOM →  batch до {bs}"); continue
                if ml > 64:
                    ml = max(64, ml // 2); print(f"OOM → max_len до {ml}"); continue
                return False, f"OOM на CUDA при bs=1, max_len={ml}"
            raise

ok, msg = train_once(XE_MODEL_BIG, XE_LR_BIG, XE_MAX_LEN, XE_BATCH)
if not ok and torch.cuda.is_available():
    print("ruBert-tiny2 на CUDA для быстрого обучения:", msg)
    ok2, msg2 = train_once(XE_MODEL_SMALL, XE_LR_SMALL, min(96, XE_MAX_LEN), max(4, XE_BATCH))
    if not ok2:
        raise RuntimeError(f"Не удалось обучить XE даже на компактной модели: {msg2}")


XE examples: 6000 | pos=3000 neg=3000 | VRAM free/total: (13, 15) GB
CrossEncoder train → model=BAAI/bge-reranker-v2-m3, device=cuda, bs=2, acc=2, epochs=1, max_len=128, steps/epoch=3000, lr=5e-06, warmup=600


  scaler = torch.cuda.amp.GradScaler(enabled=use_cuda)


XE Epoch 1/1:   0%|          | 0/3000 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


✓ Saved to: ./outputs/crossencoder


In [19]:
from tqdm.auto import tqdm

run_te = {}
B_TEST = getattr(cfg, "test_batch", max(32, cfg.bi_batch * 2))
print(f"Search test | queries={len(queries_te)} | batch={B_TEST} | topk={cfg.topk}")

te_qids = list(queries_te.keys())
for i in tqdm(range(0, len(te_qids), B_TEST), desc="Search test (CUDA)"):
    qb = te_qids[i:i+B_TEST]
    qt = [format_e5(queries_te[q], True) for q in qb]
    qe = bi.encode(qt, batch_size=B_TEST, convert_to_numpy=True, show_progress_bar=False).astype("float32")
    qe = l2norm(qe)
    _, I = index.search(qe, cfg.topk)
    for j, idxs in enumerate(I):
        run_te[qb[j]] = [ids[k] for k in idxs if k >= 0]
    del qt, qe, I
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

bi_R10   = recall_at_k(run_te, qrels_te, 10)
bi_MRR10 = mrr_at_k(run_te, qrels_te, 10)
bi_nDCG10= ndcg_at_k(run_te, qrels_te, 10)
{"R@10": bi_R10, "MRR@10": bi_MRR10, "nDCG@10": bi_nDCG10}


Search test | queries=10000 | batch=64 | topk=200


Search test (CUDA):   0%|          | 0/157 [00:00<?, ?it/s]

{'R@10': 0.2497, 'MRR@10': 0.2036152380952377, 'nDCG@10': 0.2147046798742635}

In [20]:
import numpy as np
from tqdm.auto import tqdm
from sentence_transformers.cross_encoder import CrossEncoder

RERANK_TOPK = getattr(cfg, "rerank_topk", 20)
pred_bs      = getattr(cfg, "rerank_batch", 32)

xe_path_or_name = "./outputs/crossencoder"

dev = "cuda" if torch.cuda.is_available() else "cpu"
cross = CrossEncoder(xe_path_or_name, num_labels=1, device=dev, max_length=getattr(cfg, "xe_max_len", 128))

reranked = {}

print(f"Reranking → device={dev}, topk={RERANK_TOPK}, pred_batch={pred_bs}")
for qid, cand_ids in tqdm(run_te.items(), desc="Reranking (CUDA)"):
    subset = cand_ids[:RERANK_TOPK]
    pairs = [(queries_te[qid], fetch(cid)) for cid in subset]
    b = pred_bs
    while True:
        try:
            scores = cross.predict(pairs, batch_size=b).tolist()
            break
        except RuntimeError as e:
            msg = str(e).lower()
            if ("out of memory" in msg or ("cuda" in msg and "memory" in msg)) and dev == "cuda" and b > 4:
                torch.cuda.empty_cache()
                b = max(4, b // 2)
                continue
            raise
    order = np.argsort(-np.array(scores))
    reranked[qid] = [subset[i] for i in order]

pipe_R10   = recall_at_k(reranked, qrels_te, 10)
pipe_MRR10 = mrr_at_k(reranked, qrels_te, 10)
pipe_nDCG10= ndcg_at_k(reranked, qrels_te, 10)
{"R@10": pipe_R10, "MRR@10": pipe_MRR10, "nDCG@10": pipe_nDCG10}


Reranking → device=cuda, topk=20, pred_batch=32


Reranking (CUDA):   0%|          | 0/10000 [00:00<?, ?it/s]

{'R@10': 0.2589, 'MRR@10': 0.22455869047619045, 'nDCG@10': 0.23293820572715604}

In [21]:
import json, os

os.makedirs("outputs", exist_ok=True)
metrics = {
    "bi": {
        "R@10": float(bi_R10),
        "MRR@10": float(bi_MRR10),
        "nDCG@10": float(bi_nDCG10),
    },
    "pipeline": {
        "R@10": float(pipe_R10),
        "MRR@10": float(pipe_MRR10),
        "nDCG@10": float(pipe_nDCG10),
    },
}
with open("outputs/metrics.json", "w", encoding="utf-8") as f:
    json.dump(metrics, f, ensure_ascii=False, indent=2)
metrics


{'bi': {'R@10': 0.2497,
  'MRR@10': 0.2036152380952377,
  'nDCG@10': 0.2147046798742635},
 'pipeline': {'R@10': 0.2589,
  'MRR@10': 0.22455869047619045,
  'nDCG@10': 0.23293820572715604}}