In [6]:
import pandas as pd

pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)



df = pd.read_csv('./prior_results/halubench_splade.csv') 

df.columns




Index(['question', 'answer', 'passage', 'splade_ret_docs', 'MAP@3', 'NDCG@3',
       'MAP@5', 'NDCG@5', 'MAP@10', 'NDCG@10'],
      dtype='object')

In [7]:
df.iloc[0, 3]

'[{"doc_id": "391", "score": 20.2985, "snippet": "September 25, 2011 at Paul Brown Stadium, Cincinnati, Ohio (Blacked Out) Playing versus the San Francisco 49ers for the first time since 2003, the Bengals lost 13-8 with the smallest crowd for a Benga"}, {"doc_id": "432", "score": 19.8393, "snippet": "After winning on the road, the Bengals returned home for Game 2 against the Steelers.  The Bengals scored first in the first quarter when Randy Bullock kicked a 35-yard field goal to make it 3-0.  The"}, {"doc_id": "201", "score": 19.5582, "snippet": " After winning at home, the Bengals traveled down south to take on the Jaguars.  The Jags scored first in the first quarter when Josh Lambo kicked a 32-yard field goal to make it 3-0.  They would make"}, {"doc_id": "199", "score": 19.3586, "snippet": " The Steelers started their 2014 season at home against the Browns.  In the first quarter, the Steelers would score first when Shaun Suisham kicked a 36-yard field goal for a 3-0 lead.  However,

# Halubench Hybrid

In [13]:
import pandas as pd
import json
from typing import List, Dict, Any, Optional
import math
import numpy as np
import ast
import re
import string

# ---------- Safe checks ----------

def is_nan_like(x: Any) -> bool:
    if x is None:
        return True
    if isinstance(x, (list, tuple, np.ndarray, dict)):
        return False
    try:
        if pd.isna(x):
            return True
    except Exception:
        pass
    if isinstance(x, str) and x.strip() == "":
        return True
    return False

def to_python_list(x: Any) -> Any:
    if isinstance(x, np.ndarray):
        return x.tolist()
    return x

# ---------- Text normalization ----------

NOWIKI_PATTERN = re.compile(r"<nowiki>.*?</nowiki>", flags=re.IGNORECASE)
HTML_TAG_PATTERN = re.compile(r"<[^>]+>")  # crude strip of any HTML tags
WHITESPACE_PATTERN = re.compile(r"\s+")

def normalize_text(s: str) -> str:
    if not isinstance(s, str):
        s = str(s)
    s = s.lower()
    # remove <nowiki>...</nowiki>
    s = NOWIKI_PATTERN.sub(" ", s)
    # remove html tags if any
    s = HTML_TAG_PATTERN.sub(" ", s)
    # remove punctuation
    s = s.translate(str.maketrans("", "", string.punctuation))
    # collapse whitespace
    s = WHITESPACE_PATTERN.sub(" ", s).strip()
    return s

# ---------- Parsing helpers ----------

def parse_docs(cell: Any, id_key: str = "doc_id") -> List[Dict[str, Any]]:
    """
    Expect retrieved docs as:
      - JSON/Python list of dicts with {doc_id, score?, snippet?}
      - list of doc_id strings/ints
    Returns a list of dicts with doc_id and snippet (normalized_snippet added).
    """
    if is_nan_like(cell):
        return []

    cell = to_python_list(cell)

    if isinstance(cell, list):
        out = []
        for d in cell:
            if isinstance(d, dict) and id_key in d:
                doc = {**d, id_key: str(d[id_key])}
                snip = doc.get("snippet")
                if isinstance(snip, str):
                    doc["_norm_snippet"] = normalize_text(snip)
                else:
                    doc["_norm_snippet"] = ""
                out.append(doc)
            elif isinstance(d, (str, int)):
                out.append({id_key: str(d), "_norm_snippet": ""})
        return out

    if isinstance(cell, dict):
        if id_key in cell:
            doc = {**cell, id_key: str(cell[id_key])}
            snip = doc.get("snippet")
            doc["_norm_snippet"] = normalize_text(snip) if isinstance(snip, str) else ""
            return [doc]
        return []

    if isinstance(cell, str):
        s = cell.strip()
        if not s:
            return []
        try:
            data = json.loads(s)
            return parse_docs(data, id_key=id_key)
        except json.JSONDecodeError:
            try:
                data = ast.literal_eval(s)
                return parse_docs(data, id_key=id_key)
            except Exception:
                return []

    return []

def parse_passage_as_text_list(cell: Any) -> List[str]:
    """
    Passage holds ground-truth text(s). Accept:
      - single string (one passage)
      - list of strings
      - JSON or Python-literal string for either of the above
    Return list of normalized strings.
    """
    if is_nan_like(cell):
        return []

    cell = to_python_list(cell)

    texts = []
    if isinstance(cell, list):
        for x in cell:
            if isinstance(x, str):
                texts.append(normalize_text(x))
            elif isinstance(x, (int, float)):
                texts.append(normalize_text(str(x)))
            elif isinstance(x, dict):
                # if dict, try common keys
                for k in ("text", "passage", "snippet"):
                    if k in x and isinstance(x[k], str):
                        texts.append(normalize_text(x[k]))
                        break
    elif isinstance(cell, dict):
        for k in ("text", "passage", "snippet"):
            if k in cell and isinstance(cell[k], str):
                texts.append(normalize_text(cell[k]))
                break
    elif isinstance(cell, (str, int, float)):
        s = str(cell).strip()
        if not s:
            return []
        # Try JSON then literal_eval; if fails, treat as plain string
        try:
            data = json.loads(s)
            return parse_passage_as_text_list(data)
        except json.JSONDecodeError:
            try:
                data = ast.literal_eval(s)
                return parse_passage_as_text_list(data)
            except Exception:
                texts.append(normalize_text(s))
    return texts

# ---------- Simple text-based relevance matching ----------

def is_match(gt_text: str, cand_text: str) -> bool:
    """
    Heuristic match:
      - substring either way after normalization
    """
    if not gt_text or not cand_text:
        return False
    return (gt_text in cand_text) or (cand_text in gt_text)

def build_binary_relevance_from_text(
    retrieved: List[Dict[str, Any]],
    gt_texts: List[str]
) -> List[int]:
    """
    For each retrieved item (in order), produce 1 if any GT text matches the snippet text, else 0.
    """
    rel = []
    for d in retrieved:
        cand = d.get("_norm_snippet", "")
        hit = any(is_match(gt, cand) for gt in gt_texts)
        rel.append(1 if hit else 0)
    return rel

# ---------- RRF fusion ----------

def rrf_fuse_lists(
    lists: List[List[Dict[str, Any]]],
    k_rrf: int = 60,
    id_key: str = "doc_id"
) -> List[Dict[str, Any]]:
    scores = {}
    rep = {}
    for lst in lists:
        # If not guaranteed ranked, enable sorting:
        # lst = sorted(lst, key=lambda d: d.get("score", float("-inf")), reverse=True)
        for rank, d in enumerate(lst, start=1):
            did = d.get(id_key)
            if did is None:
                continue
            did = str(did)
            scores[did] = scores.get(did, 0.0) + 1.0 / (k_rrf + rank)
            if did not in rep:
                rep[did] = d

    fused = []
    for did, s in scores.items():
        base = rep[did]
        item = {
            "doc_id": did,
            "rrf_score": float(s),
        }
        if "snippet" in base:
            item["snippet"] = base["snippet"]
        if "_norm_snippet" in base:
            item["_norm_snippet"] = base["_norm_snippet"]
        if "score" in base:
            item["orig_score"] = base["score"]
        fused.append(item)

    fused.sort(key=lambda x: x["rrf_score"], reverse=True)
    return fused

# ---------- Metrics ----------

def average_precision_at_k_from_binary(rel: List[int], k: int) -> float:
    """
    rel: list of 0/1 indicating relevance at each rank position (already aligned with retrieved order).
    """
    # AP@k for binary relevance with unknown total relevant count -> standard IR AP:
    # sum(precision@i when rel[i]=1)/num_relevant
    # If no relevant, return 0
    rel_at_k = rel[:k]
    num_relevant = sum(rel)
    if num_relevant == 0:
        return 0.0
    hits = 0
    ap_sum = 0.0
    for i, r in enumerate(rel_at_k, start=1):
        if r == 1:
            hits += 1
            ap_sum += hits / i
    return ap_sum / num_relevant

def dcg_at_k_from_binary(rel: List[int], k: int) -> float:
    dcg = 0.0
    for i, g in enumerate(rel[:k], start=1):
        if g:
            dcg += 1.0 / math.log2(i + 1)
    return dcg

def ndcg_at_k_from_binary(rel: List[int], k: int) -> float:
    dcg = dcg_at_k_from_binary(rel, k)
    ideal = sorted(rel, reverse=True)  # all 1s first
    idcg = dcg_at_k_from_binary(ideal, k)
    return 0.0 if idcg == 0 else dcg / idcg

# ---------- Main pipeline ----------

def hybrid_rrf_with_eval(
    splade_csv: str,
    mpnet_csv: str,
    join_key: str = "question",
    answer_col: str = "answer",
    passage_col: str = "passage",  # GT text
    splade_docs_col: str = "splade_ret_docs",
    mpnet_docs_col: str = "mpnet_ret_docs",
    k_rrf: int = 60,
    topk_to_keep: int = 10
) -> pd.DataFrame:
    # Read
    s = pd.read_csv(splade_csv)
    d = pd.read_csv(mpnet_csv)

    # Keep needed columns
    s_cols = [c for c in [join_key, answer_col, passage_col, splade_docs_col] if c in s.columns]
    d_cols = [c for c in [join_key, answer_col, passage_col, mpnet_docs_col] if c in d.columns]
    s_small = s[s_cols].copy()
    d_small = d[d_cols].copy()

    # Rename duplicates from d_small before merge (except join key)
    for col in d_cols:
        if col != join_key and col in s_small.columns:
            d_small = d_small.rename(columns={col: f"{col}__d"})

    # Merge
    df = s_small.merge(d_small, on=join_key, how="outer")

    def pick(row, primary: str, secondary: str) -> Any:
        if primary in row and not is_nan_like(row[primary]):
            return row[primary]
        if secondary in row and not is_nan_like(row[secondary]):
            return row[secondary]
        return None

    rows = []
    for _, row in df.iterrows():
        question = row.get(join_key)
        answer = row.get(answer_col) if answer_col in row else row.get(f"{answer_col}__d")

        # Retrieved lists
        splade_cell = pick(row, splade_docs_col, f"{splade_docs_col}__d")
        mpnet_cell = pick(row, mpnet_docs_col, f"{mpnet_docs_col}__d")
        splade_docs = parse_docs(splade_cell)
        mpnet_docs = parse_docs(mpnet_cell)

        # Fusion
        fused = rrf_fuse_lists([splade_docs, mpnet_docs], k_rrf=k_rrf)
        hybrid_ret_docs = fused[:topk_to_keep]

        # Ground-truth texts from 'passage'
        passage_cell = pick(row, passage_col, f"{passage_col}__d")
        gt_texts = parse_passage_as_text_list(passage_cell)

        # Binary relevance per retrieved doc by text match
        rel_binary = build_binary_relevance_from_text(hybrid_ret_docs, gt_texts)

        # Metrics
        map3 = average_precision_at_k_from_binary(rel_binary, 3)
        ndcg3 = ndcg_at_k_from_binary(rel_binary, 3)
        map5 = average_precision_at_k_from_binary(rel_binary, 5)
        ndcg5 = ndcg_at_k_from_binary(rel_binary, 5)
        map10 = average_precision_at_k_from_binary(rel_binary, 10)
        ndcg10 = ndcg_at_k_from_binary(rel_binary, 10)

        rows.append({
            "question": question,
            "answer": answer,
            # Keep original passage text(s) (normalized for clarity). If you prefer raw, store passage_cell instead.
            "passage_norm": gt_texts,
            "hybrid_ret_docs": hybrid_ret_docs,
            "MAP@3": map3,
            "NDCG@3": ndcg3,
            "MAP@5": map5,
            "NDCG@5": ndcg5,
            "MAP@10": map10,
            "NDCG@10": ndcg10
        })

    return pd.DataFrame(rows)

# -------- Example run and save --------
result = hybrid_rrf_with_eval(
    splade_csv="prior_results/halubench_splade.csv",
    mpnet_csv="prior_results/halubench_mpnet.csv",
    join_key="question",
    answer_col="answer",
    passage_col="passage",
    splade_docs_col="splade_ret_docs",
    mpnet_docs_col="mpnet_ret_docs",
    k_rrf=60,
    topk_to_keep=10
)
result.to_csv("hybrid_rrf_with_metrics.csv", index=False)



In [14]:
# results summary
import pandas as pd
for name, path in [("Hybrid Halubench","hybrid_rrf_with_metrics.csv")]:
    df = pd.read_csv(path)
    print(name, "Results:")
    for k in (3,5,10):
        print(f"  MAP@{k}: {pd.to_numeric(df[f'MAP@{k}'], errors='coerce').mean():.4f}, "
              f"NDCG@{k}: {pd.to_numeric(df[f'NDCG@{k}'], errors='coerce').mean():.4f}")
    print()

Hybrid Halubench Results:
  MAP@3: 0.5016, NDCG@3: 0.5057
  MAP@5: 0.5038, NDCG@5: 0.5093
  MAP@10: 0.5059, NDCG@10: 0.5136



# Hybrid all pipelines on Hotpot

In [None]:
import json
import ast
import re
import os
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple

# ============ CONFIG (paths relative to the notebook's folder) ============
# Make sure your current working directory is the "hybrid_retrieval" folder (where the notebook lives).
SPLADE_CSV = "./prior_results/hotpotqa_splade.csv"
MPNET_CSV  = "./prior_results/hotpotqa_mpnet.csv"
OUTPUT_CSV = "./hybrid_weighted_score_sum_combmnz.csv"

# Baseline RRF parameter
RRF_K = 60

# Original optional rank-from-score normalization for RRF ranking
NORMALIZE_SCORES = False

# Fusion config
# "base" == plain RRF (exactly your original hybrid)
# Other options: "rrf" (alias of base), "weighted_rrf", "weighted_score_sum", "weighted_score_sum_combmnz"
FUSION_MODE = "weighted_score_sum_combmnz"
SCORE_NORM_PER_LIST = True     # used only for score-sum fusion modes: min-max per list
WEIGHT_TEMPERATURE = 1.17      # temperature applied to QC logits; 1.0 no change; >1 smoother; <1 sharper
LOG_WEIGHTS = True             # include weights & fusion meta in the hybrid output entries
# ========================================================================


# ---------- QC Classifier (BERT) ----------
from transformers import BertTokenizer, BertForSequenceClassification
import torch

# Local folder next to the notebook
QC_MODEL_FOLDER = "./bert_model_QC_finetuned"

_qc_tokenizer = None
_qc_model = None

def _lazy_load_qc():
    """
    Lazily load the local QC model and tokenizer exactly from QC_MODEL_FOLDER.
    Will not hit Hugging Face Hub (local_files_only=True).
    """
    global _qc_tokenizer, _qc_model
    if _qc_tokenizer is None or _qc_model is None:
        if not os.path.isdir(QC_MODEL_FOLDER):
            raise OSError(
                f"QC model folder not found: {QC_MODEL_FOLDER}. "
                f"CWD: {os.getcwd()}. "
                f"Ensure the directory exists and contains tokenizer + model files "
                f"(e.g., config.json, pytorch_model.bin, tokenizer.json or vocab.txt, "
                f"tokenizer_config.json, special_tokens_map.json)."
            )
        _qc_tokenizer = BertTokenizer.from_pretrained(QC_MODEL_FOLDER, local_files_only=True)
        _qc_model = BertForSequenceClassification.from_pretrained(QC_MODEL_FOLDER, local_files_only=True)
        _qc_model.eval()

def get_qc_weights(question: str, temperature: float = 1.0) -> Tuple[float, float]:
    """
    Run the QC classifier on a single question and return (w_sparse, w_dense).
    Assumes class 0 = sparse, class 1 = dense.
    Temperature is applied to logits before softmax.
    """
    _lazy_load_qc()
    inputs = _qc_tokenizer(question, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = _qc_model(**inputs)
        logits = outputs.logits
    probs = torch.softmax(logits / float(temperature), dim=-1)[0]  # shape: (2,)
    w_sparse = float(probs[0].item())
    w_dense = float(probs[1].item())
    s = w_sparse + w_dense
    if s <= 0:
        return 0.5, 0.5
    return w_sparse / s, w_dense / s


# ---------- Utilities ----------

def normalize_text(s: Any) -> str:
    if s is None or (isinstance(s, float) and np.isnan(s)):
        return ""
    s = str(s)
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def safe_parse_list(val: Any) -> Any:
    if isinstance(val, (list, dict)):
        return val
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return []
    s = str(val).strip()
    if s == "":
        return []
    try:
        return json.loads(s)
    except Exception:
        try:
            return ast.literal_eval(s)
        except Exception:
            return []

def normalize_groundtruth_str(gt_field: Any) -> str:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        norm_items = [normalize_text(x) for x in data]
        return json.dumps(norm_items, ensure_ascii=False)
    return json.dumps([normalize_text(str(gt_field))], ensure_ascii=False)

def parse_ret_list(ret_field: Any) -> List[Dict[str, Any]]:
    """
    Parse retrieval list into a list of dicts: {doc_id, score, full_text}
    """
    data = safe_parse_list(ret_field)
    out = []
    if isinstance(data, list):
        for d in data:
            if isinstance(d, dict) and "doc_id" in d:
                doc_id = str(d.get("doc_id"))
                score = float(d.get("score", 0.0))
                # Use full_text (authoritative), ignore preview_snippet
                full_text = d.get("full_text", "")
                out.append({"doc_id": doc_id, "score": score, "full_text": full_text})
    return out

def parse_groundtruth_list(gt_field: Any) -> List[str]:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        return [str(x) for x in data]
    s = str(gt_field).strip()
    if not s:
        return []
    if "," in s:
        return [t.strip() for t in s.split(",") if t.strip()]
    return s.split()

def min_max_normalize(scores: List[float]) -> List[float]:
    if not scores:
        return []
    mn, mx = min(scores), max(scores)
    if mx == mn:
        return [0.0 for _ in scores]
    return [(x - mn) / (mx - mn) for x in scores]

def rank_from_scores(items: List[Dict[str, Any]], normalize: bool = False) -> Dict[str, int]:
    if not items:
        return {}
    arr = items.copy()
    if normalize:
        norm_scores = min_max_normalize([x["score"] for x in arr])
        for i, ns in enumerate(norm_scores):
            arr[i]["_tmp_score"] = ns
        key = "_tmp_score"
    else:
        key = "score"
    arr_sorted = sorted(arr, key=lambda x: x.get(key, 0.0), reverse=True)
    return {it["doc_id"]: idx for idx, it in enumerate(arr_sorted, start=1)}


# ---------- Fusion Methods ----------

def rrf_fuse_detailed(lists: List[List[Dict[str, Any]]], k: int = 60, normalize_scores: bool = False):
    """
    Reciprocal Rank Fusion with detailed output.
    Returns a list sorted by RRF score desc:
      [
        {
          'doc_id': str,
          'rrf_score': float,
          'ranks': [rank_in_list0_or_None, rank_in_list1_or_None, ...]
        }, ...
      ]
    """
    rank_maps = [rank_from_scores(lst, normalize=normalize_scores) for lst in lists]
    all_doc_ids = set().union(*[set(rm.keys()) for rm in rank_maps])

    fused = []
    for d in all_doc_ids:
        ranks = [rm.get(d) for rm in rank_maps]
        rrf_score = sum(1.0 / (k + r) for r in ranks if r is not None)
        fused.append({"doc_id": d, "rrf_score": rrf_score, "ranks": ranks})

    fused.sort(key=lambda x: x["rrf_score"], reverse=True)
    return fused

def rrf_fuse_weighted(lists: List[List[Dict[str, Any]]], weights: List[float], k: int = 60, normalize_scores: bool = False):
    """
    Weighted RRF: sum_i w_i * 1/(k + rank_i)
    """
    rank_maps = [rank_from_scores(lst, normalize=normalize_scores) for lst in lists]
    all_doc_ids = set().union(*[set(rm.keys()) for rm in rank_maps])

    fused = []
    for d in all_doc_ids:
        ranks = [rm.get(d) for rm in rank_maps]
        rrf_score = 0.0
        for i, r in enumerate(ranks):
            if r is not None:
                rrf_score += float(weights[i]) * (1.0 / (k + r))
        fused.append({"doc_id": d, "rrf_score": rrf_score, "ranks": ranks})

    fused.sort(key=lambda x: x["rrf_score"], reverse=True)
    return fused

def score_sum_fuse_weighted(lists: List[List[Dict[str, Any]]], weights: List[float], normalize_per_list: bool = True, combmnz: bool = False):
    """
    Weighted linear score fusion:
      fused_score(d) = sum_i w_i * s_i(d)
    where s_i(d) are per-list scores, optionally min-max normalized within each list.
    COMBMNZ variant multiplies by the count of lists where the doc appears:
      fused_mnz(d) = fused_score(d) * (#systems that retrieved d)
    """
    # Build doc_id -> score per list
    score_maps = []
    for lst in lists:
        if normalize_per_list:
            norm_scores = min_max_normalize([x["score"] for x in lst])
            m = {lst[i]["doc_id"]: norm_scores[i] for i in range(len(lst))}
        else:
            m = {x["doc_id"]: float(x["score"]) for x in lst}
        score_maps.append(m)

    all_doc_ids = set().union(*[set(m.keys()) for m in score_maps])

    fused = []
    for d in all_doc_ids:
        per_list_scores = [m.get(d, 0.0) for m in score_maps]
        fused_score = sum(float(weights[i]) * per_list_scores[i] for i in range(len(per_list_scores)))
        num_systems_present = sum(1 for s in per_list_scores if s > 0.0)
        if combmnz:
            fused_score *= max(1, num_systems_present)
        fused.append({
            "doc_id": d,
            "fused_score": fused_score,
            "per_list_scores": per_list_scores,
            "num_systems": num_systems_present
        })

    fused.sort(key=lambda x: x["fused_score"], reverse=True)
    return fused


# ---------- Metrics (text-based evaluation) ----------

def apk(actual: List[str], predicted: List[str], k: int) -> float:
    if not actual:
        return 0.0
    pred_k = predicted[:k]
    hits, score = 0, 0.0
    seen = set()
    for i, p in enumerate(pred_k, start=1):
        if p in actual and p not in seen:
            hits += 1
            score += hits / i
            seen.add(p)
    return score / min(len(actual), k)

def mapk(actual_list: List[List[str]], predicted_list: List[List[str]], k: int) -> float:
    scores = [apk(a, p, k) for a, p in zip(actual_list, predicted_list)]
    return float(np.mean(scores)) if scores else 0.0

def dcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    pred_k = predicted[:k]
    dcg = 0.0
    for i, p in enumerate(pred_k, start=1):
        rel = 1.0 if p in ideal_texts else 0.0
        if rel:
            dcg += rel / np.log2(i + 1)
    return dcg

def idcg_at_k(ideal_texts: List[str], k: int) -> float:
    g = min(len(ideal_texts), k)
    return sum(1.0 / np.log2(i + 1) for i in range(1, g + 1))

def ndcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    idcg = idcg_at_k(ideal_texts, k)
    if idcg == 0.0:
        return 0.0
    return dcg_at_k(predicted, ideal_texts, k) / idcg


# ---------- Main Pipeline ----------

def main():
    # Optional sanity check
    print("CWD:", os.getcwd())
    print("Expecting QC model folder at:", os.path.abspath(QC_MODEL_FOLDER))

    # Load files
    df_sparse = pd.read_csv(SPLADE_CSV)
    df_dense = pd.read_csv(MPNET_CSV)

    # Preserve originals for output
    df_sparse["_q_orig"] = df_sparse["question"]
    df_sparse["_a_orig"] = df_sparse["answer"]
    df_sparse["_gt_orig"] = df_sparse["groundtruth_docs"]

    df_dense["_q_orig"] = df_dense["question"]
    df_dense["_a_orig"] = df_dense["answer"]
    df_dense["_gt_orig"] = df_dense["groundtruth_docs"]

    # Normalized join keys
    df_sparse["_q_norm"] = df_sparse["question"].apply(normalize_text)
    df_sparse["_a_norm"] = df_sparse["answer"].apply(normalize_text)
    df_sparse["_gt_norm"] = df_sparse["groundtruth_docs"].apply(normalize_groundtruth_str)

    df_dense["_q_norm"] = df_dense["question"].apply(normalize_text)
    df_dense["_a_norm"] = df_dense["answer"].apply(normalize_text)
    df_dense["_gt_norm"] = df_dense["groundtruth_docs"].apply(normalize_groundtruth_str)

    # Merge on normalized keys
    df = pd.merge(
        df_sparse,
        df_dense,
        on=["_q_norm", "_a_norm", "_gt_norm"],
        suffixes=("_splade", "_mpnet"),
        how="inner"
    )

    if df.empty:
        raise ValueError("No rows matched after normalization. Check groundtruth formats across CSVs.")

    hybrid_ret_docs_col = []
    map3_list, ndcg3_list = [], []
    map5_list, ndcg5_list = [], []
    map10_list, ndcg10_list = [], []

    for _, row in df.iterrows():
        # Parse lists
        splade_list = parse_ret_list(row["splade_ret_docs"])
        mpnet_list  = parse_ret_list(row["mpnet_ret_docs"])
        gt_list_raw = parse_groundtruth_list(row["_gt_orig_splade"])  # use either side; same after join

        # Build full_text lookup for ID->full_text
        fulltext_map = {}
        for d in splade_list:
            fulltext_map.setdefault(d["doc_id"], d.get("full_text", ""))
        for d in mpnet_list:
            fulltext_map.setdefault(d["doc_id"], d.get("full_text", ""))

        # QC weights for this query (only needed for weighted modes)
        question_text = row["_q_orig_splade"]
        if FUSION_MODE in ("weighted_rrf", "weighted_score_sum", "weighted_score_sum_combmnz"):
            w_sparse, w_dense = get_qc_weights(question_text, temperature=WEIGHT_TEMPERATURE)
            weights = [w_sparse, w_dense]  # [splade, mpnet]
        else:
            # Not used in base/rrf
            w_sparse, w_dense = None, None
            weights = None

        # Fusion
        # "base" is the original plain RRF (same as "rrf")
        mode = FUSION_MODE
        if mode == "base" or mode == "rrf":
            fused = rrf_fuse_detailed([splade_list, mpnet_list], k=RRF_K, normalize_scores=NORMALIZE_SCORES)
            fusion_label = "base"
            for f in fused:
                f["_fused_score"] = f["rrf_score"]
        elif mode == "weighted_rrf":
            fused = rrf_fuse_weighted([splade_list, mpnet_list], weights=weights, k=RRF_K, normalize_scores=NORMALIZE_SCORES)
            fusion_label = "weighted_rrf"
            for f in fused:
                f["_fused_score"] = f["rrf_score"]
        elif mode == "weighted_score_sum":
            fused_score = score_sum_fuse_weighted([splade_list, mpnet_list], weights=weights, normalize_per_list=SCORE_NORM_PER_LIST, combmnz=False)
            rank_maps = [rank_from_scores(splade_list, normalize=NORMALIZE_SCORES),
                         rank_from_scores(mpnet_list, normalize=NORMALIZE_SCORES)]
            fused = []
            for item in fused_score:
                d = item["doc_id"]
                fused.append({
                    "doc_id": d,
                    "_fused_score": item["fused_score"],
                    "ranks": [rank_maps[0].get(d), rank_maps[1].get(d)]
                })
            fused.sort(key=lambda x: x["_fused_score"], reverse=True)
            fusion_label = "weighted_score_sum"
        elif mode == "weighted_score_sum_combmnz":
            fused_score = score_sum_fuse_weighted([splade_list, mpnet_list], weights=weights, normalize_per_list=SCORE_NORM_PER_LIST, combmnz=True)
            rank_maps = [rank_from_scores(splade_list, normalize=NORMALIZE_SCORES),
                         rank_from_scores(mpnet_list, normalize=NORMALIZE_SCORES)]
            fused = []
            for item in fused_score:
                d = item["doc_id"]
                fused.append({
                    "doc_id": d,
                    "_fused_score": item["fused_score"],
                    "ranks": [rank_maps[0].get(d), rank_maps[1].get(d)]
                })
            fused.sort(key=lambda x: x["_fused_score"], reverse=True)
            fusion_label = "weighted_score_sum_combmnz"
        else:
            raise ValueError(f"Unknown FUSION_MODE: {FUSION_MODE}")

        # Build hybrid_ret_docs with fused score, rank, full_text, per-model rank, and weights
        hybrid_struct = []
        fused_norm_fulltexts = []  # for evaluation by text
        for idx, item in enumerate(fused, start=1):
            doc_id = item["doc_id"]
            fused_score = item["_fused_score"]
            ranks = item.get("ranks")
            if ranks is None:
                ranks = [None, None]
            full_text = fulltext_map.get(doc_id, "")

            entry = {
                "doc_id": doc_id,
                "score": fused_score,
                "rank": idx,
                "full_text": full_text,
                "source_ranks": {
                    "splade": ranks[0],
                    "mpnet": ranks[1],
                }
            }
            if LOG_WEIGHTS:
                if weights is not None:
                    entry["qc_weights"] = {"splade": weights[0], "mpnet": weights[1]}
                else:
                    entry["qc_weights"] = None
                entry["fusion_mode"] = fusion_label
            hybrid_struct.append(entry)

            # Collect normalized full text for evaluation by text
            fused_norm_fulltexts.append(normalize_text(full_text))

        # Save hybrid struct JSON
        hybrid_ret_docs_col.append(json.dumps(hybrid_struct, ensure_ascii=False))

        # Prepare normalized GT texts (maintain order for NDCG ideal)
        gt_norm_texts = [normalize_text(x) for x in parse_groundtruth_list(row["_gt_orig_splade"])]

        # Metrics by text
        pred = fused_norm_fulltexts
        map3_list.append(apk(gt_norm_texts, pred, 3))
        map5_list.append(apk(gt_norm_texts, pred, 5))
        map10_list.append(apk(gt_norm_texts, pred, 10))
        ndcg3_list.append(ndcg_at_k(pred, gt_norm_texts, 3))
        ndcg5_list.append(ndcg_at_k(pred, gt_norm_texts, 5))
        ndcg10_list.append(ndcg_at_k(pred, gt_norm_texts, 10))

    # Build output with requested columns (use original SPLADE-side text columns)
    out = pd.DataFrame({
        "question": df["_q_orig_splade"],
        "answer": df["_a_orig_splade"],
        "groundtruth_docs": df["_gt_orig_splade"],
        "splade_ret_docs": df["splade_ret_docs"],
        "mpnet_ret_docs": df["mpnet_ret_docs"],
        "hybrid_ret_docs": hybrid_ret_docs_col,
        "MAP@3": map3_list,
        "NDCG@3": ndcg3_list,
        "MAP@5": map5_list,
        "NDCG@5": ndcg5_list,
        "MAP@10": map10_list,
        "NDCG@10": ndcg10_list,
    })

    cols = [
        "question", "answer", "groundtruth_docs",
        "splade_ret_docs", "mpnet_ret_docs", "hybrid_ret_docs",
        "MAP@3", "NDCG@3", "MAP@5", "NDCG@5", "MAP@10", "NDCG@10"
    ]
    out = out[cols]

    out.to_csv(OUTPUT_CSV, index=False, encoding="utf-8")

    overall = {
        "MAP@3": float(np.mean(map3_list)) if map3_list else 0.0,
        "NDCG@3": float(np.mean(ndcg3_list)) if ndcg3_list else 0.0,
        "MAP@5": float(np.mean(map5_list)) if map5_list else 0.0,
        "NDCG@5": float(np.mean(ndcg5_list)) if ndcg5_list else 0.0,
        "MAP@10": float(np.mean(map10_list)) if map10_list else 0.0,
        "NDCG@10": float(np.mean(ndcg10_list)) if ndcg10_list else 0.0,
    }
    print("Saved:", OUTPUT_CSV)
    print("Fusion mode:", FUSION_MODE)
    print("QC temperature:", WEIGHT_TEMPERATURE)
    print("Hybrid overall averages:", overall)


# If running in a notebook cell, call main() explicitly:
main()

- Saved: ./hybrid_results.csv
- Fusion mode: base
- QC temperature: 1.17
- Hybrid overall averages: {'MAP@3': 0.44218438823684025, 'NDCG@3': 0.5229603376262054, 'MAP@5': 0.4671846504463248, 'NDCG@5': 0.5585157793137547, 'MAP@10': 0.48421740577612526, 'NDCG@10': 0.5866646649926855}

- Saved: ./hybrid_weighted_rrf.csv
- Fusion mode: weighted_rrf
- QC temperature: 1.17
- Hybrid overall averages: {'MAP@3': 0.35653002709287396, 'NDCG@3': 0.43031830609708344, 'MAP@5': 0.3723286350448347, 'NDCG@5': 0.45498537614676815, 'MAP@10': 0.38600341152032897, 'NDCG@10': 0.4796897845341999}

- Saved: ./hybrid_weighted_score_sum_combmnz.csv
- Fusion mode: weighted_score_sum_combmnz
- QC temperature: 1.17
- Hybrid overall averages: {'MAP@3': 0.381144547559114, 'NDCG@3': 0.4556785730660241, 'MAP@5': 0.39439389803778846, 'NDCG@5': 0.4763792247129293, 'MAP@10': 0.4196332105481169, 'NDCG@10': 0.527365715144966}



- Why is “hybrid (base)” worse than SPLADE alone?


Hybrid underperforming SPLADE
- RRF is rank-only. If SPLADE is much stronger than MPNet on this dataset (your metrics show that), vanilla RRF can hurt by:
- - Elevating MPNet-only docs into the top-k due to reciprocal-rank contributions, displacing SPLADE’s strong hits.
- - Using equal contribution from both lists regardless of quality.

This is common: when one list is clearly better, naive RRF may degrade metrics.
Ways to fix

- Use top_k caps per list, then fuse:
- - If you feed 10 docs from each, the fused universe is max 20 docs; but more importantly, limiting MPNet’s long tail reduces harm.
- Prefer weighted fusion:
- - Weighted-RRF with global weights favoring SPLADE (e.g., w_sparse=0.7, w_dense=0.3) or use your QC weights per-query.
- - Weighted score fusion (min-max per list), optionally with COMBMNZ.
- Re-rank cut: Always take only top-N after fusion (e.g., N=10 for evaluation). Your metrics are computed using the full fused list right now; that’s fine because MAP@k, NDCG@k look at only top-k anyway, but it’s clearer to explicitly cap to K.
- Tune RRF K:
Lower K makes top ranks dominate more; higher K flattens. Try K in {10, 30, 60, 100}.
- Normalize scores before ranking:
For rank maps, you currently have NORMALIZE_SCORES=False. Try True to reduce sensitivity to per-list scoring peculiarities before rank extraction.

In [23]:
# New
# Implementing 1) How to Make hybrid fusion robust to spiky weights?

import json
import ast
import re
import os
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple

# ============ CONFIG (paths relative to the notebook's folder) ============
# Make sure your current working directory is the "hybrid_retrieval" folder (where the notebook lives).
SPLADE_CSV = "./prior_results/hotpotqa_splade.csv"
MPNET_CSV  = "./prior_results/hotpotqa_mpnet.csv"
OUTPUT_CSV = "./hybrid_weighted_rrf.csv"

# Baseline RRF parameter
RRF_K = 60

# Original optional rank-from-score normalization for RRF ranking
NORMALIZE_SCORES = False

# Fusion config
# "base" == plain RRF (exactly your original hybrid)
# Other options: "rrf" (alias of base), "weighted_rrf", "weighted_score_sum", "weighted_score_sum_combmnz"
FUSION_MODE = "weighted_rrf"
SCORE_NORM_PER_LIST = True     # used only for score-sum fusion modes: min-max per list
WEIGHT_TEMPERATURE = 1.17      # temperature applied to QC logits; 1.0 no change; >1 smoother; <1 sharper
LOG_WEIGHTS = True             # include weights & fusion meta in the hybrid output entries

# Robustness knobs for spiky weights
CLAMP_WEIGHTS = True        # enforce floor/ceiling on per-query weights
WEIGHT_FLOOR = 0.25         # lower bound for one retriever
WEIGHT_CEIL = 0.75          # upper bound for the other (complements floor)

SMOOTH_TOWARD_HALF = True   # blend weights toward 0.5
SMOOTH_ALPHA = 0.7          # 0..1; 0 = full 0.5, 1 = original weight. e.g., 0.7 => w = 0.7*w + 0.3*0.5

# Control how many items from each retriever go into fusion and how many to keep at the end
TOP_K_PER_LIST = 20         # cap each input list before fusion
FINAL_TOP_K = 10            # cap the final fused list before saving/evaluating

# ========================================================================


# ---------- QC Classifier (BERT) ----------
from transformers import BertTokenizer, BertForSequenceClassification
import torch

# Local folder next to the notebook
QC_MODEL_FOLDER = "./bert_model_QC_finetuned"

_qc_tokenizer = None
_qc_model = None

def _lazy_load_qc():
    """
    Lazily load the local QC model and tokenizer exactly from QC_MODEL_FOLDER.
    Will not hit Hugging Face Hub (local_files_only=True).
    """
    global _qc_tokenizer, _qc_model
    if _qc_tokenizer is None or _qc_model is None:
        if not os.path.isdir(QC_MODEL_FOLDER):
            raise OSError(
                f"QC model folder not found: {QC_MODEL_FOLDER}. "
                f"CWD: {os.getcwd()}. "
                f"Ensure the directory exists and contains tokenizer + model files "
                f"(e.g., config.json, pytorch_model.bin, tokenizer.json or vocab.txt, "
                f"tokenizer_config.json, special_tokens_map.json)."
            )
        _qc_tokenizer = BertTokenizer.from_pretrained(QC_MODEL_FOLDER, local_files_only=True)
        _qc_model = BertForSequenceClassification.from_pretrained(QC_MODEL_FOLDER, local_files_only=True)
        _qc_model.eval()

def get_qc_weights(question: str, temperature: float = 1.0) -> Tuple[float, float]:
    """
    Run the QC classifier on a single question and return (w_sparse, w_dense).
    Assumes class 0 = sparse, class 1 = dense.
    Temperature is applied to logits before softmax.
    """
    _lazy_load_qc()
    inputs = _qc_tokenizer(question, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = _qc_model(**inputs)
        logits = outputs.logits
    probs = torch.softmax(logits / float(temperature), dim=-1)[0]  # shape: (2,)
    w_sparse = float(probs[0].item())
    w_dense = float(probs[1].item())
    s = w_sparse + w_dense
    if s <= 0:
        return 0.5, 0.5
    return w_sparse / s, w_dense / s

def clamp_pair(p0: float, p1: float, lo: float, hi: float) -> Tuple[float, float]:
    # clamp first, renormalize, then clamp complement
    p0 = max(lo, min(hi, p0))
    p1 = 1.0 - p0
    # ensure complement respects bounds too (if hi < 0.5 this is redundant but safe)
    if p1 < lo:
        p1 = lo
        p0 = 1.0 - p1
    if p1 > hi:
        p1 = hi
        p0 = 1.0 - p1
    return p0, p1

def smooth_toward_half(p0: float, p1: float, alpha: float) -> Tuple[float, float]:
    # convex mix toward 0.5
    p0_s = alpha * p0 + (1.0 - alpha) * 0.5
    p1_s = alpha * p1 + (1.0 - alpha) * 0.5
    s = p0_s + p1_s
    if s <= 0:
        return 0.5, 0.5
    return p0_s / s, p1_s / s

# ---------- Utilities ----------

def normalize_text(s: Any) -> str:
    if s is None or (isinstance(s, float) and np.isnan(s)):
        return ""
    s = str(s)
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def safe_parse_list(val: Any) -> Any:
    if isinstance(val, (list, dict)):
        return val
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return []
    s = str(val).strip()
    if s == "":
        return []
    try:
        return json.loads(s)
    except Exception:
        try:
            return ast.literal_eval(s)
        except Exception:
            return []

def normalize_groundtruth_str(gt_field: Any) -> str:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        norm_items = [normalize_text(x) for x in data]
        return json.dumps(norm_items, ensure_ascii=False)
    return json.dumps([normalize_text(str(gt_field))], ensure_ascii=False)

def parse_ret_list(ret_field: Any) -> List[Dict[str, Any]]:
    """
    Parse retrieval list into a list of dicts: {doc_id, score, full_text}
    """
    data = safe_parse_list(ret_field)
    out = []
    if isinstance(data, list):
        for d in data:
            if isinstance(d, dict) and "doc_id" in d:
                doc_id = str(d.get("doc_id"))
                score = float(d.get("score", 0.0))
                # Use full_text (authoritative), ignore preview_snippet
                full_text = d.get("full_text", "")
                out.append({"doc_id": doc_id, "score": score, "full_text": full_text})
    return out

def parse_groundtruth_list(gt_field: Any) -> List[str]:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        return [str(x) for x in data]
    s = str(gt_field).strip()
    if not s:
        return []
    if "," in s:
        return [t.strip() for t in s.split(",") if t.strip()]
    return s.split()

def min_max_normalize(scores: List[float]) -> List[float]:
    if not scores:
        return []
    mn, mx = min(scores), max(scores)
    if mx == mn:
        return [0.0 for _ in scores]
    return [(x - mn) / (mx - mn) for x in scores]

def rank_from_scores(items: List[Dict[str, Any]], normalize: bool = False) -> Dict[str, int]:
    if not items:
        return {}
    arr = items.copy()
    if normalize:
        norm_scores = min_max_normalize([x["score"] for x in arr])
        for i, ns in enumerate(norm_scores):
            arr[i]["_tmp_score"] = ns
        key = "_tmp_score"
    else:
        key = "score"
    arr_sorted = sorted(arr, key=lambda x: x.get(key, 0.0), reverse=True)
    return {it["doc_id"]: idx for idx, it in enumerate(arr_sorted, start=1)}


# ---------- Fusion Methods ----------

def rrf_fuse_detailed(lists: List[List[Dict[str, Any]]], k: int = 60, normalize_scores: bool = False):
    """
    Reciprocal Rank Fusion with detailed output.
    Returns a list sorted by RRF score desc:
      [
        {
          'doc_id': str,
          'rrf_score': float,
          'ranks': [rank_in_list0_or_None, rank_in_list1_or_None, ...]
        }, ...
      ]
    """
    rank_maps = [rank_from_scores(lst, normalize=normalize_scores) for lst in lists]
    all_doc_ids = set().union(*[set(rm.keys()) for rm in rank_maps])

    fused = []
    for d in all_doc_ids:
        ranks = [rm.get(d) for rm in rank_maps]
        rrf_score = sum(1.0 / (k + r) for r in ranks if r is not None)
        fused.append({"doc_id": d, "rrf_score": rrf_score, "ranks": ranks})

    fused.sort(key=lambda x: x["rrf_score"], reverse=True)
    return fused

def rrf_fuse_weighted(lists: List[List[Dict[str, Any]]], weights: List[float], k: int = 60, normalize_scores: bool = False):
    """
    Weighted RRF: sum_i w_i * 1/(k + rank_i)
    """
    rank_maps = [rank_from_scores(lst, normalize=normalize_scores) for lst in lists]
    all_doc_ids = set().union(*[set(rm.keys()) for rm in rank_maps])

    fused = []
    for d in all_doc_ids:
        ranks = [rm.get(d) for rm in rank_maps]
        rrf_score = 0.0
        for i, r in enumerate(ranks):
            if r is not None:
                rrf_score += float(weights[i]) * (1.0 / (k + r))
        fused.append({"doc_id": d, "rrf_score": rrf_score, "ranks": ranks})

    fused.sort(key=lambda x: x["rrf_score"], reverse=True)
    return fused

def score_sum_fuse_weighted(lists: List[List[Dict[str, Any]]], weights: List[float], normalize_per_list: bool = True, combmnz: bool = False):
    """
    Weighted linear score fusion:
      fused_score(d) = sum_i w_i * s_i(d)
    where s_i(d) are per-list scores, optionally min-max normalized within each list.
    COMBMNZ variant multiplies by the count of lists where the doc appears:
      fused_mnz(d) = fused_score(d) * (#systems that retrieved d)
    """
    # Build doc_id -> score per list
    score_maps = []
    for lst in lists:
        if normalize_per_list:
            norm_scores = min_max_normalize([x["score"] for x in lst])
            m = {lst[i]["doc_id"]: norm_scores[i] for i in range(len(lst))}
        else:
            m = {x["doc_id"]: float(x["score"]) for x in lst}
        score_maps.append(m)

    all_doc_ids = set().union(*[set(m.keys()) for m in score_maps])

    fused = []
    for d in all_doc_ids:
        per_list_scores = [m.get(d, 0.0) for m in score_maps]
        fused_score = sum(float(weights[i]) * per_list_scores[i] for i in range(len(per_list_scores)))
        num_systems_present = sum(1 for s in per_list_scores if s > 0.0)
        if combmnz:
            fused_score *= max(1, num_systems_present)
        fused.append({
            "doc_id": d,
            "fused_score": fused_score,
            "per_list_scores": per_list_scores,
            "num_systems": num_systems_present
        })

    fused.sort(key=lambda x: x["fused_score"], reverse=True)
    return fused


# ---------- Metrics (text-based evaluation) ----------

def apk(actual: List[str], predicted: List[str], k: int) -> float:
    if not actual:
        return 0.0
    pred_k = predicted[:k]
    hits, score = 0, 0.0
    seen = set()
    for i, p in enumerate(pred_k, start=1):
        if p in actual and p not in seen:
            hits += 1
            score += hits / i
            seen.add(p)
    return score / min(len(actual), k)

def mapk(actual_list: List[List[str]], predicted_list: List[List[str]], k: int) -> float:
    scores = [apk(a, p, k) for a, p in zip(actual_list, predicted_list)]
    return float(np.mean(scores)) if scores else 0.0

def dcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    pred_k = predicted[:k]
    dcg = 0.0
    for i, p in enumerate(pred_k, start=1):
        rel = 1.0 if p in ideal_texts else 0.0
        if rel:
            dcg += rel / np.log2(i + 1)
    return dcg

def idcg_at_k(ideal_texts: List[str], k: int) -> float:
    g = min(len(ideal_texts), k)
    return sum(1.0 / np.log2(i + 1) for i in range(1, g + 1))

def ndcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    idcg = idcg_at_k(ideal_texts, k)
    if idcg == 0.0:
        return 0.0
    return dcg_at_k(predicted, ideal_texts, k) / idcg


# ---------- Main Pipeline ----------

def main():
    # Optional sanity check
    print("CWD:", os.getcwd())
    print("Expecting QC model folder at:", os.path.abspath(QC_MODEL_FOLDER))

    # Load files
    df_sparse = pd.read_csv(SPLADE_CSV)
    df_dense = pd.read_csv(MPNET_CSV)

    # Preserve originals for output
    df_sparse["_q_orig"] = df_sparse["question"]
    df_sparse["_a_orig"] = df_sparse["answer"]
    df_sparse["_gt_orig"] = df_sparse["groundtruth_docs"]

    df_dense["_q_orig"] = df_dense["question"]
    df_dense["_a_orig"] = df_dense["answer"]
    df_dense["_gt_orig"] = df_dense["groundtruth_docs"]

    # Normalized join keys
    df_sparse["_q_norm"] = df_sparse["question"].apply(normalize_text)
    df_sparse["_a_norm"] = df_sparse["answer"].apply(normalize_text)
    df_sparse["_gt_norm"] = df_sparse["groundtruth_docs"].apply(normalize_groundtruth_str)

    df_dense["_q_norm"] = df_dense["question"].apply(normalize_text)
    df_dense["_a_norm"] = df_dense["answer"].apply(normalize_text)
    df_dense["_gt_norm"] = df_dense["groundtruth_docs"].apply(normalize_groundtruth_str)

    # Merge on normalized keys
    df = pd.merge(
        df_sparse,
        df_dense,
        on=["_q_norm", "_a_norm", "_gt_norm"],
        suffixes=("_splade", "_mpnet"),
        how="inner"
    )

    if df.empty:
        raise ValueError("No rows matched after normalization. Check groundtruth formats across CSVs.")

    hybrid_ret_docs_col = []
    map3_list, ndcg3_list = [], []
    map5_list, ndcg5_list = [], []
    map10_list, ndcg10_list = [], []

    for _, row in df.iterrows():
        # Parse lists
        splade_list = parse_ret_list(row["splade_ret_docs"])[:TOP_K_PER_LIST]
        mpnet_list  = parse_ret_list(row["mpnet_ret_docs"])[:TOP_K_PER_LIST]

        # Build full_text lookup for ID->full_text
        fulltext_map = {}
        for d in splade_list:
            fulltext_map.setdefault(d["doc_id"], d.get("full_text", ""))
        for d in mpnet_list:
            fulltext_map.setdefault(d["doc_id"], d.get("full_text", ""))

        # QC weights for this query (only needed for weighted modes)
        question_text = row["_q_orig_splade"]
        if FUSION_MODE in ("weighted_rrf", "weighted_score_sum", "weighted_score_sum_combmnz"):
            w_sparse, w_dense = get_qc_weights(question_text, temperature=WEIGHT_TEMPERATURE)

            # 1) optional smoothing toward 0.5
            if SMOOTH_TOWARD_HALF:
                w_sparse, w_dense = smooth_toward_half(w_sparse, w_dense, SMOOTH_ALPHA)

            # 2) optional clamping to guarantee minimum mix of both retrievers
            if CLAMP_WEIGHTS:
                w_sparse, w_dense = clamp_pair(w_sparse, w_dense, WEIGHT_FLOOR, WEIGHT_CEIL)

            # final sanity renormalization
            s = w_sparse + w_dense
            if s <= 0:
                w_sparse, w_dense = 0.5, 0.5
            else:
                w_sparse, w_dense = w_sparse / s, w_dense / s

            weights = [float(w_sparse), float(w_dense)]
        else:
            w_sparse, w_dense = None, None
            weights = None

        # Fusion
        # "base" is the original plain RRF (same as "rrf")
        mode = FUSION_MODE
        if mode == "base" or mode == "rrf":
            fused = rrf_fuse_detailed([splade_list, mpnet_list], k=RRF_K, normalize_scores=NORMALIZE_SCORES)
            fusion_label = "base"
            for f in fused:
                f["_fused_score"] = f["rrf_score"]
            # cap final list here
            fused = fused[:FINAL_TOP_K]
        elif mode == "weighted_rrf":
            fused = rrf_fuse_weighted([splade_list, mpnet_list], weights=weights, k=RRF_K, normalize_scores=NORMALIZE_SCORES)
            fusion_label = "weighted_rrf"
            for f in fused:
                f["_fused_score"] = f["rrf_score"]
            # cap final list here
            fused = fused[:FINAL_TOP_K]
        elif mode == "weighted_score_sum":
            fused_score = score_sum_fuse_weighted([splade_list, mpnet_list], weights=weights, normalize_per_list=SCORE_NORM_PER_LIST, combmnz=False)
            rank_maps = [rank_from_scores(splade_list, normalize=NORMALIZE_SCORES),
                         rank_from_scores(mpnet_list, normalize=NORMALIZE_SCORES)]
            fused = []
            for item in fused_score:
                d = item["doc_id"]
                fused.append({
                    "doc_id": d,
                    "_fused_score": item["fused_score"],
                    "ranks": [rank_maps[0].get(d), rank_maps[1].get(d)]
                })
            fused.sort(key=lambda x: x["_fused_score"], reverse=True)
            fusion_label = "weighted_score_sum"
            # cap final list here
            fused = fused[:FINAL_TOP_K]
        elif mode == "weighted_score_sum_combmnz":
            fused_score = score_sum_fuse_weighted([splade_list, mpnet_list], weights=weights, normalize_per_list=SCORE_NORM_PER_LIST, combmnz=True)
            rank_maps = [rank_from_scores(splade_list, normalize=NORMALIZE_SCORES),
                         rank_from_scores(mpnet_list, normalize=NORMALIZE_SCORES)]
            fused = []
            for item in fused_score:
                d = item["doc_id"]
                fused.append({
                    "doc_id": d,
                    "_fused_score": item["fused_score"],
                    "ranks": [rank_maps[0].get(d), rank_maps[1].get(d)]
                })
            fused.sort(key=lambda x: x["_fused_score"], reverse=True)
            fusion_label = "weighted_score_sum_combmnz"
            # cap final list here
            fused = fused[:FINAL_TOP_K]
        else:
            raise ValueError(f"Unknown FUSION_MODE: {FUSION_MODE}")

        # Build hybrid_ret_docs with fused score, rank, full_text, per-model rank, and weights
        hybrid_struct = []
        fused_norm_fulltexts = []  # for evaluation by text
        for idx, item in enumerate(fused, start=1):
            doc_id = item["doc_id"]
            fused_score = item["_fused_score"]
            ranks = item.get("ranks")
            if ranks is None:
                ranks = [None, None]
            full_text = fulltext_map.get(doc_id, "")

            entry = {
                "doc_id": doc_id,
                "score": fused_score,
                "rank": idx,
                "full_text": full_text,
                "source_ranks": {
                    "splade": ranks[0],
                    "mpnet": ranks[1],
                }
            }
            if LOG_WEIGHTS:
                if weights is not None:
                    entry["qc_weights"] = {"splade": weights[0], "mpnet": weights[1]}
                else:
                    entry["qc_weights"] = None
                entry["fusion_mode"] = fusion_label
            hybrid_struct.append(entry)

            # Collect normalized full text for evaluation by text
            fused_norm_fulltexts.append(normalize_text(full_text))

        # Save hybrid struct JSON
        hybrid_ret_docs_col.append(json.dumps(hybrid_struct, ensure_ascii=False))

        # Prepare normalized GT texts (maintain order for NDCG ideal)
        gt_norm_texts = [normalize_text(x) for x in parse_groundtruth_list(row["_gt_orig_splade"])]

        # Metrics by text
        pred = fused_norm_fulltexts
        map3_list.append(apk(gt_norm_texts, pred, 3))
        map5_list.append(apk(gt_norm_texts, pred, 5))
        map10_list.append(apk(gt_norm_texts, pred, 10))
        ndcg3_list.append(ndcg_at_k(pred, gt_norm_texts, 3))
        ndcg5_list.append(ndcg_at_k(pred, gt_norm_texts, 5))
        ndcg10_list.append(ndcg_at_k(pred, gt_norm_texts, 10))

    # Build output with requested columns (use original SPLADE-side text columns)
    out = pd.DataFrame({
        "question": df["_q_orig_splade"],
        "answer": df["_a_orig_splade"],
        "groundtruth_docs": df["_gt_orig_splade"],
        "splade_ret_docs": df["splade_ret_docs"],
        "mpnet_ret_docs": df["mpnet_ret_docs"],
        "hybrid_ret_docs": hybrid_ret_docs_col,
        "MAP@3": map3_list,
        "NDCG@3": ndcg3_list,
        "MAP@5": map5_list,
        "NDCG@5": ndcg5_list,
        "MAP@10": map10_list,
        "NDCG@10": ndcg10_list,
    })

    cols = [
        "question", "answer", "groundtruth_docs",
        "splade_ret_docs", "mpnet_ret_docs", "hybrid_ret_docs",
        "MAP@3", "NDCG@3", "MAP@5", "NDCG@5", "MAP@10", "NDCG@10"
    ]
    out = out[cols]

    out.to_csv(OUTPUT_CSV, index=False, encoding="utf-8")

    overall = {
        "MAP@3": float(np.mean(map3_list)) if map3_list else 0.0,
        "NDCG@3": float(np.mean(ndcg3_list)) if ndcg3_list else 0.0,
        "MAP@5": float(np.mean(map5_list)) if map5_list else 0.0,
        "NDCG@5": float(np.mean(ndcg5_list)) if ndcg5_list else 0.0,
        "MAP@10": float(np.mean(map10_list)) if map10_list else 0.0,
        "NDCG@10": float(np.mean(ndcg10_list)) if ndcg10_list else 0.0,
    }
    print("Saved:", OUTPUT_CSV)
    print("Fusion mode:", FUSION_MODE)
    print("QC temperature:", WEIGHT_TEMPERATURE)
    print("Hybrid overall averages:", overall)


# If running in a notebook cell, call main() explicitly:
main() 

CWD: /home/csmala/journal_rag/hybrid_pipeline
Expecting QC model folder at: /home/csmala/journal_rag/hybrid_pipeline/bert_model_QC_finetuned
Saved: ./hybrid_weighted_rrf.csv
Fusion mode: weighted_rrf
QC temperature: 1.17
Hybrid overall averages: {'MAP@3': 0.40284538359668015, 'NDCG@3': 0.4782867077931083, 'MAP@5': 0.41393834538865126, 'NDCG@5': 0.4954578325048874, 'MAP@10': 0.42103152301077024, 'NDCG@10': 0.5075874708412451}


Saved: ./hybrid_weighted_score_sum_combmnz.csv
Fusion mode: weighted_score_sum_combmnz
QC temperature: 1.17
Hybrid overall averages: {'MAP@3': 0.40784726645032804, 'NDCG@3': 0.4839240380520279, 'MAP@5': 0.4282661318856842, 'NDCG@5': 0.5155112511570251, 'MAP@10': 0.4536183281168094, 'NDCG@10': 0.5602454339260866}


Saved: ./hybrid_weighted_rrf.csv
Fusion mode: weighted_rrf
QC temperature: 1.17
Hybrid overall averages: {'MAP@3': 0.40284538359668015, 'NDCG@3': 0.4782867077931083, 'MAP@5': 0.41393834538865126, 'NDCG@5': 0.4954578325048874, 'MAP@10': 0.42103152301077024, 'NDCG@10': 0.5075874708412451}

In [27]:
# 2. try linear interpolation with global weights
 
import json
import ast
import re
import os
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple

# ============ CONFIG (paths relative to the notebook's folder) ============
# Make sure your current working directory is the "hybrid_retrieval" folder (where the notebook lives).
SPLADE_CSV = "./prior_results/hotpotqa_splade.csv"
MPNET_CSV  = "./prior_results/hotpotqa_mpnet.csv"
OUTPUT_CSV = "./hybrid_linear_interpolation.csv"

# Fusion mode: use score-based linear interpolation
# "weighted_score_sum" = λ·S_splade + (1−λ)·S_mpnet (after per-list min-max normalization)
# "weighted_score_sum_combmnz" = same as above, then multiply by number of lists containing the doc
FUSION_MODE = "weighted_score_sum_combmnz"  # "weighted_score_sum" | "weighted_score_sum_combmnz"

# Global weights (no QC). Set λ = GLOBAL_WEIGHTS[0], (1-λ) = GLOBAL_WEIGHTS[1]
USE_GLOBAL_WEIGHTS = True
GLOBAL_WEIGHTS = (0.7, 0.3)  # bias toward SPLADE which seems stronger on your dataset

# Per-list score normalization for score-sum fusions
SCORE_NORM_PER_LIST = True  # keep True for linear interpolation

# Input/output caps
TOP_K_PER_LIST = 20   # take top-N from each retriever before fusion
FINAL_TOP_K = 10      # keep only top-10 after fusion for storage/eval

# Logging options
LOG_WEIGHTS = True  # include weights & fusion meta in the hybrid output entries

# RRF options (not used here, but leaving for completeness if you switch modes)
RRF_K = 90
NORMALIZE_SCORES = False
# ========================================================================


# ---------- Utilities ----------

def normalize_text(s: Any) -> str:
    if s is None or (isinstance(s, float) and np.isnan(s)):
        return ""
    s = str(s)
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def safe_parse_list(val: Any) -> Any:
    if isinstance(val, (list, dict)):
        return val
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return []
    s = str(val).strip()
    if s == "":
        return []
    try:
        return json.loads(s)
    except Exception:
        try:
            return ast.literal_eval(s)
        except Exception:
            return []

def normalize_groundtruth_str(gt_field: Any) -> str:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        norm_items = [normalize_text(x) for x in data]
        return json.dumps(norm_items, ensure_ascii=False)
    return json.dumps([normalize_text(str(gt_field))], ensure_ascii=False)

def parse_ret_list(ret_field: Any) -> List[Dict[str, Any]]:
    """
    Parse retrieval list into a list of dicts: {doc_id, score, full_text}
    """
    data = safe_parse_list(ret_field)
    out = []
    if isinstance(data, list):
        for d in data:
            if isinstance(d, dict) and "doc_id" in d:
                doc_id = str(d.get("doc_id"))
                score = float(d.get("score", 0.0))
                full_text = d.get("full_text", "")
                out.append({"doc_id": doc_id, "score": score, "full_text": full_text})
    return out

def parse_groundtruth_list(gt_field: Any) -> List[str]:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        return [str(x) for x in data]
    s = str(gt_field).strip()
    if not s:
        return []
    if "," in s:
        return [t.strip() for t in s.split(",") if t.strip()]
    return s.split()

def min_max_normalize(scores: List[float]) -> List[float]:
    if not scores:
        return []
    mn, mx = min(scores), max(scores)
    if mx == mn:
        return [0.0 for _ in scores]
    return [(x - mn) / (mx - mn) for x in scores]


# ---------- Score-based Fusion (Linear Interpolation) ----------

def score_sum_fuse_weighted(lists: List[List[Dict[str, Any]]], weights: List[float], normalize_per_list: bool = True, combmnz: bool = False):
    """
    Weighted linear score fusion:
      fused_score(d) = sum_i w_i * s_i(d)
    where s_i(d) are per-list scores, optionally min-max normalized within each list.
    COMBMNZ variant multiplies by the count of lists where the doc appears:
      fused_mnz(d) = fused_score(d) * (#systems that retrieved d)
    """
    # Build doc_id -> score per list
    score_maps = []
    for lst in lists:
        if normalize_per_list:
            norm_scores = min_max_normalize([x["score"] for x in lst])
            m = {lst[i]["doc_id"]: norm_scores[i] for i in range(len(lst))}
        else:
            m = {x["doc_id"]: float(x["score"]) for x in lst}
        score_maps.append(m)

    all_doc_ids = set().union(*[set(m.keys()) for m in score_maps])

    fused = []
    for d in all_doc_ids:
        per_list_scores = [m.get(d, 0.0) for m in score_maps]
        fused_score = sum(float(weights[i]) * per_list_scores[i] for i in range(len(per_list_scores)))
        num_systems_present = sum(1 for s in per_list_scores if s > 0.0)
        if combmnz:
            fused_score *= max(1, num_systems_present)
        fused.append({
            "doc_id": d,
            "fused_score": fused_score,
            "per_list_scores": per_list_scores,
            "num_systems": num_systems_present
        })

    fused.sort(key=lambda x: x["fused_score"], reverse=True)
    return fused


# ---------- Metrics (text-based evaluation) ----------

def apk(actual: List[str], predicted: List[str], k: int) -> float:
    if not actual:
        return 0.0
    pred_k = predicted[:k]
    hits, score = 0, 0.0
    seen = set()
    for i, p in enumerate(pred_k, start=1):
        if p in actual and p not in seen:
            hits += 1
            score += hits / i
            seen.add(p)
    return score / min(len(actual), k)

def mapk(actual_list: List[List[str]], predicted_list: List[List[str]], k: int) -> float:
    scores = [apk(a, p, k) for a, p in zip(actual_list, predicted_list)]
    return float(np.mean(scores)) if scores else 0.0

def dcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    pred_k = predicted[:k]
    dcg = 0.0
    for i, p in enumerate(pred_k, start=1):
        rel = 1.0 if p in ideal_texts else 0.0
        if rel:
            dcg += rel / np.log2(i + 1)
    return dcg

def idcg_at_k(ideal_texts: List[str], k: int) -> float:
    g = min(len(ideal_texts), k)
    return sum(1.0 / np.log2(i + 1) for i in range(1, g + 1))

def ndcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    idcg = idcg_at_k(ideal_texts, k)
    if idcg == 0.0:
        return 0.0
    return dcg_at_k(predicted, ideal_texts, k) / idcg


# ---------- Main Pipeline ----------

def main():
    # Optional sanity check
    print("CWD:", os.getcwd())

    # Load files
    df_sparse = pd.read_csv(SPLADE_CSV)
    df_dense = pd.read_csv(MPNET_CSV)

    # Preserve originals for output
    df_sparse["_q_orig"] = df_sparse["question"]
    df_sparse["_a_orig"] = df_sparse["answer"]
    df_sparse["_gt_orig"] = df_sparse["groundtruth_docs"]

    df_dense["_q_orig"] = df_dense["question"]
    df_dense["_a_orig"] = df_dense["answer"]
    df_dense["_gt_orig"] = df_dense["groundtruth_docs"]

    # Normalized join keys
    df_sparse["_q_norm"] = df_sparse["question"].apply(normalize_text)
    df_sparse["_a_norm"] = df_sparse["answer"].apply(normalize_text)
    df_sparse["_gt_norm"] = df_sparse["groundtruth_docs"].apply(normalize_groundtruth_str)

    df_dense["_q_norm"] = df_dense["question"].apply(normalize_text)
    df_dense["_a_norm"] = df_dense["answer"].apply(normalize_text)
    df_dense["_gt_norm"] = df_dense["groundtruth_docs"].apply(normalize_groundtruth_str)

    # Merge on normalized keys
    df = pd.merge(
        df_sparse,
        df_dense,
        on=["_q_norm", "_a_norm", "_gt_norm"],
        suffixes=("_splade", "_mpnet"),
        how="inner"
    )

    if df.empty:
        raise ValueError("No rows matched after normalization. Check groundtruth formats across CSVs.")

    hybrid_ret_docs_col = []
    map3_list, ndcg3_list = [], []
    map5_list, ndcg5_list = [], []
    map10_list, ndcg10_list = [], []

    # Global linear interpolation weights (no QC)
    if not USE_GLOBAL_WEIGHTS:
        raise ValueError("This script expects USE_GLOBAL_WEIGHTS=True. Set it at the top.")
    w_sparse, w_dense = GLOBAL_WEIGHTS
    s = w_sparse + w_dense
    if s <= 0:
        w_sparse, w_dense = 0.5, 0.5
    else:
        w_sparse, w_dense = float(w_sparse)/s, float(w_dense)/s
    weights = [w_sparse, w_dense]
    fusion_label = FUSION_MODE

    for _, row in df.iterrows():
        # Parse and cap lists
        splade_list = parse_ret_list(row["splade_ret_docs"])[:TOP_K_PER_LIST]
        mpnet_list  = parse_ret_list(row["mpnet_ret_docs"])[:TOP_K_PER_LIST]
        gt_list_raw = parse_groundtruth_list(row["_gt_orig_splade"])  # use either side; same after join

        # Build full_text lookup for ID->full_text
        fulltext_map = {}
        for d in splade_list:
            fulltext_map.setdefault(d["doc_id"], d.get("full_text", ""))
        for d in mpnet_list:
            fulltext_map.setdefault(d["doc_id"], d.get("full_text", ""))

        # Score-based fusion with global weights
        if FUSION_MODE == "weighted_score_sum":
            fused_score = score_sum_fuse_weighted(
                [splade_list, mpnet_list],
                weights=weights,
                normalize_per_list=SCORE_NORM_PER_LIST,
                combmnz=False
            )
            # Convert to fused entries with ranks
            # Create rank maps for reference (optional)
            # Not needed for sorting, only for logging source ranks
            # We'll compute ranks from the original lists' scores
            # Build rank maps (descending by score)
            def rank_from_scores(items):
                arr_sorted = sorted(items, key=lambda x: x.get("score", 0.0), reverse=True)
                return {it["doc_id"]: idx for idx, it in enumerate(arr_sorted, start=1)}
            rank_maps = [rank_from_scores(splade_list), rank_from_scores(mpnet_list)]

            fused = []
            for item in fused_score:
                d = item["doc_id"]
                fused.append({
                    "doc_id": d,
                    "_fused_score": item["fused_score"],
                    "ranks": [rank_maps[0].get(d), rank_maps[1].get(d)]
                })
            fused.sort(key=lambda x: x["_fused_score"], reverse=True)

        elif FUSION_MODE == "weighted_score_sum_combmnz":
            fused_score = score_sum_fuse_weighted(
                [splade_list, mpnet_list],
                weights=weights,
                normalize_per_list=SCORE_NORM_PER_LIST,
                combmnz=True
            )
            def rank_from_scores(items):
                arr_sorted = sorted(items, key=lambda x: x.get("score", 0.0), reverse=True)
                return {it["doc_id"]: idx for idx, it in enumerate(arr_sorted, start=1)}
            rank_maps = [rank_from_scores(splade_list), rank_from_scores(mpnet_list)]

            fused = []
            for item in fused_score:
                d = item["doc_id"]
                fused.append({
                    "doc_id": d,
                    "_fused_score": item["fused_score"],
                    "ranks": [rank_maps[0].get(d), rank_maps[1].get(d)]
                })
            fused.sort(key=lambda x: x["_fused_score"], reverse=True)

        else:
            raise ValueError(f"Unsupported FUSION_MODE for this script: {FUSION_MODE}")

        # Cap final fused to top-K
        fused = fused[:FINAL_TOP_K]

        # Build hybrid_ret_docs with fused score, rank, full_text, per-model rank, and weights
        hybrid_struct = []
        fused_norm_fulltexts = []  # for evaluation by text
        for idx, item in enumerate(fused, start=1):
            doc_id = item["doc_id"]
            fused_score_val = item["_fused_score"]
            ranks = item.get("ranks")
            if ranks is None:
                ranks = [None, None]
            full_text = fulltext_map.get(doc_id, "")

            entry = {
                "doc_id": doc_id,
                "score": fused_score_val,
                "rank": idx,
                "full_text": full_text,
                "source_ranks": {
                    "splade": ranks[0],
                    "mpnet": ranks[1],
                }
            }
            if LOG_WEIGHTS:
                entry["qc_weights"] = {"splade": weights[0], "mpnet": weights[1]}
                entry["fusion_mode"] = fusion_label
            hybrid_struct.append(entry)

            fused_norm_fulltexts.append(normalize_text(full_text))

        # Save hybrid struct JSON
        hybrid_ret_docs_col.append(json.dumps(hybrid_struct, ensure_ascii=False))

        # Prepare normalized GT texts (maintain order for NDCG ideal)
        gt_norm_texts = [normalize_text(x) for x in gt_list_raw]

        # Metrics by text
        pred = fused_norm_fulltexts
        map3_list.append(apk(gt_norm_texts, pred, 3))
        map5_list.append(apk(gt_norm_texts, pred, 5))
        map10_list.append(apk(gt_norm_texts, pred, 10))
        ndcg3_list.append(ndcg_at_k(pred, gt_norm_texts, 3))
        ndcg5_list.append(ndcg_at_k(pred, gt_norm_texts, 5))
        ndcg10_list.append(ndcg_at_k(pred, gt_norm_texts, 10))

    # Build output with requested columns (use original SPLADE-side text columns)
    out = pd.DataFrame({
        "question": df["_q_orig_splade"],
        "answer": df["_a_orig_splade"],
        "groundtruth_docs": df["_gt_orig_splade"],
        "splade_ret_docs": df["splade_ret_docs"],
        "mpnet_ret_docs": df["mpnet_ret_docs"],
        "hybrid_ret_docs": hybrid_ret_docs_col,
        "MAP@3": map3_list,
        "NDCG@3": ndcg3_list,
        "MAP@5": map5_list,
        "NDCG@5": ndcg5_list,
        "MAP@10": map10_list,
        "NDCG@10": ndcg10_list,
    })

    cols = [
        "question", "answer", "groundtruth_docs",
        "splade_ret_docs", "mpnet_ret_docs", "hybrid_ret_docs",
        "MAP@3", "NDCG@3", "MAP@5", "NDCG@5", "MAP@10", "NDCG@10"
    ]
    out = out[cols]

    out.to_csv(OUTPUT_CSV, index=False, encoding="utf-8")

    overall = {
        "MAP@3": float(np.mean(map3_list)) if map3_list else 0.0,
        "NDCG@3": float(np.mean(ndcg3_list)) if ndcg3_list else 0.0,
        "MAP@5": float(np.mean(map5_list)) if map5_list else 0.0,
        "NDCG@5": float(np.mean(ndcg5_list)) if ndcg5_list else 0.0,
        "MAP@10": float(np.mean(map10_list)) if map10_list else 0.0,
        "NDCG@10": float(np.mean(ndcg10_list)) if ndcg10_list else 0.0,
    }
    print("Saved:", OUTPUT_CSV)
    print("Fusion mode:", FUSION_MODE)
    print("Global weights (λ, 1-λ):", weights)
    print("TOP_K_PER_LIST:", TOP_K_PER_LIST, "FINAL_TOP_K:", FINAL_TOP_K)
    print("Hybrid overall averages:", overall)


# Run
main()

CWD: /home/csmala/journal_rag/hybrid_pipeline
Saved: ./hybrid_linear_interpolation.csv
Fusion mode: weighted_score_sum_combmnz
Global weights (λ, 1-λ): [0.7, 0.3]
TOP_K_PER_LIST: 20 FINAL_TOP_K: 10
Hybrid overall averages: {'MAP@3': 0.48431292269432563, 'NDCG@3': 0.5631704059232733, 'MAP@5': 0.5018147202509123, 'NDCG@5': 0.5883772764828926, 'MAP@10': 0.5156350223348495, 'NDCG@10': 0.6111674154268684}


In [28]:
# 3. implement the cross encoders re-ranker on the fused shortlist?

import json
import ast
import re
import os
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple

# ============ CONFIG ============
# Paths relative to the notebook's folder
SPLADE_CSV = "./prior_results/hotpotqa_splade.csv"
MPNET_CSV  = "./prior_results/hotpotqa_mpnet.csv"
OUTPUT_CSV = "./hybrid_crossencoder_rerank.csv"

# Candidate generation
TOP_K_PER_LIST = 20   # how many from each retriever to consider before re-ranking
FINAL_TOP_K = 10      # final top-N after re-ranking to store/evaluate

# Cross-encoder model
# Option A: HF Hub model name
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
# Option B: Local folder if you have it downloaded (set CE_LOCAL_ONLY=True)
CROSS_ENCODER_MODEL_PATH = "./cross_encoder_ms_marco_minilm_l6_v2"  # change if needed
CE_LOCAL_ONLY = False   # True to force loading from local path/folder only

# Inference
CE_BATCH_SIZE = 32      # adjust based on CPU/GPU
CE_DEVICE = None        # None lets transformers pick automatically; or set "cuda" / "cpu"

# Logging
LOG_WEIGHTS = True      # keep for consistency (not used here but we log fusion_mode)
FUSION_LABEL = "ce_rerank"
# =================================


# ---------- Utilities ----------

def normalize_text(s: Any) -> str:
    if s is None or (isinstance(s, float) and np.isnan(s)):
        return ""
    s = str(s)
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def safe_parse_list(val: Any) -> Any:
    if isinstance(val, (list, dict)):
        return val
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return []
    s = str(val).strip()
    if s == "":
        return []
    try:
        return json.loads(s)
    except Exception:
        try:
            return ast.literal_eval(s)
        except Exception:
            return []

def normalize_groundtruth_str(gt_field: Any) -> str:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        norm_items = [normalize_text(x) for x in data]
        return json.dumps(norm_items, ensure_ascii=False)
    return json.dumps([normalize_text(str(gt_field))], ensure_ascii=False)

def parse_ret_list(ret_field: Any) -> List[Dict[str, Any]]:
    """
    Parse retrieval list into a list of dicts: {doc_id, score, full_text}
    """
    data = safe_parse_list(ret_field)
    out = []
    if isinstance(data, list):
        for d in data:
            if isinstance(d, dict) and "doc_id" in d:
                doc_id = str(d.get("doc_id"))
                score = float(d.get("score", 0.0))
                full_text = d.get("full_text", "")
                out.append({"doc_id": doc_id, "score": score, "full_text": full_text})
    return out

def parse_groundtruth_list(gt_field: Any) -> List[str]:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        return [str(x) for x in data]
    s = str(gt_field).strip()
    if not s:
        return []
    if "," in s:
        return [t.strip() for t in s.split(",") if t.strip()]
    return s.split()

# ---------- Metrics (text-based evaluation) ----------

def apk(actual: List[str], predicted: List[str], k: int) -> float:
    if not actual:
        return 0.0
    pred_k = predicted[:k]
    hits, score = 0, 0.0
    seen = set()
    for i, p in enumerate(pred_k, start=1):
        if p in actual and p not in seen:
            hits += 1
            score += hits / i
            seen.add(p)
    return score / min(len(actual), k)

def mapk(actual_list: List[List[str]], predicted_list: List[List[str]], k: int) -> float:
    scores = [apk(a, p, k) for a, p in zip(actual_list, predicted_list)]
    return float(np.mean(scores)) if scores else 0.0

def dcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    pred_k = predicted[:k]
    dcg = 0.0
    for i, p in enumerate(pred_k, start=1):
        rel = 1.0 if p in ideal_texts else 0.0
        if rel:
            dcg += rel / np.log2(i + 1)
    return dcg

def idcg_at_k(ideal_texts: List[str], k: int) -> float:
    g = min(len(ideal_texts), k)
    return sum(1.0 / np.log2(i + 1) for i in range(1, g + 1))

def ndcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    idcg = idcg_at_k(ideal_texts, k)
    if idcg == 0.0:
        return 0.0
    return dcg_at_k(predicted, ideal_texts, k) / idcg


# ---------- Cross-Encoder Loading and Scoring ----------

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

_ce_tokenizer = None
_ce_model = None

def load_cross_encoder():
    global _ce_tokenizer, _ce_model
    if _ce_tokenizer is not None and _ce_model is not None:
        return
    if CE_LOCAL_ONLY:
        if not os.path.isdir(CROSS_ENCODER_MODEL_PATH):
            raise OSError(
                f"Local CE folder not found: {CROSS_ENCODER_MODEL_PATH}. "
                f"Set CE_LOCAL_ONLY=False to load from HF Hub or place the model locally."
            )
        _ce_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER_MODEL_PATH, local_files_only=True)
        _ce_model = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER_MODEL_PATH, local_files_only=True)
    else:
        # Try local path first if exists, else HF Hub
        if os.path.isdir(CROSS_ENCODER_MODEL_PATH):
            _ce_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER_MODEL_PATH)
            _ce_model = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER_MODEL_PATH)
        else:
            _ce_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER_MODEL)
            _ce_model = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER_MODEL)
    _ce_model.eval()
    if CE_DEVICE:
        _ce_model.to(CE_DEVICE)

def ce_score_pairs(pairs: List[Tuple[str, str]], batch_size: int = 32) -> List[float]:
    """
    pairs: list of (query, full_text)
    returns: list of float scores (higher = more relevant)
    """
    load_cross_encoder()
    scores = []
    device = CE_DEVICE or ("cuda" if torch.cuda.is_available() else "cpu")
    _ce_model.to(device)

    # Some CE heads output logits with shape [batch, 1] or [batch, 2]
    # We’ll take the positive class logit or the single logit.
    with torch.no_grad():
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i+batch_size]
            texts = list(batch)  # [(q, d), ...]
            inputs = _ce_tokenizer(
                [t[0] for t in texts],
                [t[1] for t in texts],
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = _ce_model(**inputs)
            logits = outputs.logits  # [B, C] or [B]
            if logits.dim() == 1:
                batch_scores = logits.detach().float().cpu().tolist()
            else:
                # If binary, take the logit of the positive class (assume index 1)
                if logits.size(-1) == 1:
                    batch_scores = logits.squeeze(-1).detach().float().cpu().tolist()
                else:
                    batch_scores = logits[:, -1].detach().float().cpu().tolist()
            scores.extend(batch_scores)
    return scores


# ---------- Main Pipeline with CE re-ranking ----------

def main():
    print("CWD:", os.getcwd())
    # Load CSVs
    df_sparse = pd.read_csv(SPLADE_CSV)
    df_dense = pd.read_csv(MPNET_CSV)

    # Preserve originals
    df_sparse["_q_orig"] = df_sparse["question"]
    df_sparse["_a_orig"] = df_sparse["answer"]
    df_sparse["_gt_orig"] = df_sparse["groundtruth_docs"]

    df_dense["_q_orig"] = df_dense["question"]
    df_dense["_a_orig"] = df_dense["answer"]
    df_dense["_gt_orig"] = df_dense["groundtruth_docs"]

    # Normalized join keys
    def normalize_groundtruth_str(gt_field: Any) -> str:
        data = safe_parse_list(gt_field)
        if isinstance(data, list):
            norm_items = [normalize_text(x) for x in data]
            return json.dumps(norm_items, ensure_ascii=False)
        return json.dumps([normalize_text(str(gt_field))], ensure_ascii=False)

    df_sparse["_q_norm"] = df_sparse["question"].apply(normalize_text)
    df_sparse["_a_norm"] = df_sparse["answer"].apply(normalize_text)
    df_sparse["_gt_norm"] = df_sparse["groundtruth_docs"].apply(normalize_groundtruth_str)

    df_dense["_q_norm"] = df_dense["question"].apply(normalize_text)
    df_dense["_a_norm"] = df_dense["answer"].apply(normalize_text)
    df_dense["_gt_norm"] = df_dense["groundtruth_docs"].apply(normalize_groundtruth_str)

    # Merge
    df = pd.merge(
        df_sparse,
        df_dense,
        on=["_q_norm", "_a_norm", "_gt_norm"],
        suffixes=("_splade", "_mpnet"),
        how="inner"
    )

    if df.empty:
        raise ValueError("No rows matched after normalization. Check groundtruth formats across CSVs.")

    hybrid_ret_docs_col = []
    map3_list, ndcg3_list = [], []
    map5_list, ndcg5_list = [], []
    map10_list, ndcg10_list = [], []

    for _, row in df.iterrows():
        # Parse and cap lists
        splade_list = parse_ret_list(row["splade_ret_docs"])[:TOP_K_PER_LIST]
        mpnet_list  = parse_ret_list(row["mpnet_ret_docs"])[:TOP_K_PER_LIST]
        gt_list_raw = parse_groundtruth_list(row["_gt_orig_splade"])

        # Build candidate pool = union by doc_id, preserving first seen full_text
        cand_map = {}
        for d in splade_list:
            if d["doc_id"] not in cand_map:
                cand_map[d["doc_id"]] = {"doc_id": d["doc_id"], "full_text": d.get("full_text", ""), "sources": {"splade": True}}
            else:
                cand_map[d["doc_id"]]["sources"]["splade"] = True
        for d in mpnet_list:
            if d["doc_id"] not in cand_map:
                cand_map[d["doc_id"]] = {"doc_id": d["doc_id"], "full_text": d.get("full_text", ""), "sources": {"mpnet": True}}
            else:
                cand_map[d["doc_id"]]["sources"]["mpnet"] = True

        candidates = list(cand_map.values())
        # Re-rank with cross-encoder
        question_text = row["_q_orig_splade"]
        pairs = [(question_text, c["full_text"]) for c in candidates]
        if pairs:
            scores = ce_score_pairs(pairs, batch_size=CE_BATCH_SIZE)
        else:
            scores = []

        # Attach scores and sort
        for i, c in enumerate(candidates):
            c["_ce_score"] = float(scores[i]) if i < len(scores) else float("-inf")

        candidates.sort(key=lambda x: x["_ce_score"], reverse=True)
        reranked = candidates[:FINAL_TOP_K]

        # Build hybrid_ret_docs with CE score and rank
        hybrid_struct = []
        fused_norm_fulltexts = []
        for idx, c in enumerate(reranked, start=1):
            entry = {
                "doc_id": c["doc_id"],
                "score": c["_ce_score"],
                "rank": idx,
                "full_text": c.get("full_text", ""),
                "source_ranks": {  # we don’t compute per-list ranks for CE; mark presence
                    "splade": 1 if c["sources"].get("splade") else None,
                    "mpnet": 1 if c["sources"].get("mpnet") else None,
                },
                "fusion_mode": FUSION_LABEL
            }
            if LOG_WEIGHTS:
                entry["qc_weights"] = None
            hybrid_struct.append(entry)
            fused_norm_fulltexts.append(normalize_text(entry["full_text"]))

        hybrid_ret_docs_col.append(json.dumps(hybrid_struct, ensure_ascii=False))

        # Metrics by text
        gt_norm_texts = [normalize_text(x) for x in gt_list_raw]
        pred = fused_norm_fulltexts
        map3_list.append(apk(gt_norm_texts, pred, 3))
        map5_list.append(apk(gt_norm_texts, pred, 5))
        map10_list.append(apk(gt_norm_texts, pred, 10))
        ndcg3_list.append(ndcg_at_k(pred, gt_norm_texts, 3))
        ndcg5_list.append(ndcg_at_k(pred, gt_norm_texts, 5))
        ndcg10_list.append(ndcg_at_k(pred, gt_norm_texts, 10))

    # Output
    out = pd.DataFrame({
        "question": df["_q_orig_splade"],
        "answer": df["_a_orig_splade"],
        "groundtruth_docs": df["_gt_orig_splade"],
        "splade_ret_docs": df["splade_ret_docs"],
        "mpnet_ret_docs": df["mpnet_ret_docs"],
        "hybrid_ret_docs": hybrid_ret_docs_col,
        "MAP@3": map3_list,
        "NDCG@3": ndcg3_list,
        "MAP@5": map5_list,
        "NDCG@5": ndcg5_list,
        "MAP@10": map10_list,
        "NDCG@10": ndcg10_list,
    })

    cols = [
        "question", "answer", "groundtruth_docs",
        "splade_ret_docs", "mpnet_ret_docs", "hybrid_ret_docs",
        "MAP@3", "NDCG@3", "MAP@5", "NDCG@5", "MAP@10", "NDCG@10"
    ]
    out = out[cols]
    out.to_csv(OUTPUT_CSV, index=False, encoding="utf-8")

    overall = {
        "MAP@3": float(np.mean(map3_list)) if map3_list else 0.0,
        "NDCG@3": float(np.mean(ndcg3_list)) if ndcg3_list else 0.0,
        "MAP@5": float(np.mean(map5_list)) if map5_list else 0.0,
        "NDCG@5": float(np.mean(ndcg5_list)) if ndcg5_list else 0.0,
        "MAP@10": float(np.mean(map10_list)) if map10_list else 0.0,
        "NDCG@10": float(np.mean(ndcg10_list)) if ndcg10_list else 0.0,
    }
    print("Saved:", OUTPUT_CSV)
    print("Fusion mode:", FUSION_LABEL)
    print("TOP_K_PER_LIST:", TOP_K_PER_LIST, "FINAL_TOP_K:", FINAL_TOP_K)
    print("Cross-Encoder:", CROSS_ENCODER_MODEL if not CE_LOCAL_ONLY else CROSS_ENCODER_MODEL_PATH)
    print("Hybrid overall averages:", overall)


# Run
main()

CWD: /home/csmala/journal_rag/hybrid_pipeline
Saved: ./hybrid_crossencoder_rerank.csv
Fusion mode: ce_rerank
TOP_K_PER_LIST: 20 FINAL_TOP_K: 10
Cross-Encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
Hybrid overall averages: {'MAP@3': 0.5191621427632709, 'NDCG@3': 0.5946335494517143, 'MAP@5': 0.5326992033886311, 'NDCG@5': 0.6147403898582319, 'MAP@10': 0.5451409088593357, 'NDCG@10': 0.6352804956790745}


Fusion mode: ce_rerank
TOP_K_PER_LIST: 20 FINAL_TOP_K: 10
Cross-Encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
Hybrid overall averages: {'MAP@3': 0.5191621427632709, 'NDCG@3': 0.5946335494517143, 'MAP@5': 0.5326992033886311, 'NDCG@5': 0.6147403898582319, 'MAP@10': 0.5451409088593357, 'NDCG@10': 0.6352804956790745}

#### Cross encoders are working better so far. I am trying to optimise their performance by adding other methods on top of that.

- Candidate generation:
- - Weighted RRF over SPLADE + MPNet (with global weights favoring SPLADE modestly).
- - Then score-based linear interpolation with COMBMNZ on the union to reinforce agreement.
- - This “two-step fusion” sharpens the shortlist before CE.

- Cross-encoder re-ranking:
- - Use cross-encoder/ms-marco-MiniLM-L-12-v2 (stronger than L-6, still efficient).
- - Re-rank top-N from fusion (N typically 60).

- Post-CE blending:
- - FinalScore = α·CE + (1−α)·Fused (both min-max normalized per query).
- - This stabilizes CE variability and often yields extra gains.

- Practical defaults:
- - Weighted RRF K=60, λ=(0.65, 0.35), Interp+COMBMNZ, candidate_pool_size=60, CE=L-12, α=0.85.
- - You can raise candidate_pool_size to 80 if your hardware allows; it often helps slightly.

Tuning tips to push toward +7–8% over SPLADE

Increase CANDIDATE_POOL_SIZE to 80 if you can afford the CE pass; often yields a further lift.
Try INTERP_GLOBAL_WEIGHTS = (0.6, 0.4) if MPNet is somewhat complementary on your dataset.
If CE latency is acceptable, test CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2" (already set). If still underwhelming, try "cross-encoder/ms-marco-electra-base".
Sweep ALPHA_CE_BLEND in {0.8, 0.85, 0.9}. 0.85 is a good default; 0.9 biases more to CE.
If you have memory/time, test TOP_K_PER_LIST=50 and/or RRF_K in {30, 60, 120}.
Why this should outperform prior attempts

RRF reduces sensitivity to raw scores and brings complementary docs up.
Score interpolation with COMBMNZ rewards agreement and sharpens the pool.
CE L-12 is materially stronger than L-6; blending stabilizes edge cases where CE alone may shuffle near-ties suboptimally.
This stacking structure (RRF -> Interp+COMBMNZ -> CE -> Blend) is a proven recipe in IR stacks to get consistent gains over strong sparse baselines.

A tiny sweep harness to run a short grid on:

λ for RRF and interpolation,
COMBMNZ on/off,
candidate_pool_size ∈ {60, 80},
α ∈ {0.8, 0.85, 0.9},
and print a summary table so you can pick the best configuration.

In [2]:
# Cross encoders are working better so far. I am trying to optimise their performance by adding other methods on top of that.


import json
import ast
import re
import os
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple

# =================== CONFIG ===================
# Input/Output
SPLADE_CSV = "./prior_results/hotpotqa_splade.csv"
MPNET_CSV  = "./prior_results/hotpotqa_mpnet.csv"
OUTPUT_CSV = "./hybrid_optimized_fusion_ce.csv"

# Candidate generation caps
TOP_K_PER_LIST = 40         # take top-N from each retriever before initial fusion
CANDIDATE_POOL_SIZE = 60    # number of candidates passed to CE after fusion (try 60 or 80)
FINAL_TOP_K = 10            # final top-N to save/evaluate

# Weighted RRF settings (stage A)
USE_WEIGHTED_RRF = True
RRF_K = 60
RRF_GLOBAL_WEIGHTS = (0.65, 0.35)  # (w_splade, w_mpnet). Slight SPLADE bias but still lets MPNet matter

# Score interpolation settings (stage B) – applied on the union after RRF
USE_SCORE_INTERP = True
SCORE_NORM_PER_LIST = True     # min-max normalize per list before interpolation
INTERP_GLOBAL_WEIGHTS = (0.7, 0.3)  # λ for (splade, mpnet)
USE_COMBMNZ = True             # multiply interpolated score by #systems that retrieved the doc (reinforce agreement)

# Cross-encoder model (stage C)
# Stronger than L-6: slightly slower but typically better MAP/NDCG
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
CROSS_ENCODER_MODEL_PATH = "./ce_ms_marco_minilm_l12_v2"  # optional local cache/folder
CE_LOCAL_ONLY = False
CE_BATCH_SIZE = 32
CE_DEVICE = None  # "cuda" or "cpu" or None (auto)

# Post-CE blending (stage D)
USE_POST_CE_BLENDING = True
ALPHA_CE_BLEND = 0.85  # Final = α·CE + (1−α)·Fused (both min-max normalized per query)

# Logging
LOG_RUN_META = True
# ==============================================


# --------------- Utilities ---------------

def normalize_text(s: Any) -> str:
    if s is None or (isinstance(s, float) and np.isnan(s)):
        return ""
    s = str(s)
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def safe_parse_list(val: Any) -> Any:
    if isinstance(val, (list, dict)):
        return val
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return []
    s = str(val).strip()
    if s == "":
        return []
    try:
        return json.loads(s)
    except Exception:
        try:
            return ast.literal_eval(s)
        except Exception:
            return []

def normalize_groundtruth_str(gt_field: Any) -> str:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        norm_items = [normalize_text(x) for x in data]
        return json.dumps(norm_items, ensure_ascii=False)
    return json.dumps([normalize_text(str(gt_field))], ensure_ascii=False)

def parse_ret_list(ret_field: Any) -> List[Dict[str, Any]]:
    data = safe_parse_list(ret_field)
    out = []
    if isinstance(data, list):
        for d in data:
            if isinstance(d, dict) and "doc_id" in d:
                doc_id = str(d.get("doc_id"))
                score = float(d.get("score", 0.0))
                full_text = d.get("full_text", "")
                out.append({"doc_id": doc_id, "score": score, "full_text": full_text})
    return out

def parse_groundtruth_list(gt_field: Any) -> List[str]:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        return [str(x) for x in data]
    s = str(gt_field).strip()
    if not s:
        return []
    if "," in s:
        return [t.strip() for t in s.split(",") if t.strip()]
    return s.split()

def min_max_normalize(values: List[float]) -> List[float]:
    if not values:
        return []
    mn, mx = min(values), max(values)
    if mx == mn:
        return [0.0 for _ in values]
    return [(x - mn) / (mx - mn) for x in values]

def per_list_minmax_map(items: List[Dict[str, Any]]) -> Dict[str, float]:
    scores = [x["score"] for x in items]
    norm = min_max_normalize(scores)
    return {items[i]["doc_id"]: norm[i] for i in range(len(items))}

# --------------- Fusion blocks ---------------

def rrf_fuse_weighted(lists: List[List[Dict[str, Any]]], weights: List[float], k: int = 60):
    """
    Weighted Reciprocal Rank Fusion:
    score(d) = Σ_i w_i * 1 / (k + rank_i(d))
    """
    rank_maps = []
    for lst in lists:
        ranks = {it["doc_id"]: r for r, it in enumerate(sorted(lst, key=lambda x: x["score"], reverse=True), start=1)}
        rank_maps.append(ranks)
    all_doc_ids = set().union(*[set(m.keys()) for m in rank_maps])

    fused = []
    for d in all_doc_ids:
        s = 0.0
        for i, ranks in enumerate(rank_maps):
            r = ranks.get(d)
            if r is not None:
                s += float(weights[i]) * (1.0 / (k + r))
        fused.append({"doc_id": d, "fused_rrf": s})
    fused.sort(key=lambda x: x["fused_rrf"], reverse=True)
    return fused

def score_sum_fuse_weighted(lists: List[List[Dict[str, Any]]], weights: List[float], normalize_per_list: bool = True, combmnz: bool = False):
    """
    Weighted linear interpolation with optional COMBMNZ.
    """
    score_maps = []
    for lst in lists:
        if normalize_per_list:
            m = per_list_minmax_map(lst)
        else:
            m = {x["doc_id"]: float(x["score"]) for x in lst}
        score_maps.append(m)

    all_doc_ids = set().union(*[set(m.keys()) for m in score_maps])
    fused = []
    for d in all_doc_ids:
        per_scores = [m.get(d, 0.0) for m in score_maps]
        val = sum(float(weights[i]) * per_scores[i] for i in range(len(per_scores)))
        if combmnz:
            c = sum(1 for s in per_scores if s > 0.0)
            val *= max(1, c)
        fused.append({"doc_id": d, "fused_interp": val})
    fused.sort(key=lambda x: x["fused_interp"], reverse=True)
    return fused

# --------------- Metrics ---------------

def apk(actual: List[str], predicted: List[str], k: int) -> float:
    if not actual:
        return 0.0
    pred_k = predicted[:k]
    hits, score = 0, 0.0
    seen = set()
    for i, p in enumerate(pred_k, start=1):
        if p in actual and p not in seen:
            hits += 1
            score += hits / i
            seen.add(p)
    return score / min(len(actual), k)

def ndcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    pred_k = predicted[:k]
    dcg = 0.0
    for i, p in enumerate(pred_k, start=1):
        rel = 1.0 if p in ideal_texts else 0.0
        if rel:
            dcg += rel / np.log2(i + 1)
    g = min(len(ideal_texts), k)
    idcg = sum(1.0 / np.log2(i + 1) for i in range(1, g + 1))
    return (dcg / idcg) if idcg > 0.0 else 0.0

# --------------- Cross-Encoder ---------------

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

_ce_tokenizer = None
_ce_model = None

def load_cross_encoder():
    global _ce_tokenizer, _ce_model
    if _ce_tokenizer is not None and _ce_model is not None:
        return
    if CE_LOCAL_ONLY:
        if not os.path.isdir(CROSS_ENCODER_MODEL_PATH):
            raise OSError(
                f"Local CE folder not found: {CROSS_ENCODER_MODEL_PATH}. "
                f"Set CE_LOCAL_ONLY=False to load from HF Hub or place the model locally."
            )
        _ce_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER_MODEL_PATH, local_files_only=True)
        _ce_model = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER_MODEL_PATH, local_files_only=True)
    else:
        if os.path.isdir(CROSS_ENCODER_MODEL_PATH):
            _ce_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER_MODEL_PATH)
            _ce_model = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER_MODEL_PATH)
        else:
            _ce_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER_MODEL)
            _ce_model = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER_MODEL)
    _ce_model.eval()
    if CE_DEVICE:
        _ce_model.to(CE_DEVICE)

def ce_score_pairs(pairs: List[Tuple[str, str]], batch_size: int = 32) -> List[float]:
    load_cross_encoder()
    device = CE_DEVICE or ("cuda" if torch.cuda.is_available() else "cpu")
    _ce_model.to(device)

    scores = []
    with torch.no_grad():
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i+batch_size]
            inputs = _ce_tokenizer(
                [q for q, _ in batch],
                [d for _, d in batch],
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = _ce_model(**inputs)
            logits = outputs.logits
            if logits.dim() == 1:
                batch_scores = logits.detach().float().cpu().tolist()
            else:
                if logits.size(-1) == 1:
                    batch_scores = logits.squeeze(-1).detach().float().cpu().tolist()
                else:
                    batch_scores = logits[:, -1].detach().float().cpu().tolist()
            scores.extend(batch_scores)
    return scores

# --------------- Main: Optimized pipeline ---------------

def main():
    print("CWD:", os.getcwd())

    # Load CSVs
    df_sparse = pd.read_csv(SPLADE_CSV)
    df_dense = pd.read_csv(MPNET_CSV)

    # Preserve originals
    df_sparse["_q_orig"] = df_sparse["question"]
    df_sparse["_a_orig"] = df_sparse["answer"]
    df_sparse["_gt_orig"] = df_sparse["groundtruth_docs"]

    df_dense["_q_orig"] = df_dense["question"]
    df_dense["_a_orig"] = df_dense["answer"]
    df_dense["_gt_orig"] = df_dense["groundtruth_docs"]

    # Normalize join keys
    df_sparse["_q_norm"] = df_sparse["question"].apply(normalize_text)
    df_sparse["_a_norm"] = df_sparse["answer"].apply(normalize_text)
    df_sparse["_gt_norm"] = df_sparse["groundtruth_docs"].apply(normalize_groundtruth_str)

    df_dense["_q_norm"] = df_dense["question"].apply(normalize_text)
    df_dense["_a_norm"] = df_dense["answer"].apply(normalize_text)
    df_dense["_gt_norm"] = df_dense["groundtruth_docs"].apply(normalize_groundtruth_str)

    # Merge
    df = pd.merge(
        df_sparse, df_dense,
        on=["_q_norm", "_a_norm", "_gt_norm"],
        suffixes=("_splade", "_mpnet"),
        how="inner"
    )
    if df.empty:
        raise ValueError("No rows matched after normalization. Check groundtruth formats across CSVs.")

    # Normalize weights
    def norm_pair(w0, w1):
        s = float(w0) + float(w1)
        return (0.5, 0.5) if s <= 0 else (float(w0)/s, float(w1)/s)

    w_rrf = norm_pair(*RRF_GLOBAL_WEIGHTS)
    w_interp = norm_pair(*INTERP_GLOBAL_WEIGHTS)

    hybrid_ret_docs_col = []
    map3_list, map5_list, map10_list = [], [], []
    ndcg3_list, ndcg5_list, ndcg10_list = [], [], []

    for _, row in df.iterrows():
        # Parse and cap inputs
        splade_list = parse_ret_list(row["splade_ret_docs"])[:TOP_K_PER_LIST]
        mpnet_list  = parse_ret_list(row["mpnet_ret_docs"])[:TOP_K_PER_LIST]
        gt_list_raw = parse_groundtruth_list(row["_gt_orig_splade"])
        q_text = row["_q_orig_splade"]

        # Stage A: Weighted RRF to get an initial global ranking (conservative blend)
        if USE_WEIGHTED_RRF:
            rrf_fused = rrf_fuse_weighted([splade_list, mpnet_list], weights=w_rrf, k=RRF_K)
        else:
            # If disabled, start with union pool
            rrf_fused = [{"doc_id": x["doc_id"], "fused_rrf": 0.0} for x in (splade_list + mpnet_list)]

        # Collect union set and maps
        splade_map = {x["doc_id"]: x for x in splade_list}
        mpnet_map  = {x["doc_id"]: x for x in mpnet_list}

        # Attach a conservative fused score for stage B fallback
        # Build initial candidate list ordered by RRF
        candidates = []
        seen = set()
        for item in rrf_fused:
            d = item["doc_id"]
            if d in seen:
                continue
            seen.add(d)
            full_text = splade_map.get(d, mpnet_map.get(d, {})).get("full_text", "")
            candidates.append({
                "doc_id": d,
                "full_text": full_text,
                "rrf_score": item.get("fused_rrf", 0.0)
            })

        # Also ensure everything from both lists is included (union), with very small rrf if missing
        for dct in (splade_list + mpnet_list):
            d = dct["doc_id"]
            if d not in seen:
                seen.add(d)
                candidates.append({
                    "doc_id": d,
                    "full_text": dct.get("full_text", ""),
                    "rrf_score": 0.0
                })

        # Stage B: Score interpolation with COMBMNZ on the union, reinforce agreement
        if USE_SCORE_INTERP:
            # Build fused interpolation score on union
            # Prepare per-list normalized score maps
            s_map = per_list_minmax_map(splade_list) if SCORE_NORM_PER_LIST else {x["doc_id"]: float(x["score"]) for x in splade_list}
            m_map = per_list_minmax_map(mpnet_list)  if SCORE_NORM_PER_LIST else {x["doc_id"]: float(x["score"]) for x in mpnet_list}

            interp_list = []
            for c in candidates:
                d = c["doc_id"]
                s = s_map.get(d, 0.0)
                m = m_map.get(d, 0.0)
                fused = w_interp[0] * s + w_interp[1] * m
                if USE_COMBMNZ:
                    present = (1 if s > 0.0 else 0) + (1 if m > 0.0 else 0)
                    fused *= max(1, present)
                interp_list.append({"doc_id": d, "full_text": c["full_text"], "interp_score": fused, "rrf_score": c["rrf_score"]})

            # Combine interpolation and RRF as a pre-CE ranking
            # We can simply rank by interp_score; if ties, use rrf_score
            interp_list.sort(key=lambda x: (x["interp_score"], x["rrf_score"]), reverse=True)
            pre_ce = interp_list
        else:
            # Fallback to RRF ordering
            pre_ce = [{"doc_id": c["doc_id"], "full_text": c["full_text"], "interp_score": c["rrf_score"], "rrf_score": c["rrf_score"]} for c in candidates]

        # Truncate to the candidate pool for CE
        pre_ce = pre_ce[:CANDIDATE_POOL_SIZE]

        # Stage C: Cross-Encoder re-rank
        pairs = [(q_text, it["full_text"]) for it in pre_ce]
        if pairs:
            ce_scores = ce_score_pairs(pairs, batch_size=CE_BATCH_SIZE)
        else:
            ce_scores = []

        for i, it in enumerate(pre_ce):
            it["_ce_score"] = float(ce_scores[i]) if i < len(ce_scores) else float("-inf")

        # Stage D: Post-CE blending with pre-CE fused scores (stabilizes and often boosts)
        if USE_POST_CE_BLENDING:
            # Normalize CE and preCE scores per query
            ce_vals = [it["_ce_score"] for it in pre_ce]
            fused_vals = [it["interp_score"] for it in pre_ce]
            ce_norm = min_max_normalize(ce_vals)
            fused_norm = min_max_normalize(fused_vals)
            for i, it in enumerate(pre_ce):
                it["_final_score"] = ALPHA_CE_BLEND * ce_norm[i] + (1.0 - ALPHA_CE_BLEND) * fused_norm[i]
        else:
            # Use CE scores only
            for it in pre_ce:
                it["_final_score"] = it["_ce_score"]

        # Final sort and slice
        pre_ce.sort(key=lambda x: x["_final_score"], reverse=True)
        final = pre_ce[:FINAL_TOP_K]

        # Build output struct and compute metrics (by normalized full_text)
        hybrid_struct = []
        pred_texts = []
        for rank_idx, it in enumerate(final, start=1):
            entry = {
                "doc_id": it["doc_id"],
                "score": it["_final_score"],
                "rank": rank_idx,
                "full_text": it["full_text"],
                "fusion_mode": "weighted_rrf -> score_interp_combmnz -> CE(L12) -> blend"
            }
            if LOG_RUN_META:
                entry["meta"] = {
                    "rrf_k": RRF_K,
                    "rrf_weights": {"splade": w_rrf[0], "mpnet": w_rrf[1]},
                    "interp_weights": {"splade": w_interp[0], "mpnet": w_interp[1]},
                    "combmnz": USE_COMBMNZ,
                    "candidate_pool_size": CANDIDATE_POOL_SIZE,
                    "ce_model": CROSS_ENCODER_MODEL if not CE_LOCAL_ONLY else CROSS_ENCODER_MODEL_PATH,
                    "alpha_ce_blend": ALPHA_CE_BLEND if USE_POST_CE_BLENDING else None
                }
            hybrid_struct.append(entry)
            pred_texts.append(normalize_text(it["full_text"]))

        hybrid_ret_docs = json.dumps(hybrid_struct, ensure_ascii=False)

        # Metrics by text
        gt_norm_texts = [normalize_text(x) for x in gt_list_raw]
        map3_list.append(apk(gt_norm_texts, pred_texts, 3))
        map5_list.append(apk(gt_norm_texts, pred_texts, 5))
        map10_list.append(apk(gt_norm_texts, pred_texts, 10))
        ndcg3_list.append(ndcg_at_k(pred_texts, gt_norm_texts, 3))
        ndcg5_list.append(ndcg_at_k(pred_texts, gt_norm_texts, 5))
        ndcg10_list.append(ndcg_at_k(pred_texts, gt_norm_texts, 10))

        # Store per-row
        row["_hybrid_ret_docs"] = hybrid_ret_docs

        # Assign back (collect later)
        # We’ll not mutate df in-place per cell to avoid SettingWithCopy; collect a list instead
        if "_hybrid_rows" not in locals():
            _hybrid_rows = []
        _hybrid_rows.append(hybrid_ret_docs)

    # Output dataframe
    out = pd.DataFrame({
        "question": df["_q_orig_splade"],
        "answer": df["_a_orig_splade"],
        "groundtruth_docs": df["_gt_orig_splade"],
        "splade_ret_docs": df["splade_ret_docs"],
        "mpnet_ret_docs": df["mpnet_ret_docs"],
        "hybrid_ret_docs": _hybrid_rows,
        "MAP@3": map3_list,
        "NDCG@3": ndcg3_list,
        "MAP@5": map5_list,
        "NDCG@5": ndcg5_list,
        "MAP@10": map10_list,
        "NDCG@10": ndcg10_list,
    })

    cols = [
        "question", "answer", "groundtruth_docs",
        "splade_ret_docs", "mpnet_ret_docs", "hybrid_ret_docs",
        "MAP@3", "NDCG@3", "MAP@5", "NDCG@5", "MAP@10", "NDCG@10"
    ]
    out = out[cols]
    out.to_csv(OUTPUT_CSV, index=False, encoding="utf-8")

    overall = {
        "MAP@3": float(np.mean(map3_list)) if map3_list else 0.0,
        "NDCG@3": float(np.mean(ndcg3_list)) if ndcg3_list else 0.0,
        "MAP@5": float(np.mean(map5_list)) if map5_list else 0.0,
        "NDCG@5": float(np.mean(ndcg5_list)) if ndcg5_list else 0.0,
        "MAP@10": float(np.mean(map10_list)) if map10_list else 0.0,
        "NDCG@10": float(np.mean(ndcg10_list)) if ndcg10_list else 0.0,
    }

    print("Saved:", OUTPUT_CSV)
    print("Pipeline: weighted_rrf -> score_interp_combmnz -> CE(L12) -> blend")
    print("Weighted RRF: K =", RRF_K, "weights =", w_rrf)
    print("Interpolation weights =", w_interp, "COMBMNZ =", USE_COMBMNZ)
    print("Candidate pool size =", CANDIDATE_POOL_SIZE, "| FINAL_TOP_K =", FINAL_TOP_K)
    print("Cross-Encoder:", CROSS_ENCODER_MODEL if not CE_LOCAL_ONLY else CROSS_ENCODER_MODEL_PATH)
    print("Post-CE blending:", USE_POST_CE_BLENDING, "alpha =", ALPHA_CE_BLEND)
    print("Hybrid overall averages:", overall)

# Run
main()

  from .autonotebook import tqdm as notebook_tqdm


CWD: /home/csmala/journal_rag/hybrid_pipeline
Saved: ./hybrid_optimized_fusion_ce.csv
Pipeline: weighted_rrf -> score_interp_combmnz -> CE(L12) -> blend
Weighted RRF: K = 60 weights = (0.65, 0.35)
Interpolation weights = (0.7, 0.3) COMBMNZ = True
Candidate pool size = 60 | FINAL_TOP_K = 10
Cross-Encoder: cross-encoder/ms-marco-MiniLM-L-12-v2
Post-CE blending: True alpha = 0.85
Hybrid overall averages: {'MAP@3': 0.5144442675320212, 'NDCG@3': 0.5901077866678365, 'MAP@5': 0.5279396747591463, 'NDCG@5': 0.6103883965954874, 'MAP@10': 0.5401260917400604, 'NDCG@10': 0.6306136362160123}


Time taken: 47m 43.1s
Pipeline: weighted_rrf -> score_interp_combmnz -> CE(L12) -> blend
Weighted RRF: K = 60 weights = (0.65, 0.35)
Interpolation weights = (0.7, 0.3) COMBMNZ = True
Candidate pool size = 60 | FINAL_TOP_K = 10
Cross-Encoder: cross-encoder/ms-marco-MiniLM-L-12-v2
Post-CE blending: True alpha = 0.85
Hybrid overall averages: {'MAP@3': 0.5144442675320212, 'NDCG@3': 0.5901077866678365, 'MAP@5': 0.5279396747591463, 'NDCG@5': 0.6103883965954874, 'MAP@10': 0.5401260917400604, 'NDCG@10': 0.6306136362160123}

Below code is for adaptive weighting scheme driven by a simple proxy for “keywordiness” of the query, which tends to correlate with sparse vs. dense effectiveness:

Keyword-heavy queries (many informative/high-IDF terms) → favor SPLADE (sparse).
Keyword-light/natural-language queries (few informative terms) → favor MPNet (dense).
This avoids your BERT QC’s overconfidence while still being dynamic per query.

How it differs from fixed linear interpolation

Linear interpolation with a fixed λ uses the same weights for every query.
Your TF-IDF-based approach computes λ dynamically from the query’s term statistics. It’s closer to a “soft decision rule” for query complexity, but simple, stable, and explainable.
A practical, robust way to implement it

Compute a “keywordiness” score from the query via TF-IDF or IDF lookup:
Example metric: average IDF of tokens present in the query.
Normalize it to [0, 1], then map to a weight λ for sparse.
Clamp λ to a band (e.g., [0.3, 0.7]) to avoid extreme swings.
Use that λ in score-based fusion (min–max normalize per list, then λ·S_splade + (1−λ)·S_mpnet).
Optionally apply COMBMNZ to reward agreement.
Then pass the fused shortlist to your cross-encoder for re-ranking and post-CE blending (as in your working pipeline).
In this script that:

Builds an IDF dictionary on-the-fly from your retriever corpora proxy (document full_texts across SPLADE/MPNet outputs).
Computes a per-query λ from average IDF of the query tokens.
Fuses using dynamic λ, then CE re-ranks and blends.

How to tune:

CANDIDATE_POOL_SIZE: 60 → 80 often gives a small boost.
LAMBDA_MIN/LAMBDA_MAX: Try (0.4, 0.8) if SPLADE is generally stronger; or (0.3, 0.7) as a conservative band.
USE_COMBMNZ: Try on/off; it often helps when systems agree on correct docs.
ALPHA_CE_BLEND: Try 0.8 vs 0.9; 0.85 is a strong default.
MIN_DF: 2 or 3; too high may ignore useful IDF signals.
Why this can hit +7–8%

Adaptive λ picks the right balance per query without relying on a brittle classifier.
CE L-12 re-ranking plus blending typically adds several points.
The IDF-driven weighting is stable (bounded) and explainable, avoiding the “spiky” QC issue.

We can also test:

λ band {(0.3,0.7), (0.4,0.8)}
CANDIDATE_POOL_SIZE {60, 80}
ALPHA {0.8, 0.85, 0.9}

In [3]:
import json
import ast
import re
import os
import math
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple, Iterable, DefaultDict
from collections import defaultdict, Counter

# =================== CONFIG ===================
SPLADE_CSV = "./prior_results/hotpotqa_splade.csv"
MPNET_CSV  = "./prior_results/hotpotqa_mpnet.csv"
OUTPUT_CSV = "./hybrid_tfidf_adaptive_ce.csv"

# Candidate generation caps
TOP_K_PER_LIST = 40         # items taken from each retriever before fusion
CANDIDATE_POOL_SIZE = 60    # items passed to CE
FINAL_TOP_K = 10            # final top-N to save/evaluate

# Adaptive λ band
LAMBDA_MIN = 0.3            # min weight on SPLADE
LAMBDA_MAX = 0.7            # max weight on SPLADE

# IDF construction
# Build a lightweight IDF dictionary from all candidate texts seen across rows
# This is a proxy for the underlying corpus; good enough to guide λ adaptively.
MIN_DF = 2                  # ignore terms appearing in < MIN_DF docs
MAX_VOCAB = 200_000         # safety cap

# Tokenization
TOKEN_PATTERN = r"[A-Za-z0-9_]+"  # simple alnum tokens, adjust if needed
LOWERCASE = True
STOPWORDS = set([
    "the","a","an","of","and","or","to","in","on","for","is","are","was","were","as","with","by","at","from","that","this","it","be","has","have","had"
])  # extend if needed

# Cross-encoder model
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
CROSS_ENCODER_MODEL_PATH = "./ce_ms_marco_minilm_l12_v2"
CE_LOCAL_ONLY = False
CE_BATCH_SIZE = 32
CE_DEVICE = None  # "cuda" or "cpu" or None

# Post-CE blending
USE_POST_CE_BLENDING = True
ALPHA_CE_BLEND = 0.85       # Final = α·CE + (1−α)·Fused

# Interpolation fusion options
SCORE_NORM_PER_LIST = True
USE_COMBMNZ = True

# Logging
LOG_RUN_META = True
# ==============================================


# --------------- Text utils ---------------

def normalize_text(s: Any) -> str:
    if s is None or (isinstance(s, float) and np.isnan(s)):
        return ""
    s = str(s)
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def safe_parse_list(val: Any) -> Any:
    if isinstance(val, (list, dict)):
        return val
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return []
    s = str(val).strip()
    if s == "":
        return []
    try:
        return json.loads(s)
    except Exception:
        try:
            return ast.literal_eval(s)
        except Exception:
            return []

def parse_ret_list(ret_field: Any) -> List[Dict[str, Any]]:
    data = safe_parse_list(ret_field)
    out = []
    if isinstance(data, list):
        for d in data:
            if isinstance(d, dict) and "doc_id" in d:
                doc_id = str(d.get("doc_id"))
                score = float(d.get("score", 0.0))
                full_text = d.get("full_text", "")
                out.append({"doc_id": doc_id, "score": score, "full_text": full_text})
    return out

def parse_groundtruth_list(gt_field: Any) -> List[str]:
    data = safe_parse_list(gt_field)
    if isinstance(data, list):
        return [str(x) for x in data]
    s = str(gt_field).strip()
    if not s:
        return []
    if "," in s:
        return [t.strip() for t in s.split(",") if t.strip()]
    return s.split()

def tokenize(text: str) -> List[str]:
    if text is None:
        return []
    if LOWERCASE:
        text = text.lower()
    toks = re.findall(TOKEN_PATTERN, text)
    toks = [t for t in toks if t and t not in STOPWORDS]
    return toks

def min_max_normalize(values: List[float]) -> List[float]:
    if not values:
        return []
    mn, mx = min(values), max(values)
    if mx == mn:
        return [0.0 for _ in values]
    return [(x - mn) / (mx - mn) for x in values]

def per_list_minmax_map(items: List[Dict[str, Any]]) -> Dict[str, float]:
    scores = [x["score"] for x in items]
    norm = min_max_normalize(scores)
    return {items[i]["doc_id"]: norm[i] for i in range(len(items))}


# --------------- Adaptive λ from IDF ---------------

def build_idf_dict(all_docs: Iterable[str], min_df: int = 2, max_vocab: int = MAX_VOCAB) -> Dict[str, float]:
    # Compute DF (document frequency) over the provided documents (strings)
    df_counter: DefaultDict[str, int] = defaultdict(int)
    total_docs = 0
    for txt in all_docs:
        total_docs += 1
        toks = set(tokenize(txt))
        for t in toks:
            df_counter[t] += 1
    # Compute IDF
    idf = {}
    for t, df in df_counter.items():
        if df >= min_df:
            idf[t] = math.log((1 + total_docs) / (1 + df)) + 1.0  # smooth idf
    # Trim vocab if huge
    if len(idf) > max_vocab:
        # Keep highest IDF terms first (rare terms)
        idf = dict(sorted(idf.items(), key=lambda kv: kv[1], reverse=True)[:max_vocab])
    return idf

def query_keywordiness_idf_avg(query: str, idf: Dict[str, float]) -> float:
    toks = tokenize(query)
    if not toks:
        return 0.0
    vals = [idf.get(t, 0.0) for t in toks]
    if not vals:
        return 0.0
    return float(sum(vals) / len(vals))

def adaptive_lambda_from_idf(avg_idf: float, min_idf: float, max_idf: float, lam_min: float, lam_max: float) -> float:
    # Map avg_idf in [min_idf, max_idf] -> λ in [lam_min, lam_max]
    if max_idf <= min_idf:
        return (lam_min + lam_max) / 2.0
    z = (avg_idf - min_idf) / (max_idf - min_idf)
    z = max(0.0, min(1.0, z))
    lam = lam_min + z * (lam_max - lam_min)
    return float(lam)


# --------------- Fusion (score interpolation + COMBMNZ) ---------------

def score_sum_fuse_weighted(lists: List[List[Dict[str, Any]]], weights: List[float], normalize_per_list: bool = True, combmnz: bool = False):
    score_maps = []
    for lst in lists:
        if normalize_per_list:
            m = per_list_minmax_map(lst)
        else:
            m = {x["doc_id"]: float(x["score"]) for x in lst}
        score_maps.append(m)

    all_doc_ids = set().union(*[set(m.keys()) for m in score_maps])
    fused = []
    for d in all_doc_ids:
        per_scores = [m.get(d, 0.0) for m in score_maps]
        val = sum(float(weights[i]) * per_scores[i] for i in range(len(per_scores)))
        if combmnz:
            c = sum(1 for s in per_scores if s > 0.0)
            val *= max(1, c)
        fused.append({"doc_id": d, "fused_interp": val})
    fused.sort(key=lambda x: x["fused_interp"], reverse=True)
    return fused


# --------------- Metrics ---------------

def apk(actual: List[str], predicted: List[str], k: int) -> float:
    if not actual:
        return 0.0
    pred_k = predicted[:k]
    hits, score = 0, 0.0
    seen = set()
    for i, p in enumerate(pred_k, start=1):
        if p in actual and p not in seen:
            hits += 1
            score += hits / i
            seen.add(p)
    return score / min(len(actual), k)

def ndcg_at_k(predicted: List[str], ideal_texts: List[str], k: int) -> float:
    pred_k = predicted[:k]
    dcg = 0.0
    for i, p in enumerate(pred_k, start=1):
        rel = 1.0 if p in ideal_texts else 0.0
        if rel:
            dcg += rel / np.log2(i + 1)
    g = min(len(ideal_texts), k)
    idcg = sum(1.0 / np.log2(i + 1) for i in range(1, g + 1))
    return (dcg / idcg) if idcg > 0.0 else 0.0


# --------------- Cross-Encoder ---------------

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

_ce_tokenizer = None
_ce_model = None

def load_cross_encoder():
    global _ce_tokenizer, _ce_model
    if _ce_tokenizer is not None and _ce_model is not None:
        return
    if CE_LOCAL_ONLY:
        if not os.path.isdir(CROSS_ENCODER_MODEL_PATH):
            raise OSError(
                f"Local CE folder not found: {CROSS_ENCODER_MODEL_PATH}. "
                f"Set CE_LOCAL_ONLY=False to load from HF Hub or place the model locally."
            )
        _ce_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER_MODEL_PATH, local_files_only=True)
        _ce_model = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER_MODEL_PATH, local_files_only=True)
    else:
        if os.path.isdir(CROSS_ENCODER_MODEL_PATH):
            _ce_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER_MODEL_PATH)
            _ce_model = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER_MODEL_PATH)
        else:
            _ce_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER_MODEL)
            _ce_model = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER_MODEL)
    _ce_model.eval()
    if CE_DEVICE:
        _ce_model.to(CE_DEVICE)

def ce_score_pairs(pairs: List[Tuple[str, str]], batch_size: int = 32) -> List[float]:
    load_cross_encoder()
    device = CE_DEVICE or ("cuda" if torch.cuda.is_available() else "cpu")
    _ce_model.to(device)

    scores = []
    with torch.no_grad():
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i+batch_size]
            inputs = _ce_tokenizer(
                [q for q, _ in batch],
                [d for _, d in batch],
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = _ce_model(**inputs)
            logits = outputs.logits
            if logits.dim() == 1:
                batch_scores = logits.detach().float().cpu().tolist()
            else:
                if logits.size(-1) == 1:
                    batch_scores = logits.squeeze(-1).detach().float().cpu().tolist()
                else:
                    batch_scores = logits[:, -1].detach().float().cpu().tolist()
            scores.extend(batch_scores)
    return scores


# --------------- Main: Adaptive λ + CE re-rank ---------------

def main():
    print("CWD:", os.getcwd())

    # Load CSVs
    df_sparse = pd.read_csv(SPLADE_CSV)
    df_dense = pd.read_csv(MPNET_CSV)

    # Preserve originals
    df_sparse["_q_orig"] = df_sparse["question"]
    df_sparse["_a_orig"] = df_sparse["answer"]
    df_sparse["_gt_orig"] = df_sparse["groundtruth_docs"]

    df_dense["_q_orig"] = df_dense["question"]
    df_dense["_a_orig"] = df_dense["answer"]
    df_dense["_gt_orig"] = df_dense["groundtruth_docs"]

    # Normalize join keys
    def normalize_groundtruth_str(gt_field: Any) -> str:
        data = safe_parse_list(gt_field)
        if isinstance(data, list):
            norm_items = [normalize_text(x) for x in data]
            return json.dumps(norm_items, ensure_ascii=False)
        return json.dumps([normalize_text(str(gt_field))], ensure_ascii=False)

    df_sparse["_q_norm"] = df_sparse["question"].apply(normalize_text)
    df_sparse["_a_norm"] = df_sparse["answer"].apply(normalize_text)
    df_sparse["_gt_norm"] = df_sparse["groundtruth_docs"].apply(normalize_groundtruth_str)

    df_dense["_q_norm"] = df_dense["question"].apply(normalize_text)
    df_dense["_a_norm"] = df_dense["answer"].apply(normalize_text)
    df_dense["_gt_norm"] = df_dense["groundtruth_docs"].apply(normalize_groundtruth_str)

    # Merge
    df = pd.merge(
        df_sparse, df_dense,
        on=["_q_norm", "_a_norm", "_gt_norm"],
        suffixes=("_splade", "_mpnet"),
        how="inner"
    )
    if df.empty:
        raise ValueError("No rows matched after normalization. Check groundtruth formats across CSVs.")

    # Build an IDF dictionary from all candidate texts seen across rows
    # Gather a corpus proxy from top lists to estimate DF robustly
    doc_texts_for_idf = []
    for _, row in df.iterrows():
        splade_list = parse_ret_list(row["splade_ret_docs"])[:TOP_K_PER_LIST]
        mpnet_list  = parse_ret_list(row["mpnet_ret_docs"])[:TOP_K_PER_LIST]
        for d in splade_list:
            if d.get("full_text"):
                doc_texts_for_idf.append(d["full_text"])
        for d in mpnet_list:
            if d.get("full_text"):
                doc_texts_for_idf.append(d["full_text"])
    idf_dict = build_idf_dict(doc_texts_for_idf, min_df=MIN_DF, max_vocab=MAX_VOCAB)

    # For normalization of avg_idf -> λ, we need rough bounds across queries
    # Precompute average idf per query to estimate min/max
    avg_idfs_all = []
    q_texts = df["_q_orig_splade"].tolist()
    for q in q_texts:
        avg_idfs_all.append(query_keywordiness_idf_avg(q, idf_dict))
    # Handle degenerate case
    if not avg_idfs_all:
        avg_idfs_all = [0.0]
    global_min_idf = float(min(avg_idfs_all))
    global_max_idf = float(max(avg_idfs_all)) if max(avg_idfs_all) > global_min_idf else (global_min_idf + 1.0)

    hybrid_ret_docs_col = []
    map3_list, map5_list, map10_list = [], [], []
    ndcg3_list, ndcg5_list, ndcg10_list = [], [], []

    for idx, row in df.iterrows():
        # Parse inputs
        splade_list = parse_ret_list(row["splade_ret_docs"])[:TOP_K_PER_LIST]
        mpnet_list  = parse_ret_list(row["mpnet_ret_docs"])[:TOP_K_PER_LIST]
        gt_list_raw = parse_groundtruth_list(row["_gt_orig_splade"])
        q_text = row["_q_orig_splade"]

        # Compute adaptive λ from query keywordiness
        avg_idf = query_keywordiness_idf_avg(q_text, idf_dict)
        lam_splade = adaptive_lambda_from_idf(avg_idf, global_min_idf, global_max_idf, LAMBDA_MIN, LAMBDA_MAX)
        lam_mpnet = 1.0 - lam_splade
        weights = [lam_splade, lam_mpnet]

        # Fusion by score interpolation + COMBMNZ on union
        fused_list = score_sum_fuse_weighted(
            [splade_list, mpnet_list],
            weights=weights,
            normalize_per_list=SCORE_NORM_PER_LIST,
            combmnz=USE_COMBMNZ
        )

        # Build candidate pool with full_texts
        splade_map = {x["doc_id"]: x for x in splade_list}
        mpnet_map  = {x["doc_id"]: x for x in mpnet_list}
        candidates = []
        for item in fused_list:
            d = item["doc_id"]
            ft = splade_map.get(d, mpnet_map.get(d, {})).get("full_text", "")
            candidates.append({
                "doc_id": d,
                "full_text": ft,
                "pre_fused_score": item["fused_interp"]
            })

        # Truncate for CE
        candidates = candidates[:CANDIDATE_POOL_SIZE]

        # CE re-rank
        pairs = [(q_text, c["full_text"]) for c in candidates]
        if pairs:
            ce_scores = ce_score_pairs(pairs, batch_size=CE_BATCH_SIZE)
        else:
            ce_scores = []

        for i, c in enumerate(candidates):
            c["_ce_score"] = float(ce_scores[i]) if i < len(ce_scores) else float("-inf")

        # Post-CE blending
        if USE_POST_CE_BLENDING:
            ce_vals = [c["_ce_score"] for c in candidates]
            pre_vals = [c["pre_fused_score"] for c in candidates]
            ce_norm = min_max_normalize(ce_vals)
            pre_norm = min_max_normalize(pre_vals)
            for i, c in enumerate(candidates):
                c["_final_score"] = ALPHA_CE_BLEND * ce_norm[i] + (1.0 - ALPHA_CE_BLEND) * pre_norm[i]
        else:
            for c in candidates:
                c["_final_score"] = c["_ce_score"]

        candidates.sort(key=lambda x: x["_final_score"], reverse=True)
        final = candidates[:FINAL_TOP_K]

        # Output struct and metrics
        hybrid_struct = []
        pred_texts = []
        for rnk, c in enumerate(final, start=1):
            entry = {
                "doc_id": c["doc_id"],
                "score": c["_final_score"],
                "rank": rnk,
                "full_text": c["full_text"],
                "fusion_mode": "adaptive_tfidf_interp -> CE(L12) -> blend"
            }
            if LOG_RUN_META:
                entry["meta"] = {
                    "lambda_splade": lam_splade,
                    "lambda_mpnet": lam_mpnet,
                    "cand_pool": CANDIDATE_POOL_SIZE,
                    "combmnz": USE_COMBMNZ,
                    "ce_model": CROSS_ENCODER_MODEL if not CE_LOCAL_ONLY else CROSS_ENCODER_MODEL_PATH,
                    "alpha_ce_blend": ALPHA_CE_BLEND,
                    "avg_idf": avg_idf
                }
            hybrid_struct.append(entry)
            pred_texts.append(normalize_text(c["full_text"]))

        hybrid_ret_docs_col.append(json.dumps(hybrid_struct, ensure_ascii=False))

        gt_norm_texts = [normalize_text(x) for x in gt_list_raw]
        map3_list.append(apk(gt_norm_texts, pred_texts, 3))
        map5_list.append(apk(gt_norm_texts, pred_texts, 5))
        map10_list.append(apk(gt_norm_texts, pred_texts, 10))
        ndcg3_list.append(ndcg_at_k(pred_texts, gt_norm_texts, 3))
        ndcg5_list.append(ndcg_at_k(pred_texts, gt_norm_texts, 5))
        ndcg10_list.append(ndcg_at_k(pred_texts, gt_norm_texts, 10))

    # Output
    out = pd.DataFrame({
        "question": df["_q_orig_splade"],
        "answer": df["_a_orig_splade"],
        "groundtruth_docs": df["_gt_orig_splade"],
        "splade_ret_docs": df["splade_ret_docs"],
        "mpnet_ret_docs": df["mpnet_ret_docs"],
        "hybrid_ret_docs": hybrid_ret_docs_col,
        "MAP@3": map3_list,
        "NDCG@3": ndcg3_list,
        "MAP@5": map5_list,
        "NDCG@5": ndcg5_list,
        "MAP@10": map10_list,
        "NDCG@10": ndcg10_list,
    })
    cols = [
        "question","answer","groundtruth_docs",
        "splade_ret_docs","mpnet_ret_docs","hybrid_ret_docs",
        "MAP@3","NDCG@3","MAP@5","NDCG@5","MAP@10","NDCG@10"
    ]
    out = out[cols]
    out.to_csv(OUTPUT_CSV, index=False, encoding="utf-8")

    overall = {
        "MAP@3": float(np.mean(map3_list)) if map3_list else 0.0,
        "NDCG@3": float(np.mean(ndcg3_list)) if ndcg3_list else 0.0,
        "MAP@5": float(np.mean(map5_list)) if map5_list else 0.0,
        "NDCG@5": float(np.mean(ndcg5_list)) if ndcg5_list else 0.0,
        "MAP@10": float(np.mean(map10_list)) if map10_list else 0.0,
        "NDCG@10": float(np.mean(ndcg10_list)) if ndcg10_list else 0.0,
    }
    print("Saved:", OUTPUT_CSV)
    print("Pipeline: adaptive_tfidf_interp -> CE(L12) -> blend")
    print("λ band (splade):", (LAMBDA_MIN, LAMBDA_MAX))
    print("Candidate pool:", CANDIDATE_POOL_SIZE, "| FINAL_TOP_K:", FINAL_TOP_K)
    print("CE:", CROSS_ENCODER_MODEL if not CE_LOCAL_ONLY else CROSS_ENCODER_MODEL_PATH)
    print("Blend α:", ALPHA_CE_BLEND, "| COMBMNZ:", USE_COMBMNZ)
    print("Hybrid overall averages:", overall)

# Run
main()

CWD: /home/csmala/journal_rag/hybrid_pipeline
Saved: ./hybrid_tfidf_adaptive_ce.csv
Pipeline: adaptive_tfidf_interp -> CE(L12) -> blend
λ band (splade): (0.3, 0.7)
Candidate pool: 60 | FINAL_TOP_K: 10
CE: cross-encoder/ms-marco-MiniLM-L-12-v2
Blend α: 0.85 | COMBMNZ: True
Hybrid overall averages: {'MAP@3': 0.5134620876676876, 'NDCG@3': 0.5892316464471192, 'MAP@5': 0.5271900810258897, 'NDCG@5': 0.6097992628793352, 'MAP@10': 0.5393110351771699, 'NDCG@10': 0.6299182972580143}


# Hybrid with modified metrics on HotpotQA (Current Working directory)

In [2]:
import pandas as pd
df = pd.read_csv("./prior_results/hotpotqa_splade_modified_metrics.csv")
df.columns

Index(['question', 'answer', 'passage', 'groundtruth_docs', 'splade_ret_docs',
       'passages_with_ids', 'groundtruth_with_ids', 'MAP@3', 'NDCG@3', 'MAP@5',
       'NDCG@5', 'MAP@10', 'NDCG@10'],
      dtype='object')

In [3]:
df = pd.read_csv("./prior_results/hotpotqa_mpnet_modified_metrics.csv")
df.columns

Index(['question', 'answer', 'passage', 'groundtruth_docs',
       'all-mpnet-base-v2_ret_docs', 'passages_with_ids',
       'groundtruth_with_ids', 'MAP@3', 'NDCG@3', 'MAP@5', 'NDCG@5', 'MAP@10',
       'NDCG@10'],
      dtype='object')

In [22]:
# Merging sparse and dense as single file for passing into hybrid 

import pandas as pd

# ------------- Load -------------
splade_path = './prior_results/hotpotqa_splade_modified_metrics.csv'
mpnet_path  = './prior_results/hotpotqa_mpnet_modified_metrics.csv'

splade_df = pd.read_csv(splade_path)
mpnet_df  = pd.read_csv(mpnet_path)

# Normalize column name if needed
if 'all-mpnet-base-v2_ret_docs' in mpnet_df.columns:
    mpnet_df = mpnet_df.rename(columns={'all-mpnet-base-v2_ret_docs': 'mpnet_ret_docs'})

# ------------- Keep only necessary columns -------------
splade_keep = ['question', 'answer', 'splade_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']
mpnet_keep  = ['mpnet_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']

# Sanity checks
missing_splade = set(splade_keep) - set(splade_df.columns)
missing_mpnet  = set(mpnet_keep) - set(mpnet_df.columns)
if missing_splade:
    raise ValueError(f"Missing columns in splade_df: {missing_splade}")
if missing_mpnet:
    raise ValueError(f"Missing columns in mpnet_df: {missing_mpnet}")

splade_df = splade_df[splade_keep].copy()
mpnet_df  = mpnet_df[mpnet_keep].copy()

# ------------- Merge -------------
# If you expect exactly one mpnet row per key, validate='m:1' will catch duplicates on the right.
merged = pd.merge(
    splade_df,
    mpnet_df,
    on=['passages_with_ids', 'groundtruth_with_ids'],
    how='inner',
    validate='m:1'
)

# ------------- Final selection and save -------------
final_cols = ['question', 'answer', 'passages_with_ids', 'groundtruth_with_ids', 'splade_ret_docs', 'mpnet_ret_docs']
merged = merged[final_cols]

out_path = 'hotpotqa_merged_splade_mpnet.csv'
merged.to_csv(out_path, index=False)

print('Merged shape:', merged.shape)
print('Saved to:', out_path)
print(merged.head(3).to_string(index=False))

Merged shape: (90447, 6)
Saved to: hotpotqa_merged_splade_mpnet.csv
                                                                                                                       question                  answer                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             

## Hybrid on HotpotQA

In [1]:
# Hybrid + vanilla RRF fusion + pytrec_eval evaluation

import pandas as pd
import json
import math
from collections import defaultdict
import pytrec_eval  # make sure available

CUTS = (3, 5, 10)

# ------------- Load + prep (same as before) -------------
splade_df = pd.read_csv('./prior_results/hotpotqa_splade_modified_metrics.csv')
mpnet_df = pd.read_csv('./prior_results/hotpotqa_mpnet_modified_metrics.csv')

if 'all-mpnet-base-v2_ret_docs' in mpnet_df.columns:
    mpnet_df = mpnet_df.rename(columns={'all-mpnet-base-v2_ret_docs': 'mpnet_ret_docs'})

splade_keep = ['question', 'answer', 'splade_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']
mpnet_keep  = ['question', 'answer', 'mpnet_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']

splade_df = splade_df[splade_keep].copy()
mpnet_df  = mpnet_df[mpnet_keep].copy()

def lower_str(x):
    if pd.isna(x):
        return x
    return str(x).lower()

for col in ['question', 'answer', 'passages_with_ids', 'groundtruth_with_ids']:
    if col in splade_df.columns:
        splade_df[col] = splade_df[col].apply(lower_str)
    if col in mpnet_df.columns:
        mpnet_df[col] = mpnet_df[col].apply(lower_str)

merged = pd.merge(
    mpnet_df, splade_df,
    on=['question', 'answer', 'passages_with_ids', 'groundtruth_with_ids'],
    how='inner'
)

# ------------- Parsing helpers (hardened) -------------
def safe_load(x):
    if isinstance(x, (list, dict)):
        return x
    if pd.isna(x):
        return None
    s = str(x)
    try:
        return json.loads(s)
    except Exception:
        try:
            import ast
            return ast.literal_eval(s)
        except Exception:
            return None

def normalize_doc_id(x):
    return None if x is None else str(x).strip().lower()

def to_ranked_dict(items):
    d = {}
    if not items:
        return d
    for i, it in enumerate(items):
        doc_id, rank, score = None, i+1, None
        if isinstance(it, (str, int)):
            doc_id = normalize_doc_id(it)
        elif isinstance(it, dict):
            doc_id = it.get('doc_id') or it.get('id') or it.get('docid')
            if doc_id is None and len(it) == 1:
                k, v = next(iter(it.items()))
                doc_id = k
                score = v if isinstance(v, (int, float, str)) else None
            else:
                rank = it.get('rank', rank)
                score = it.get('score', it.get('sim'))
            doc_id = normalize_doc_id(doc_id)
        if not doc_id:
            continue
        try:
            rank = int(rank)
        except Exception:
            rank = i+1
        try:
            score = float(score) if score is not None else None
        except Exception:
            score = None
        d[doc_id] = (rank, score)
    return d

def to_qrels_docids_only(gt_items):
    rel = {}
    if not gt_items:
        return rel
    for it in gt_items:
        candidate = None
        if isinstance(it, dict):
            candidate = it.get('doc_id') or it.get('id') or it.get('docid')
        elif isinstance(it, (str, int)):
            candidate = it
        did = normalize_doc_id(candidate)
        if did:
            rel[did] = 1
    return rel

def rrf_fusion(rank_dicts, k=60):
    scores = defaultdict(float)
    components = defaultdict(list)
    for src_idx, d in enumerate(rank_dicts):
        for doc_id, (rank, _score) in d.items():
            if not doc_id:
                continue
            scores[doc_id] += 1.0 / (k + rank)
            components[doc_id].append({'source': src_idx, 'rank': rank})
    sorted_docs = sorted(scores.items(), key=lambda x: (-x[1], x[0]))
    fused = []
    for i, (doc_id, sc) in enumerate(sorted_docs, start=1):
        fused.append({'doc_id': doc_id, 'score': sc, 'rank': i, 'components': components[doc_id]})
    return fused

# ------------- Build runs + qrels for pytrec_eval -------------
results_by_qid = {}  # qid -> {doc_id: score}
qrels_by_qid = {}    # qid -> {doc_id: relevance}
rows_out = []

for q_idx, row in merged.iterrows():
    mpnet_items = safe_load(row['mpnet_ret_docs'])
    splade_items = safe_load(row['splade_ret_docs'])
    gt_items = safe_load(row['groundtruth_with_ids'])

    mpnet_ranked = to_ranked_dict(mpnet_items)
    splade_ranked = to_ranked_dict(splade_items)
    fused = rrf_fusion([mpnet_ranked, splade_ranked], k=60)

    for e in fused:
        e['doc_id'] = normalize_doc_id(e['doc_id'])

    qid = f"q_{q_idx}"

    run_dict = {e['doc_id']: float(e['score']) for e in fused if e['doc_id'] is not None}
    results_by_qid[qid] = run_dict

    qrels = to_qrels_docids_only(gt_items)
    qrels_by_qid[qid] = {doc_id: int(rel) for doc_id, rel in qrels.items()}

    rows_out.append({
        'qid': qid,
        'question': row['question'],
        'answer': row['answer'],
        'mpnet_ret_docs': mpnet_items,
        'splade_ret_docs': splade_items,
        'hybrid_ret_docs': fused,
        'passages_with_ids': row['passages_with_ids'],
        'groundtruth_with_ids': gt_items
    })

# ------------- Evaluate with pytrec_eval (request all cuts!) -------------
metric_keys = {"map", "ndcg"} | {f"map_cut_{k}" for k in CUTS} | {f"ndcg_cut_{k}" for k in CUTS}
evaluator = pytrec_eval.RelevanceEvaluator(qrels_by_qid, metric_keys)
eval_res = evaluator.evaluate(results_by_qid)

def metric_at(pv, metric_base, k):
    key = f"{metric_base}_{k}"  # e.g., 'map_cut_3'
    return float(pv.get(key, 0.0))

metrics_rows = []
for r in rows_out:
    qid = r['qid']
    pv = eval_res.get(qid, {})
    metrics = {
        'MAP':    float(pv.get('map', 0.0)),
        'NDCG':   float(pv.get('ndcg', 0.0)),
        'MAP@3':  round(metric_at(pv, 'map_cut', 3), 6),
        'NDCG@3': round(metric_at(pv, 'ndcg_cut', 3), 6),
        'MAP@5':  round(metric_at(pv, 'map_cut', 5), 6),
        'NDCG@5': round(metric_at(pv, 'ndcg_cut', 5), 6),
        'MAP@10': round(metric_at(pv, 'map_cut', 10), 6),
        'NDCG@10':round(metric_at(pv, 'ndcg_cut', 10), 6),
    }
    metrics_rows.append({**r, **metrics})

hybrid_df = pd.DataFrame(metrics_rows)
hybrid_df = hybrid_df[['question', 'answer', 'mpnet_ret_docs', 'splade_ret_docs',
                       'hybrid_ret_docs', 'passages_with_ids', 'groundtruth_with_ids',
                       'MAP', 'NDCG', 'MAP@3', 'NDCG@3', 'MAP@5', 'NDCG@5', 'MAP@10', 'NDCG@10']]

hybrid_df.to_csv('hotpotqa_hybrid_rrf_results.csv', index=False)
print(f"Saved {len(hybrid_df)} rows to hotpotqa_hybrid_rrf_results.csv")

# Print averages
avg_metrics = hybrid_df[['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']].mean().round(6)
print('Averages:')
for m in ['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']:
    print(f"{m}: {avg_metrics[m]}")

Saved 80044 rows to hotpotqa_hybrid_rrf_results.csv
Averages:
MAP: 0.285133
NDCG: 0.365928
MAP@3: 0.256652
NDCG@3: 0.322585
MAP@5: 0.273515
NDCG@5: 0.342596
MAP@10: 0.282458
NDCG@10: 0.358922


In [4]:
'''
Hybrid(Splade + Mpnet) with Linear interpolation (LI) updated

the fusion step: we normalize scores per source list and combine them as
fused_score = alpha * mpnet_score + (1 - alpha) * splade_score

Adds: full_text field to each item in hybrid_ret_docs by building a doc_id -> text map
(prefer passages_with_ids; fallback to mpnet/splade lists).
'''

import pandas as pd
import json
import math
from collections import defaultdict
import pytrec_eval  # make sure available

CUTS = (3, 5, 10)
ALPHA = 0.3  # weight for mpnet; (1-ALPHA) is weight for splade

# ------------- Load (as provided) -------------
splade_path = './prior_results/hotpotqa_splade_modified_metrics.csv'
mpnet_path  = './prior_results/hotpotqa_mpnet_modified_metrics.csv'

splade_df = pd.read_csv(splade_path)
mpnet_df  = pd.read_csv(mpnet_path)

# Normalize column name if needed
if 'all-mpnet-base-v2_ret_docs' in mpnet_df.columns:
    mpnet_df = mpnet_df.rename(columns={'all-mpnet-base-v2_ret_docs': 'mpnet_ret_docs'})

# ------------- Keep only necessary columns (as provided) -------------
splade_keep = ['question', 'answer', 'splade_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']
mpnet_keep  = ['mpnet_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']

# Sanity checks
missing_splade = set(splade_keep) - set(splade_df.columns)
missing_mpnet  = set(mpnet_keep) - set(mpnet_df.columns)
if missing_splade:
    raise ValueError(f"Missing columns in splade_df: {missing_splade}")
if missing_mpnet:
    raise ValueError(f"Missing columns in mpnet_df: {missing_mpnet}")

splade_df = splade_df[splade_keep].copy()
mpnet_df  = mpnet_df[mpnet_keep].copy()

# ------------- Merge (as provided) -------------
# If you expect exactly one mpnet row per key, validate='m:1' will catch duplicates on the right.
merged = pd.merge(
    splade_df,
    mpnet_df,
    on=['passages_with_ids', 'groundtruth_with_ids'],
    how='inner',
    validate='m:1'
)

# ------------- Final selection (as provided) -------------
final_cols = ['question', 'answer', 'passages_with_ids', 'groundtruth_with_ids', 'splade_ret_docs', 'mpnet_ret_docs']
merged = merged[final_cols]
print('Merged shape:', merged.shape)

# ------------- Parsing helpers (same as before) -------------
def safe_load(x):
    if isinstance(x, (list, dict)):
        return x
    if pd.isna(x):
        return None
    s = str(x)
    try:
        return json.loads(s)
    except Exception:
        try:
            import ast
            return ast.literal_eval(s)
        except Exception:
            return None

def normalize_doc_id(x):
    return None if x is None else str(x).strip().lower()

def to_score_dict(items):
    """
    Convert retrieval list into dict: doc_id -> score (float).
    Accepts dicts containing {'doc_id', 'score'}; ignores items without both.
    """
    d = {}
    if not items:
        return d
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did:
            continue
        score = it.get('score', it.get('sim'))
        try:
            score = float(score)
        except Exception:
            continue
        d[did] = score
    return d

def min_max_norm(d):
    """
    Min-max normalize a dict of {doc_id: score} per query.
    If all scores equal or empty, returns zeros.
    """
    if not d:
        return {}
    vals = list(d.values())
    vmin, vmax = min(vals), max(vals)
    if vmax == vmin:
        # return zeros so that fusion still works deterministically
        return {k: 0.0 for k in d}
    return {k: (v - vmin) / (vmax - vmin) for k, v in d.items()}

def fuse_linear(mpnet_scores, splade_scores, alpha=0.5, normalize=True):
    """
    Linear interpolation fusion:
      fused = alpha * mpnet + (1 - alpha) * splade
    - If normalize, scores are min-max normalized per source before fusion.
    - Missing docs in one list get score 0 from that source.
    Returns list of dicts: {'doc_id', 'score'} sorted by score desc with rank.
    """
    m = mpnet_scores or {}
    s = splade_scores or {}
    if normalize:
        m = min_max_norm(m)
        s = min_max_norm(s)
    all_ids = set(m.keys()) | set(s.keys())
    fused = {}
    for did in all_ids:
        ms = m.get(did, 0.0)
        ss = s.get(did, 0.0)
        fused[did] = alpha * ms + (1.0 - alpha) * ss
    # sort by fused score desc, then doc_id for stability
    sorted_items = sorted(fused.items(), key=lambda x: (-x[1], x[0]))
    out = []
    for i, (did, sc) in enumerate(sorted_items, start=1):
        out.append({'doc_id': did, 'score': float(sc), 'rank': i})
    return out

def to_qrels_docids_only(gt_items):
    rel = {}
    if not gt_items:
        return rel
    for it in gt_items:
        candidate = None
        if isinstance(it, dict):
            candidate = it.get('doc_id') or it.get('id') or it.get('docid')
        elif isinstance(it, (str, int)):
            candidate = it
        did = normalize_doc_id(candidate)
        if did:
            rel[did] = 1
    return rel

# -------- NEW: Build a doc_id -> full_text map, preferring passages, then retriever lists --------
def build_doc_text_map(passages_items):
    """
    Build doc_id -> text map from passages_with_ids first (prefer 'full_text', then 'text'/'content').
    """
    mp = {}
    items = passages_items or []
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did:
            continue
        txt = it.get('full_text') or it.get('text') or it.get('content')
        if isinstance(txt, str) and txt.strip():
            mp[did] = txt
    return mp

def extend_doc_text_map(mp, retriever_items):
    """
    Enrich the map using mpnet/splade lists (prefer 'full_text', fallback to 'snippet'/'preview_snippet'/'text').
    """
    items = retriever_items or []
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did or did in mp:
            continue
        txt = it.get('full_text') or it.get('snippet') or it.get('preview_snippet') or it.get('text')
        if isinstance(txt, str) and txt.strip():
            mp[did] = txt
    return mp

# ------------- Build runs + qrels for pytrec_eval (with linear fusion) -------------
results_by_qid = {}  # qid -> {doc_id: score}
qrels_by_qid = {}    # qid -> {doc_id: relevance}
rows_out = []

for q_idx, row in merged.iterrows():
    mpnet_items = safe_load(row['mpnet_ret_docs'])
    splade_items = safe_load(row['splade_ret_docs'])
    gt_items = safe_load(row['groundtruth_with_ids'])
    passages_items = safe_load(row['passages_with_ids'])

    mpnet_scores = to_score_dict(mpnet_items)
    splade_scores = to_score_dict(splade_items)

    fused = fuse_linear(mpnet_scores, splade_scores, alpha=ALPHA, normalize=True)

    # normalize doc_ids (ensure lowercased/trimmed)
    for e in fused:
        e['doc_id'] = normalize_doc_id(e['doc_id'])

    qid = f"q_{q_idx}"

    # Build doc text map with priority: passages -> mpnet -> splade
    doc_text_map = build_doc_text_map(passages_items)
    doc_text_map = extend_doc_text_map(doc_text_map, mpnet_items)
    doc_text_map = extend_doc_text_map(doc_text_map, splade_items)

    # Attach full_text into hybrid_ret_docs entries
    hybrid_with_text = []
    for e in fused:
        did = e['doc_id']
        hybrid_with_text.append({
            'doc_id': did,
            'score': e['score'],
            'rank': e['rank'],
            'full_text': doc_text_map.get(did)  # may be None if not available anywhere
        })

    # run dict: doc_id -> fused score
    run_dict = {e['doc_id']: float(e['score']) for e in fused if e['doc_id'] is not None}
    results_by_qid[qid] = run_dict

    # qrels from groundtruth_with_ids
    qrels = to_qrels_docids_only(gt_items)
    qrels_by_qid[qid] = {doc_id: int(rel) for doc_id, rel in qrels.items()}

    rows_out.append({
        'qid': qid,
        'question': row['question'],
        'answer': row['answer'],
        'mpnet_ret_docs': mpnet_items,
        'splade_ret_docs': splade_items,
        'hybrid_ret_docs': hybrid_with_text,  # now includes full_text
        'passages_with_ids': row['passages_with_ids'],
        'groundtruth_with_ids': gt_items
    })

# ------------- Evaluate with pytrec_eval (request normal + cuts) -------------
metric_keys = {"map", "ndcg"} | {f"map_cut_{k}" for k in CUTS} | {f"ndcg_cut_{k}" for k in CUTS}
evaluator = pytrec_eval.RelevanceEvaluator(qrels_by_qid, metric_keys)
eval_res = evaluator.evaluate(results_by_qid)

def metric_at(pv, metric_base, k):
    key = f"{metric_base}_{k}"  # e.g., 'map_cut_3'
    return float(pv.get(key, 0.0))

metrics_rows = []
for r in rows_out:
    qid = r['qid']
    pv = eval_res.get(qid, {})
    metrics = {
        'MAP':    float(pv.get('map', 0.0)),
        'NDCG':   float(pv.get('ndcg', 0.0)),
        'MAP@3':  round(metric_at(pv, 'map_cut', 3), 6),
        'NDCG@3': round(metric_at(pv, 'ndcg_cut', 3), 6),
        'MAP@5':  round(metric_at(pv, 'map_cut', 5), 6),
        'NDCG@5': round(metric_at(pv, 'ndcg_cut', 5), 6),
        'MAP@10': round(metric_at(pv, 'map_cut', 10), 6),
        'NDCG@10':round(metric_at(pv, 'ndcg_cut', 10), 6),
    }
    metrics_rows.append({**r, **metrics})

hybrid_df = pd.DataFrame(metrics_rows)
hybrid_df = hybrid_df[['question', 'answer', 'mpnet_ret_docs', 'splade_ret_docs',
                       'hybrid_ret_docs', 'passages_with_ids', 'groundtruth_with_ids',
                       'MAP', 'NDCG', 'MAP@3', 'NDCG@3', 'MAP@5', 'NDCG@5', 'MAP@10', 'NDCG@10']]

# ------------- Save -------------
out_csv = 'hotpotqa_hybrid_linear_results_sparse_0.7.csv'
hybrid_df.to_csv(out_csv, index=False)
print(f"Saved {len(hybrid_df)} rows to {out_csv}")

# Print averages
avg_metrics = hybrid_df[['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']].mean().round(6)
print('Averages (alpha=%.2f):' % ALPHA)
for m in ['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']:
    print(f"{m}: {avg_metrics[m]}")

Merged shape: (90447, 6)
Saved 90447 rows to hotpotqa_hybrid_linear_results_sparse_0.7.csv
Averages (alpha=0.30):
MAP: 0.304735
NDCG: 0.382653
MAP@3: 0.282216
NDCG@3: 0.3501
MAP@5: 0.294928
NDCG@5: 0.362942
MAP@10: 0.302289
NDCG@10: 0.376265


Merged shape: (90447, 6)
Saved 90447 rows to hotpotqa_hybrid_linear_results.csv
Averages (alpha=0.40):
MAP: 0.301074
NDCG: 0.380012
MAP@3: 0.277347
NDCG@3: 0.345199
MAP@5: 0.290613
NDCG@5: 0.358912
MAP@10: 0.298428
NDCG@10: 0.373107


Merged shape: (90447, 6)
Saved 90447 rows to hotpotqa_hybrid_linear_results.csv
Averages (alpha=0.30):
MAP: 0.304735
NDCG: 0.382653
MAP@3: 0.282216
NDCG@3: 0.3501
MAP@5: 0.294928
NDCG@5: 0.362942
MAP@10: 0.302289
NDCG@10: 0.376265

In [5]:
"""
Hybrid (SPLADE + MPNet) with Linear Interpolation + Cross-Encoder Reranking updated
Reranker model: cross-encoder/ms-marco-MiniLM-L-6-v2

Adds: 'full_text' to each item in both hybrid_ret_docs (fused) and hybrid_reranked_docs.
"""

import json
from typing import List, Tuple, Dict, Optional

import pandas as pd
import pytrec_eval
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ----------------- Config -----------------
CUTS = (3, 5, 10)

# Linear interpolation weight for MPNet vs SPLADE
ALPHA = 0.3   # fused = ALPHA * mpnet + (1 - ALPHA) * splade

# Reranking weight (beta is weight for reranker)
BETA = 0.85   # final = BETA * rerank_score + (1 - BETA) * fused_score

# Reranker pool and output serialization size
MAX_CANDIDATES_FOR_RERANK = 200
TOPK_SAVE = 100  # how many reranked docs to store per row

# Inputs
splade_path = './prior_results/hotpotqa_splade_modified_metrics.csv'
mpnet_path  = './prior_results/hotpotqa_mpnet_modified_metrics.csv'

# Outputs
out_path = f"hotpotqa_hybrid_linear_with_rerank_minilm_alpha_{ALPHA}_beta_{BETA}.csv"

# Reranker model
RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
RERANKER_MAX_LENGTH = 512
RERANKER_BATCH_SIZE = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32  # the model is small; precision 32-bit is fine

# ----------------- Load -----------------
splade_df = pd.read_csv(splade_path)
mpnet_df  = pd.read_csv(mpnet_path)

# Normalize column name if needed
if 'all-mpnet-base-v2_ret_docs' in mpnet_df.columns:
    mpnet_df = mpnet_df.rename(columns={'all-mpnet-base-v2_ret_docs': 'mpnet_ret_docs'})

# Keep only necessary columns
splade_keep = ['question', 'answer', 'splade_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']
mpnet_keep  = ['mpnet_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']

missing_splade = set(splade_keep) - set(splade_df.columns)
missing_mpnet  = set(mpnet_keep) - set(mpnet_df.columns)
if missing_splade:
    raise ValueError(f"Missing columns in splade_df: {missing_splade}")
if missing_mpnet:
    raise ValueError(f"Missing columns in mpnet_df: {missing_mpnet}")

splade_df = splade_df[splade_keep].copy()
mpnet_df  = mpnet_df[mpnet_keep].copy()

# Merge on passages_with_ids and groundtruth_with_ids
merged = pd.merge(
    splade_df,
    mpnet_df,
    on=['passages_with_ids', 'groundtruth_with_ids'],
    how='inner',
    validate='m:1'
)

final_cols = ['question', 'answer', 'passages_with_ids', 'groundtruth_with_ids', 'splade_ret_docs', 'mpnet_ret_docs']
merged = merged[final_cols]
print('Merged shape:', merged.shape)

# ----------------- Helpers -----------------
def safe_load(x):
    if isinstance(x, (list, dict)):
        return x
    if pd.isna(x):
        return None
    s = str(x)
    try:
        return json.loads(s)
    except Exception:
        try:
            import ast
            return ast.literal_eval(s)
        except Exception:
            return None

def normalize_doc_id(x):
    return None if x is None else str(x).strip().lower()

def to_score_dict(items):
    """
    Convert retrieval list into dict: doc_id -> score (float).
    Accepts dicts containing {'doc_id', 'score'}; ignores items without both.
    """
    d = {}
    if not items:
        return d
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did:
            continue
        score = it.get('score', it.get('sim'))
        try:
            score = float(score)
        except Exception:
            continue
        d[did] = score
    return d

def min_max_norm(d):
    """
    Min–max normalize {doc_id: score}.
    If all scores equal or empty, returns zeros.
    """
    if not d:
        return {}
    vals = list(d.values())
    vmin, vmax = min(vals), max(vals)
    if vmax == vmin:
        return {k: 0.0 for k in d}
    return {k: (v - vmin) / (vmax - vmin) for k, v in d.items()}

def fuse_linear(mpnet_scores, splade_scores, alpha=0.5, normalize=True):
    """
    Linear interpolation fusion:
      fused = alpha * mpnet + (1 - alpha) * splade
    - If normalize, scores are min–max normalized per source before fusion.
    - Missing docs in one list get score 0 from that source.
    Returns list of dicts: {'doc_id','score','rank'} sorted by score desc with rank.
    """
    m = mpnet_scores or {}
    s = splade_scores or {}
    if normalize:
        m = min_max_norm(m)
        s = min_max_norm(s)
    all_ids = set(m.keys()) | set(s.keys())
    fused = {}
    for did in all_ids:
        ms = m.get(did, 0.0)
        ss = s.get(did, 0.0)
        fused[did] = alpha * ms + (1.0 - alpha) * ss
    sorted_items = sorted(fused.items(), key=lambda x: (-x[1], x[0]))
    out = []
    for i, (did, sc) in enumerate(sorted_items, start=1):
        out.append({'doc_id': did, 'score': float(sc), 'rank': i})
    return out

def to_qrels_docids_only(gt_items):
    rel = {}
    if not gt_items:
        return rel
    for it in gt_items:
        candidate = None
        if isinstance(it, dict):
            candidate = it.get('doc_id') or it.get('id') or it.get('docid')
        elif isinstance(it, (str, int)):
            candidate = it
        did = normalize_doc_id(candidate)
        if did:
            rel[did] = 1
    return rel

# Build a doc_id -> full_text map
def build_doc_text_map(passages_items) -> Dict[str, str]:
    """
    Prefer passages_with_ids['full_text'|'text'|'content'].
    """
    mp = {}
    items = passages_items or []
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        txt = it.get('full_text') or it.get('text') or it.get('content')
        if did and isinstance(txt, str) and txt.strip():
            mp[did] = txt
    return mp

def extend_doc_text_map(mp, retriever_items):
    """
    Enrich map with text from mpnet/splade lists (prefer 'full_text', fallback to 'snippet'/'preview_snippet'/'text').
    """
    items = retriever_items or []
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did or did in mp:
            continue
        txt = it.get('full_text') or it.get('snippet') or it.get('preview_snippet') or it.get('text')
        if isinstance(txt, str) and txt.strip():
            mp[did] = txt
    return mp

# ----------------- Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2 -----------------
class MiniLMReranker:
    """
    Cross-encoder reranker using 'cross-encoder/ms-marco-MiniLM-L-6-v2'.
    Returns a float score per (query, passage), where higher indicates higher relevance.
    """
    def __init__(self, model_name: str = RERANKER_MODEL_NAME, device: str = DEVICE, dtype=DTYPE, max_length: int = RERANKER_MAX_LENGTH):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.model.to(device)
        self.model.eval()
        self.device = device
        self.dtype = dtype
        self.max_length = max_length

    @torch.no_grad()
    def score(self, pairs: List[Tuple[str, str]], batch_size: int = RERANKER_BATCH_SIZE) -> List[float]:
        scores: List[float] = []
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i+batch_size]
            queries = [b[0] for b in batch]
            docs    = [b[1] for b in batch]
            enc = self.tokenizer(
                queries,
                docs,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            enc = {k: v.to(self.device) for k, v in enc.items()}
            logits = self.model(**enc).logits  # shape (B, 1) typically
            if logits.dim() == 2 and logits.size(1) == 1:
                batch_scores = logits.squeeze(1).detach().float().cpu().tolist()
            else:
                batch_scores = logits.view(-1).detach().float().cpu().tolist()
            scores.extend(batch_scores)
        return scores

reranker = MiniLMReranker()

def apply_reranker_on_fused_list(
    query: str,
    fused_list: List[Dict],
    doc_text_map: Dict[str, str],
    beta: float = BETA,
    max_candidates: int = MAX_CANDIDATES_FOR_RERANK,
    batch_size: int = RERANKER_BATCH_SIZE
) -> List[Dict]:
    """
    Input fused_list: [{'doc_id','score','rank'}]
    Output: [{'doc_id','fused_score','rerank_score','final_score','rank','full_text'}], sorted by final_score desc.
    """
    # Trim candidate pool
    pool = fused_list[:max_candidates]

    # Prepare (query, text) pairs; also track text to attach later
    pairs = []
    valid_idxs = []
    cached_texts = {}
    for i, e in enumerate(pool):
        did = e['doc_id']
        txt = doc_text_map.get(did)
        if isinstance(txt, str) and txt.strip():
            pairs.append((query, txt))
            valid_idxs.append(i)
            cached_texts[i] = txt  # store for output

    # If nothing to score, pass-through and attach whatever text we have (may be None)
    if not pairs:
        out = []
        for i, e in enumerate(pool, start=1):
            did = e['doc_id']
            out.append({
                'doc_id': did,
                'fused_score': float(e['score']),
                'rerank_score': 0.0,
                'final_score': float(e['score']),
                'rank': i,
                'full_text': doc_text_map.get(did)  # might be None
            })
        return out

    # Reranker scores
    rr_scores = reranker.score(pairs, batch_size=batch_size)

    # Attach scores and compute final
    for idx_local, s in zip(valid_idxs, rr_scores):
        pool[idx_local]['rerank_score'] = float(s)
    for e in pool:
        e['rerank_score'] = float(e.get('rerank_score', 0.0))
        e['fused_score'] = float(e.get('score', 0.0))
        e['final_score'] = beta * e['rerank_score'] + (1.0 - beta) * e['fused_score']

    # Sort and re-rank
    ranked = sorted(pool, key=lambda x: (-x['final_score'], x['doc_id']))

    # Build output with full_text
    out = []
    for new_rank, e in enumerate(ranked, start=1):
        did = e['doc_id']
        out.append({
            'doc_id': did,
            'fused_score': e['fused_score'],
            'rerank_score': e['rerank_score'],
            'final_score': e['final_score'],
            'rank': new_rank,
            'full_text': doc_text_map.get(did)  # attach document text
        })
    return out

# ----------------- Build Runs + Evaluate (Fusion + Rerank) -----------------
results_by_qid_fused = {}     # qid -> {doc_id: fused score}
results_by_qid_reranked = {}  # qid -> {doc_id: final score}
qrels_by_qid = {}             # qid -> {doc_id: relevance}
rows_out = []

for q_idx, row in merged.iterrows():
    mpnet_items = safe_load(row['mpnet_ret_docs'])
    splade_items = safe_load(row['splade_ret_docs'])
    gt_items = safe_load(row['groundtruth_with_ids'])
    passages_items = safe_load(row['passages_with_ids'])

    mpnet_scores = to_score_dict(mpnet_items)
    splade_scores = to_score_dict(splade_items)

    fused = fuse_linear(mpnet_scores, splade_scores, alpha=ALPHA, normalize=True)
    for e in fused:
        e['doc_id'] = normalize_doc_id(e['doc_id'])

    qid = f"q_{q_idx}"

    # Build doc_id -> text map, using passages first, then enrich from retriever lists
    doc_text_map = build_doc_text_map(passages_items)
    doc_text_map = extend_doc_text_map(doc_text_map, mpnet_items)
    doc_text_map = extend_doc_text_map(doc_text_map, splade_items)

    # Apply reranker over fused list
    reranked = apply_reranker_on_fused_list(
        query=row['question'],
        fused_list=fused,
        doc_text_map=doc_text_map,
        beta=BETA,
        max_candidates=MAX_CANDIDATES_FOR_RERANK,
        batch_size=RERANKER_BATCH_SIZE
    )

    # Also attach full_text to fused list for serialization
    fused_with_text = []
    for e in fused:
        did = e['doc_id']
        fused_with_text.append({
            'doc_id': did,
            'score': e['score'],
            'rank': e['rank'],
            'full_text': doc_text_map.get(did)
        })

    # Build run dicts
    run_fused = {e['doc_id']: float(e['score']) for e in fused if e['doc_id'] is not None}
    run_reranked = {e['doc_id']: float(e['final_score']) for e in reranked if e['doc_id'] is not None}

    results_by_qid_fused[qid] = run_fused
    results_by_qid_reranked[qid] = run_reranked

    # Qrels
    qrels = to_qrels_docids_only(gt_items)
    qrels_by_qid[qid] = {doc_id: int(rel) for doc_id, rel in qrels.items()}

    # Row serialization
    rows_out.append({
        'qid': qid,
        'question': row['question'],
        'answer': row['answer'],
        'mpnet_ret_docs': mpnet_items,
        'splade_ret_docs': splade_items,
        'hybrid_ret_docs': fused_with_text,                 # before rerank, now with full_text
        'hybrid_reranked_docs': reranked[:TOPK_SAVE],       # after rerank + blending, with full_text
        'passages_with_ids': passages_items,
        'groundtruth_with_ids': gt_items
    })

# ----------------- Evaluation (pytrec_eval) -----------------
metric_keys = {"map", "ndcg"} | {f"map_cut_{k}" for k in CUTS} | {f"ndcg_cut_{k}" for k in CUTS}
evaluator_fused = pytrec_eval.RelevanceEvaluator(qrels_by_qid, metric_keys)
evaluator_reranked = pytrec_eval.RelevanceEvaluator(qrels_by_qid, metric_keys)

eval_fused = evaluator_fused.evaluate(results_by_qid_fused)
eval_reranked = evaluator_reranked.evaluate(results_by_qid_reranked)

def extract_metrics(pv):
    return {
        'MAP': float(pv.get('map', 0.0)),
        'NDCG': float(pv.get('ndcg', 0.0)),
        'MAP@3': round(float(pv.get('map_cut_3', 0.0)), 6),
        'NDCG@3': round(float(pv.get('ndcg_cut_3', 0.0)), 6),
        'MAP@5': round(float(pv.get('map_cut_5', 0.0)), 6),
        'NDCG@5': round(float(pv.get('ndcg_cut_5', 0.0)), 6),
        'MAP@10': round(float(pv.get('map_cut_10', 0.0)), 6),
        'NDCG@10': round(float(pv.get('ndcg_cut_10', 0.0)), 6),
    }

metrics_rows = []
for r in rows_out:
    qid = r['qid']
    fused_metrics = extract_metrics(eval_fused.get(qid, {}))
    rerank_metrics = extract_metrics(eval_reranked.get(qid, {}))
    metrics_rows.append({
        **r,
        **{f"FUSED_{k}": v for k, v in fused_metrics.items()},
        **{f"RERANK_{k}": v for k, v in rerank_metrics.items()},
    })

hybrid_df = pd.DataFrame(metrics_rows)
cols = ['question', 'answer', 'mpnet_ret_docs', 'splade_ret_docs',
        'hybrid_ret_docs', 'hybrid_reranked_docs', 'passages_with_ids',
        'groundtruth_with_ids',
        'FUSED_MAP','FUSED_NDCG','FUSED_MAP@3','FUSED_NDCG@3','FUSED_MAP@5','FUSED_NDCG@5','FUSED_MAP@10','FUSED_NDCG@10',
        'RERANK_MAP','RERANK_NDCG','RERANK_MAP@3','RERANK_NDCG@3','RERANK_MAP@5','RERANK_NDCG@5','RERANK_MAP@10','RERANK_NDCG@10']
hybrid_df = hybrid_df[cols]

# Save
hybrid_df.to_csv(out_path, index=False)
print(f"Saved {len(hybrid_df)} rows to {out_path}")

# Print averages
avg_fused = hybrid_df[['FUSED_MAP','FUSED_NDCG','FUSED_MAP@3','FUSED_NDCG@3','FUSED_MAP@5','FUSED_NDCG@5','FUSED_MAP@10','FUSED_NDCG@10']].mean().round(6)
avg_rerank = hybrid_df[['RERANK_MAP','RERANK_NDCG','RERANK_MAP@3','RERANK_NDCG@3','RERANK_MAP@5','RERANK_NDCG@5','RERANK_MAP@10','RERANK_NDCG@10']].mean().round(6)

print(f"Averages (alpha={ALPHA:.2f}) - FUSED:")
for m in ['FUSED_MAP','FUSED_NDCG','FUSED_MAP@3','FUSED_NDCG@3','FUSED_MAP@5','FUSED_NDCG@5','FUSED_MAP@10','FUSED_NDCG@10']:
    print(f"{m}: {avg_fused[m]}")

print(f"Averages (alpha={ALPHA:.2f}, beta={BETA:.2f}) - RERANK:")
for m in ['RERANK_MAP','RERANK_NDCG','RERANK_MAP@3','RERANK_NDCG@3','RERANK_MAP@5','RERANK_NDCG@5','RERANK_MAP@10','RERANK_NDCG@10']:
    print(f"{m}: {avg_rerank[m]}")

Merged shape: (90447, 6)
Saved 90447 rows to hotpotqa_hybrid_linear_with_rerank_minilm_alpha_0.3_beta_0.85.csv
Averages (alpha=0.30) - FUSED:
FUSED_MAP: 0.304735
FUSED_NDCG: 0.382653
FUSED_MAP@3: 0.282216
FUSED_NDCG@3: 0.3501
FUSED_MAP@5: 0.294928
FUSED_NDCG@5: 0.362942
FUSED_MAP@10: 0.302289
FUSED_NDCG@10: 0.376265
Averages (alpha=0.30, beta=0.85) - RERANK:
RERANK_MAP: 0.315942
RERANK_NDCG: 0.391368
RERANK_MAP@3: 0.295414
RERANK_NDCG@3: 0.362896
RERANK_MAP@5: 0.307314
RERANK_NDCG@5: 0.374353
RERANK_MAP@10: 0.313973
RERANK_NDCG@10: 0.38639


Merged shape: (90447, 6)
Saved 90447 rows to hotpotqa_hybrid_linear_with_rerank_minilm_alpha_0.3_beta_0.85.csv
Averages (alpha=0.30) - FUSED:
FUSED_MAP: 0.304735
FUSED_NDCG: 0.382653
FUSED_MAP@3: 0.282216
FUSED_NDCG@3: 0.3501
FUSED_MAP@5: 0.294928
FUSED_NDCG@5: 0.362942
FUSED_MAP@10: 0.302289
FUSED_NDCG@10: 0.376265
Averages (alpha=0.30, beta=0.85) - RERANK:
RERANK_MAP: 0.315942
RERANK_NDCG: 0.391368
RERANK_MAP@3: 0.295414
RERANK_NDCG@3: 0.362896
RERANK_MAP@5: 0.307314
RERANK_NDCG@5: 0.374353
RERANK_MAP@10: 0.313973
RERANK_NDCG@10: 0.38639

In [8]:
'''
CombMNZ works as:

Normalize scores per list (recommended) to make scales comparable.
For each doc, sum the normalized scores across systems (CombSUM) and multiply by the number of systems that retrieved the doc (NZ = non-zero count).
fused_score = (sum of normalized scores) * (number of non-zero contributors)

'''

import pandas as pd
import json
import math
from collections import defaultdict
import pytrec_eval  # make sure available

CUTS = (3, 5, 10)

# ------------- Load + prep (same as before) -------------
splade_df = pd.read_csv('./prior_results/hotpotqa_splade_modified_metrics.csv')
mpnet_df = pd.read_csv('./prior_results/hotpotqa_mpnet_modified_metrics.csv')

if 'all-mpnet-base-v2_ret_docs' in mpnet_df.columns:
    mpnet_df = mpnet_df.rename(columns={'all-mpnet-base-v2_ret_docs': 'mpnet_ret_docs'})

splade_keep = ['question', 'answer', 'splade_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']
mpnet_keep  = ['question', 'answer', 'mpnet_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']

splade_df = splade_df[splade_keep].copy()
mpnet_df  = mpnet_df[mpnet_keep].copy()

def lower_str(x):
    if pd.isna(x):
        return x
    return str(x).lower()

for col in ['question', 'answer', 'passages_with_ids', 'groundtruth_with_ids']:
    if col in splade_df.columns:
        splade_df[col] = splade_df[col].apply(lower_str)
    if col in mpnet_df.columns:
        mpnet_df[col] = mpnet_df[col].apply(lower_str)

merged = pd.merge(
    mpnet_df, splade_df,
    on=['question', 'answer', 'passages_with_ids', 'groundtruth_with_ids'],
    how='inner'
)

# ------------- Parsing helpers (hardened) -------------
def safe_load(x):
    if isinstance(x, (list, dict)):
        return x
    if pd.isna(x):
        return None
    s = str(x)
    try:
        return json.loads(s)
    except Exception:
        try:
            import ast
            return ast.literal_eval(s)
        except Exception:
            return None

def normalize_doc_id(x):
    return None if x is None else str(x).strip().lower()

def to_score_dict(items):
    """
    Convert retrieval list into dict: doc_id -> score (float).
    Accepts dicts containing {'doc_id', 'score'}; ignores items without both.
    """
    d = {}
    if not items:
        return d
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did:
            continue
        score = it.get('score', it.get('sim'))
        try:
            score = float(score)
        except Exception:
            continue
        d[did] = score
    return d

def min_max_norm(d):
    """
    Min-max normalize a dict of {doc_id: score} per query.
    If all scores equal or empty, returns zeros.
    """
    if not d:
        return {}
    vals = list(d.values())
    vmin, vmax = min(vals), max(vals)
    if vmax == vmin:
        return {k: 0.0 for k in d}
    return {k: (v - vmin) / (vmax - vmin) for k, v in d.items()}

def fuse_combmnz(mpnet_scores, splade_scores, normalize=True):
    """
    CombMNZ fusion:
      For each doc:
        s_norm = (normalize scores per list if requested)
        sum_scores = s_m + s_s
        nz = count of non-zero source scores contributing (>= 2 if in both lists and >0 after normalization)
        fused = sum_scores * nz
    Missing docs in a source treated as 0.
    Returns list of dicts: {'doc_id', 'score'} sorted by fused score desc with rank.
    """
    m = mpnet_scores or {}
    s = splade_scores or {}
    if normalize:
        m = min_max_norm(m)
        s = min_max_norm(s)

    all_ids = set(m.keys()) | set(s.keys())
    fused = {}
    for did in all_ids:
        ms = m.get(did, 0.0)
        ss = s.get(did, 0.0)
        sum_scores = ms + ss
        nz = (1 if ms > 0.0 else 0) + (1 if ss > 0.0 else 0)
        # If both are zero (rare with normalization unless singletons), nz stays 0; fused should be 0
        fused_score = sum_scores * nz if nz > 0 else 0.0
        fused[did] = fused_score

    sorted_items = sorted(fused.items(), key=lambda x: (-x[1], x[0]))
    out = []
    for i, (did, sc) in enumerate(sorted_items, start=1):
        out.append({'doc_id': did, 'score': float(sc), 'rank': i})
    return out

def to_qrels_docids_only(gt_items):
    rel = {}
    if not gt_items:
        return rel
    for it in gt_items:
        candidate = None
        if isinstance(it, dict):
            candidate = it.get('doc_id') or it.get('id') or it.get('docid')
        elif isinstance(it, (str, int)):
            candidate = it
        did = normalize_doc_id(candidate)
        if did:
            rel[did] = 1
    return rel

# ------------- Build runs + qrels for pytrec_eval (with CombMNZ fusion) -------------
results_by_qid = {}  # qid -> {doc_id: score}
qrels_by_qid = {}    # qid -> {doc_id: relevance}
rows_out = []

for q_idx, row in merged.iterrows():
    mpnet_items = safe_load(row['mpnet_ret_docs'])
    splade_items = safe_load(row['splade_ret_docs'])
    gt_items = safe_load(row['groundtruth_with_ids'])

    mpnet_scores = to_score_dict(mpnet_items)
    splade_scores = to_score_dict(splade_items)

    fused = fuse_combmnz(mpnet_scores, splade_scores, normalize=True)

    # ensure normalized doc_ids
    for e in fused:
        e['doc_id'] = normalize_doc_id(e['doc_id'])

    qid = f"q_{q_idx}"

    # run dict: doc_id -> fused score
    run_dict = {e['doc_id']: float(e['score']) for e in fused if e['doc_id'] is not None}
    results_by_qid[qid] = run_dict

    # qrels from groundtruth_with_ids
    qrels = to_qrels_docids_only(gt_items)
    qrels_by_qid[qid] = {doc_id: int(rel) for doc_id, rel in qrels.items()}

    rows_out.append({
        'qid': qid,
        'question': row['question'],
        'answer': row['answer'],
        'mpnet_ret_docs': mpnet_items,
        'splade_ret_docs': splade_items,
        'hybrid_ret_docs': fused,  # CombMNZ fused list with scores and ranks
        'passages_with_ids': row['passages_with_ids'],
        'groundtruth_with_ids': gt_items
    })

# ------------- Evaluate with pytrec_eval (request normal + cuts) -------------
metric_keys = {"map", "ndcg"} | {f"map_cut_{k}" for k in CUTS} | {f"ndcg_cut_{k}" for k in CUTS}
evaluator = pytrec_eval.RelevanceEvaluator(qrels_by_qid, metric_keys)
eval_res = evaluator.evaluate(results_by_qid)

def metric_at(pv, metric_base, k):
    key = f"{metric_base}_{k}"  # e.g., 'map_cut_3'
    return float(pv.get(key, 0.0))

metrics_rows = []
for r in rows_out:
    qid = r['qid']
    pv = eval_res.get(qid, {})
    metrics = {
        'MAP':    float(pv.get('map', 0.0)),
        'NDCG':   float(pv.get('ndcg', 0.0)),
        'MAP@3':  round(metric_at(pv, 'map_cut', 3), 6),
        'NDCG@3': round(metric_at(pv, 'ndcg_cut', 3), 6),
        'MAP@5':  round(metric_at(pv, 'map_cut', 5), 6),
        'NDCG@5': round(metric_at(pv, 'ndcg_cut', 5), 6),
        'MAP@10': round(metric_at(pv, 'map_cut', 10), 6),
        'NDCG@10':round(metric_at(pv, 'ndcg_cut', 10), 6),
    }
    metrics_rows.append({**r, **metrics})

hybrid_df = pd.DataFrame(metrics_rows)
hybrid_df = hybrid_df[['question', 'answer', 'mpnet_ret_docs', 'splade_ret_docs',
                       'hybrid_ret_docs', 'passages_with_ids', 'groundtruth_with_ids',
                       'MAP', 'NDCG', 'MAP@3', 'NDCG@3', 'MAP@5', 'NDCG@5', 'MAP@10', 'NDCG@10']]

# ------------- Save -------------
hybrid_df.to_csv('hotpotqa_hybrid_combmnz_results.csv', index=False)
print(f"Saved {len(hybrid_df)} rows to hotpotqa_hybrid_combmnz_results.csv")

# Print averages
avg_metrics = hybrid_df[['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']].mean().round(6)
print('Averages (CombMNZ):')
for m in ['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']:
    print(f"{m}: {avg_metrics[m]}")

Saved 80044 rows to hotpotqa_hybrid_combmnz_results.csv
Averages (CombMNZ):
MAP: 0.287889
NDCG: 0.368214
MAP@3: 0.261554
NDCG@3: 0.328154
MAP@5: 0.276335
NDCG@5: 0.344823
MAP@10: 0.285001
NDCG@10: 0.36068


In [9]:
# Adaptive QC fusion using a BERT-based classifier to get per-query weights

import pandas as pd
import json
from collections import defaultdict
import pytrec_eval
import torch
from transformers import BertTokenizer, BertForSequenceClassification

CUTS = (3, 5, 10)

# ---------- Load adaptive weight model ----------
FOLDER_PATH = "bert_model_QC_finetuned"
TEMPERATURE = 1.17  # tune as needed

tokenizer = BertTokenizer.from_pretrained(FOLDER_PATH)
model = BertForSequenceClassification.from_pretrained(FOLDER_PATH)
model.eval()

@torch.no_grad()
def get_query_weights(query_text: str, temperature: float = TEMPERATURE):
    """
    Returns (w_sparse, w_dense) from the classifier's softmax over logits/temperature.
    Assumes class 0 = sparse, class 1 = dense.
    """
    inputs = tokenizer(query_text, return_tensors="pt", truncation=True, padding=True)
    outputs = model(**inputs)
    logits = outputs.logits  # shape [1,2]
    weights = torch.softmax(logits / temperature, dim=-1).squeeze(0)  # [2]
    w_sparse = float(weights[0].item())
    w_dense = float(weights[1].item())
    return w_sparse, w_dense

# ------------- Load + prep (same as before) -------------
splade_df = pd.read_csv('./prior_results/hotpotqa_splade_modified_metrics.csv')
mpnet_df = pd.read_csv('./prior_results/hotpotqa_mpnet_modified_metrics.csv')

if 'all-mpnet-base-v2_ret_docs' in mpnet_df.columns:
    mpnet_df = mpnet_df.rename(columns={'all-mpnet-base-v2_ret_docs': 'mpnet_ret_docs'})

splade_keep = ['question', 'answer', 'splade_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']
mpnet_keep  = ['question', 'answer', 'mpnet_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']

splade_df = splade_df[splade_keep].copy()
mpnet_df  = mpnet_df[mpnet_keep].copy()

def lower_str(x):
    if pd.isna(x):
        return x
    return str(x).lower()

for col in ['question', 'answer', 'passages_with_ids', 'groundtruth_with_ids']:
    if col in splade_df.columns:
        splade_df[col] = splade_df[col].apply(lower_str)
    if col in mpnet_df.columns:
        mpnet_df[col] = mpnet_df[col].apply(lower_str)

merged = pd.merge(
    mpnet_df, splade_df,
    on=['question', 'answer', 'passages_with_ids', 'groundtruth_with_ids'],
    how='inner'
)

# ------------- Parsing helpers -------------
def safe_load(x):
    if isinstance(x, (list, dict)):
        return x
    if pd.isna(x):
        return None
    s = str(x)
    try:
        return json.loads(s)
    except Exception:
        try:
            import ast
            return ast.literal_eval(s)
        except Exception:
            return None

def normalize_doc_id(x):
    return None if x is None else str(x).strip().lower()

def to_score_dict(items):
    """
    Convert retrieval list into dict: doc_id -> score (float).
    Accepts dicts containing {'doc_id', 'score'}; ignores items without both.
    """
    d = {}
    if not items:
        return d
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did:
            continue
        score = it.get('score', it.get('sim'))
        try:
            score = float(score)
        except Exception:
            continue
        d[did] = score
    return d

def min_max_norm(d):
    """
    Min-max normalize a dict of {doc_id: score} per query.
    If all scores equal or empty, returns zeros.
    """
    if not d:
        return {}
    vals = list(d.values())
    vmin, vmax = min(vals), max(vals)
    if vmax == vmin:
        return {k: 0.0 for k in d}
    return {k: (v - vmin) / (vmax - vmin) for k, v in d.items()}

def fuse_adaptive(mpnet_scores, splade_scores, w_sparse: float, w_dense: float, normalize=True):
    """
    Adaptive score-based fusion using classifier weights per query:
      fused = w_dense * dense_norm + w_sparse * sparse_norm
    Missing docs in a source get score 0 from that source.
    Returns list of dicts: {'doc_id', 'score', 'rank'} sorted by score desc.
    """
    dense = mpnet_scores or {}
    sparse = splade_scores or {}
    if normalize:
        dense = min_max_norm(dense)
        sparse = min_max_norm(sparse)
    all_ids = set(dense.keys()) | set(sparse.keys())
    fused = {}
    for did in all_ids:
        ds = dense.get(did, 0.0)
        ss = sparse.get(did, 0.0)
        fused[did] = w_dense * ds + w_sparse * ss
    sorted_items = sorted(fused.items(), key=lambda x: (-x[1], x[0]))
    out = []
    for i, (did, sc) in enumerate(sorted_items, start=1):
        out.append({'doc_id': did, 'score': float(sc), 'rank': i})
    return out

def to_qrels_docids_only(gt_items):
    rel = {}
    if not gt_items:
        return rel
    for it in gt_items:
        candidate = None
        if isinstance(it, dict):
            candidate = it.get('doc_id') or it.get('id') or it.get('docid')
        elif isinstance(it, (str, int)):
            candidate = it
        did = normalize_doc_id(candidate)
        if did:
            rel[did] = 1
    return rel

# ------------- Build runs + qrels with adaptive fusion -------------
results_by_qid = {}  # qid -> {doc_id: score}
qrels_by_qid = {}    # qid -> {doc_id: relevance}
rows_out = []

for q_idx, row in merged.iterrows():
    question_text = row['question']  # use original text for classifier
    mpnet_items = safe_load(row['mpnet_ret_docs'])
    splade_items = safe_load(row['splade_ret_docs'])
    gt_items = safe_load(row['groundtruth_with_ids'])

    mpnet_scores = to_score_dict(mpnet_items)
    splade_scores = to_score_dict(splade_items)

    # Get adaptive weights per query: (w_sparse, w_dense)
    w_sparse, w_dense = get_query_weights(question_text)

    fused = fuse_adaptive(mpnet_scores, splade_scores, w_sparse=w_sparse, w_dense=w_dense, normalize=True)

    # ensure normalized doc_ids
    for e in fused:
        e['doc_id'] = normalize_doc_id(e['doc_id'])

    qid = f"q_{q_idx}"

    # run dict: doc_id -> fused score
    run_dict = {e['doc_id']: float(e['score']) for e in fused if e['doc_id'] is not None}
    results_by_qid[qid] = run_dict

    # qrels from groundtruth_with_ids
    qrels = to_qrels_docids_only(gt_items)
    qrels_by_qid[qid] = {doc_id: int(rel) for doc_id, rel in qrels.items()}

    rows_out.append({
        'qid': qid,
        'question': row['question'],
        'answer': row['answer'],
        'mpnet_ret_docs': mpnet_items,
        'splade_ret_docs': splade_items,
        'hybrid_ret_docs': fused,  # adaptive fused list
        'weights': {'sparse': w_sparse, 'dense': w_dense, 'temperature': TEMPERATURE},
        'passages_with_ids': row['passages_with_ids'],
        'groundtruth_with_ids': gt_items
    })

# ------------- Evaluate with pytrec_eval (normal + cuts) -------------
metric_keys = {"map", "ndcg"} | {f"map_cut_{k}" for k in CUTS} | {f"ndcg_cut_{k}" for k in CUTS}
evaluator = pytrec_eval.RelevanceEvaluator(qrels_by_qid, metric_keys)
eval_res = evaluator.evaluate(results_by_qid)

def metric_at(pv, metric_base, k):
    key = f"{metric_base}_{k}"  # e.g., 'map_cut_3'
    return float(pv.get(key, 0.0))

metrics_rows = []
for r in rows_out:
    qid = r['qid']
    pv = eval_res.get(qid, {})
    metrics = {
        'MAP':    float(pv.get('map', 0.0)),
        'NDCG':   float(pv.get('ndcg', 0.0)),
        'MAP@3':  round(metric_at(pv, 'map_cut', 3), 6),
        'NDCG@3': round(metric_at(pv, 'ndcg_cut', 3), 6),
        'MAP@5':  round(metric_at(pv, 'map_cut', 5), 6),
        'NDCG@5': round(metric_at(pv, 'ndcg_cut', 5), 6),
        'MAP@10': round(metric_at(pv, 'map_cut', 10), 6),
        'NDCG@10':round(metric_at(pv, 'ndcg_cut', 10), 6),
    }
    metrics_rows.append({**r, **metrics})

hybrid_df = pd.DataFrame(metrics_rows)
hybrid_df = hybrid_df[['question', 'answer', 'mpnet_ret_docs', 'splade_ret_docs',
                       'hybrid_ret_docs', 'weights', 'passages_with_ids', 'groundtruth_with_ids',
                       'MAP', 'NDCG', 'MAP@3', 'NDCG@3', 'MAP@5', 'NDCG@5', 'MAP@10', 'NDCG@10']]

# ------------- Save -------------
hybrid_df.to_csv('hotpotqa_hybrid_adaptive_results.csv', index=False)
print(f"Saved {len(hybrid_df)} rows to hotpotqa_hybrid_adaptive_results.csv")

# Print averages
avg_metrics = hybrid_df[['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']].mean().round(6)
print('Averages (Adaptive weights via classifier):')
for m in ['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']:
    print(f"{m}: {avg_metrics[m]}")

  from .autonotebook import tqdm as notebook_tqdm


Saved 80044 rows to hotpotqa_hybrid_adaptive_results.csv
Averages (Adaptive weights via classifier):
MAP: 0.245609
NDCG: 0.331681
MAP@3: 0.214106
NDCG@3: 0.271198
MAP@5: 0.225806
NDCG@5: 0.285706
MAP@10: 0.240078
NDCG@10: 0.317363


## Hybrid on Halubench

In [27]:
import pandas as pd

df = pd.read_csv('./prior_results/halubench_splade_modified_metrics.csv')
df.columns

Index(['question', 'answer', 'passage', 'splade_ret_docs', 'passages_with_ids',
       'groundtruth_with_ids', 'MAP@3', 'NDCG@3', 'MAP@5', 'NDCG@5', 'MAP@10',
       'NDCG@10'],
      dtype='object')

In [28]:
df = pd.read_csv('./prior_results/halubench_mpnet_modified_metrics.csv')
df.columns

Index(['question', 'answer', 'passage', 'mpnet_ret_docs', 'passages_with_ids',
       'groundtruth_with_ids', 'MAP@3', 'NDCG@3', 'MAP@5', 'NDCG@5', 'MAP@10',
       'NDCG@10'],
      dtype='object')

In [2]:
# Hybrid + Linear interpolation

import os
import pandas as pd
import json
import pytrec_eval  # ensure installed

CUTS = (3, 5, 10)
ALPHA = 0.3  # weight for mpnet; (1-ALPHA) is weight for splade

# ------------- Load paths -------------
splade_path = './prior_results/halubench_splade_modified_metrics.csv'
mpnet_path  = './prior_results/halubench_mpnet_modified_metrics.csv'

splade_df = pd.read_csv(splade_path)
mpnet_df  = pd.read_csv(mpnet_path)

# Normalize mpnet column name if needed
if 'all-mpnet-base-v2_ret_docs' in mpnet_df.columns:
    mpnet_df = mpnet_df.rename(columns={'all-mpnet-base-v2_ret_docs': 'mpnet_ret_docs'})

# ------------- Keep only necessary columns -------------
# From Splade we keep: question, answer, passage, splade_ret_docs, passages_with_ids, groundtruth_with_ids
splade_keep = ['question', 'answer', 'passage', 'splade_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']
missing_splade = set(splade_keep) - set(splade_df.columns)
if missing_splade:
    raise ValueError(f"Missing columns in splade_df: {missing_splade}")
splade_df = splade_df[splade_keep].copy()

# From MpNet we keep only mpnet_ret_docs (drop everything else)
mpnet_keep = ['mpnet_ret_docs']
missing_mpnet = set(mpnet_keep) - set(mpnet_df.columns)
if missing_mpnet:
    raise ValueError(f"Missing columns in mpnet_df: {missing_mpnet}")
mpnet_df = mpnet_df[mpnet_keep].copy()

# ------------- Align by row index and concatenate mpnet_ret_docs -------------
if len(splade_df) != len(mpnet_df):
    raise ValueError(f"Row count mismatch: Splade={len(splade_df)}, MpNet={len(mpnet_df)}. "
                     "Row-wise concatenation requires equal lengths.")

# Reset index to ensure alignment (optional but safe)
splade_df = splade_df.reset_index(drop=True)
mpnet_df  = mpnet_df.reset_index(drop=True)

# Concatenate the single column mpnet_ret_docs to splade_df
merged = pd.concat([splade_df, mpnet_df['mpnet_ret_docs']], axis=1)
print('Merged shape:', merged.shape)

# ------------- Parsing + fusion helpers -------------
def safe_load(x):
    if isinstance(x, (list, dict)):
        return x
    if pd.isna(x):
        return None
    s = str(x)
    try:
        return json.loads(s)
    except Exception:
        try:
            import ast
            return ast.literal_eval(s)
        except Exception:
            return None

def normalize_doc_id(x):
    return None if x is None else str(x).strip().lower()

def to_score_dict(items):
    """
    Convert retrieval list into dict: doc_id -> score (float).
    Accepts dicts containing {'doc_id', 'score'} or {'id'/'docid', 'sim'}.
    """
    d = {}
    if not items:
        return d
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did:
            continue
        score = it.get('score', it.get('sim'))
        try:
            score = float(score)
        except Exception:
            continue
        d[did] = score
    return d

def min_max_norm(d):
    """
    Min-max normalize a dict of {doc_id: score} per query to [0, 1].
    If empty or all-equal, returns zeros.
    """
    if not d:
        return {}
    vals = list(d.values())
    vmin, vmax = min(vals), max(vals)
    if vmax == vmin:
        return {k: 0.0 for k in d}
    return {k: (v - vmin) / (vmax - vmin) for k, v in d.items()}

def fuse_linear(mpnet_scores, splade_scores, alpha=0.5, normalize=True):
    """
    Linear interpolation fusion:
      fused = alpha * mpnet + (1 - alpha) * splade
    - If normalize, applies per-query min-max per source first.
    - Missing docs get 0 from the source where absent.
    Returns list of dicts: {'doc_id', 'score', 'rank'} sorted by score desc.
    """
    m = mpnet_scores or {}
    s = splade_scores or {}
    if normalize:
        m = min_max_norm(m)
        s = min_max_norm(s)
    all_ids = set(m.keys()) | set(s.keys())
    fused = {}
    for did in all_ids:
        ms = m.get(did, 0.0)
        ss = s.get(did, 0.0)
        fused[did] = alpha * ms + (1.0 - alpha) * ss
    sorted_items = sorted(fused.items(), key=lambda x: (-x[1], x[0]))
    out = []
    for i, (did, sc) in enumerate(sorted_items, start=1):
        out.append({'doc_id': did, 'score': float(sc), 'rank': i})
    return out

def to_qrels_docids_only(gt_items):
    """
    Build binary qrels: {doc_id: 1} from groundtruth_with_ids entries.
    Accepts list of dicts or list of ids.
    """
    rel = {}
    if not gt_items:
        return rel
    for it in gt_items:
        candidate = None
        if isinstance(it, dict):
            candidate = it.get('doc_id') or it.get('id') or it.get('docid')
        elif isinstance(it, (str, int)):
            candidate = it
        did = normalize_doc_id(candidate)
        if did:
            rel[did] = 1
    return rel

# ---------- NEW: build doc_id -> full_text map (prefer passages, then fall back to retriever lists) ----------
def build_doc_text_map(passages_items):
    """
    Build doc_id -> text from passages_with_ids first (prefers 'full_text', falls back to 'text'/'content').
    """
    mp = {}
    items = passages_items or []
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did:
            continue
        txt = it.get('full_text') or it.get('text') or it.get('content')
        if isinstance(txt, str) and txt.strip():
            mp[did] = txt
    return mp

def extend_doc_text_map(mp, retriever_items):
    """
    Enrich map from mpnet/splade lists (prefers 'full_text', falls back to 'snippet'/'preview_snippet'/'text').
    """
    items = retriever_items or []
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did or did in mp:
            continue
        txt = it.get('full_text') or it.get('snippet') or it.get('preview_snippet') or it.get('text')
        if isinstance(txt, str) and txt.strip():
            mp[did] = txt
    return mp

# ------------- Build runs + qrels (linear fusion) -------------
results_by_qid = {}  # qid -> {doc_id: score}
qrels_by_qid = {}    # qid -> {doc_id: relevance}
rows_out = []

for q_idx, row in merged.iterrows():
    mpnet_items = safe_load(row['mpnet_ret_docs'])
    splade_items = safe_load(row['splade_ret_docs'])
    gt_items     = safe_load(row['groundtruth_with_ids'])
    passages_items = safe_load(row['passages_with_ids'])

    mpnet_scores  = to_score_dict(mpnet_items)
    splade_scores = to_score_dict(splade_items)
    fused = fuse_linear(mpnet_scores, splade_scores, alpha=ALPHA, normalize=True)

    # normalize doc_ids in fused output
    for e in fused:
        e['doc_id'] = normalize_doc_id(e['doc_id'])

    # Build text map, preferring passages_with_ids, then enriching from the two retrieval lists
    doc_text_map = build_doc_text_map(passages_items)
    doc_text_map = extend_doc_text_map(doc_text_map, mpnet_items)
    doc_text_map = extend_doc_text_map(doc_text_map, splade_items)

    # Attach full_text to each fused entry (hybrid_ret_docs)
    hybrid_with_text = []
    for e in fused:
        did = e['doc_id']
        hybrid_with_text.append({
            'doc_id': did,
            'score': e['score'],
            'rank': e['rank'],
            'full_text': doc_text_map.get(did)  # may be None if unavailable
        })

    qid = f"q_{q_idx}"

    # run dict for pytrec_eval: doc_id -> fused score
    run_dict = {e['doc_id']: float(e['score']) for e in fused if e['doc_id'] is not None}
    results_by_qid[qid] = run_dict

    # qrels from groundtruth
    qrels = to_qrels_docids_only(gt_items)
    qrels_by_qid[qid] = {doc_id: int(rel) for doc_id, rel in qrels.items()}

    rows_out.append({
        'qid': qid,
        'question': row['question'],
        'answer': row['answer'],
        'passage': row['passage'],
        'mpnet_ret_docs': mpnet_items,
        'splade_ret_docs': splade_items,
        'hybrid_ret_docs': hybrid_with_text,  # now includes full_text
        'passages_with_ids': row['passages_with_ids'],
        'groundtruth_with_ids': gt_items
    })

# ------------- Evaluate (MAP/NDCG + cuts) -------------
metric_keys = {"map", "ndcg"} | {f"map_cut_{k}" for k in CUTS} | {f"ndcg_cut_{k}" for k in CUTS}
evaluator = pytrec_eval.RelevanceEvaluator(qrels_by_qid, metric_keys)
eval_res = evaluator.evaluate(results_by_qid)

def metric_at(pv, metric_base, k):
    key = f"{metric_base}_{k}"  # e.g., 'map_cut_3'
    return float(pv.get(key, 0.0))

metrics_rows = []
for r in rows_out:
    qid = r['qid']
    pv = eval_res.get(qid, {})
    metrics = {
        'MAP':    float(pv.get('map', 0.0)),
        'NDCG':   float(pv.get('ndcg', 0.0)),
        'MAP@3':  round(metric_at(pv, 'map_cut', 3), 6),
        'NDCG@3': round(metric_at(pv, 'ndcg_cut', 3), 6),
        'MAP@5':  round(metric_at(pv, 'map_cut', 5), 6),
        'NDCG@5': round(metric_at(pv, 'ndcg_cut', 5), 6),
        'MAP@10': round(metric_at(pv, 'map_cut', 10), 6),
        'NDCG@10':round(metric_at(pv, 'ndcg_cut', 10), 6),
    }
    metrics_rows.append({**r, **metrics})

hybrid_df = pd.DataFrame(metrics_rows)
hybrid_df = hybrid_df[['question', 'answer', 'passage',
                       'mpnet_ret_docs', 'splade_ret_docs', 'hybrid_ret_docs',
                       'passages_with_ids', 'groundtruth_with_ids',
                       'MAP', 'NDCG', 'MAP@3', 'NDCG@3', 'MAP@5', 'NDCG@5', 'MAP@10', 'NDCG@10']]

# ------------- Save -------------
out_dir = 'hybrid_results'
os.makedirs(out_dir, exist_ok=True)
out_csv = os.path.join(out_dir, f'halubench_hybrid_linear_results_splade_{ALPHA}_new.csv')

hybrid_df.to_csv(out_csv, index=False)
print(f"Saved {len(hybrid_df)} rows to {out_csv}")

# Print averages
avg_metrics = hybrid_df[['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']].mean().round(6)
print('Averages (alpha=%.2f):' % ALPHA)
for m in ['MAP','NDCG','MAP@3','NDCG@3','MAP@5','NDCG@5','MAP@10','NDCG@10']:
    print(f"{m}: {avg_metrics[m]}")

Merged shape: (14900, 7)
Saved 14900 rows to hybrid_results/halubench_hybrid_linear_results_splade_0.3_new.csv
Averages (alpha=0.30):
MAP: 0.844085
NDCG: 0.864219
MAP@3: 0.834239
NDCG@3: 0.841771
MAP@5: 0.838806
NDCG@5: 0.850033
MAP@10: 0.842712
NDCG@10: 0.859471


Saved 14900 rows to hybrid_results/halubench_hybrid_linear_results_splade_0.5.csv
Averages (alpha=0.50):
MAP: 0.834822
NDCG: 0.85738
MAP@3: 0.82443
NDCG@3: 0.834067
MAP@5: 0.829473
NDCG@5: 0.843159
MAP@10: 0.833532
NDCG@10: 0.852891

Saved 14900 rows to hybrid_results/halubench_hybrid_linear_results_splade_0.6.csv
Averages (alpha=0.40):
MAP: 0.843462
NDCG: 0.863791
MAP@3: 0.833434
NDCG@3: 0.841126
MAP@5: 0.838273
NDCG@5: 0.849883
MAP@10: 0.842167
NDCG@10: 0.859296

Saved 14900 rows to hybrid_results/halubench_hybrid_linear_results_splade_0.7.csv
Averages (alpha=0.30):
MAP: 0.844085
NDCG: 0.864219
MAP@3: 0.834239
NDCG@3: 0.841771
MAP@5: 0.838806
NDCG@5: 0.850033
MAP@10: 0.842712
NDCG@10: 0.859471



In [6]:
"""
HALUbench: SPLADE + MPNet Hybrid with Linear Interpolation + Cross-Encoder Reranking updated
Reranker model: cross-encoder/ms-marco-MiniLM-L-6-v2

Adds: 'full_text' to each item in both hybrid_ret_docs (fused) and hybrid_reranked_docs.

- Keeps HALUbench layout: row-wise concat of mpnet_ret_docs to the SPLADE dataframe.
- Linear interpolation (min–max per-source per-query) with ALPHA=0.3.
- Cross-encoder reranking atop the fused list, then blending with BETA.
- Evaluates fused vs reranked with pytrec_eval and saves results.
"""

import os
import json
from typing import List, Tuple, Dict, Optional

import pandas as pd
import pytrec_eval
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ----------------- Config -----------------
CUTS = (3, 5, 10)

# Linear interpolation weight for MPNet vs SPLADE (as in your HALUbench script)
ALPHA = 0.3   # fused = ALPHA * mpnet + (1 - ALPHA) * splade

# Reranking weight (beta weights the reranker vs fused score)
BETA = 0.85   # final = BETA * rerank_score + (1 - BETA) * fused_score

# Reranker pool and output serialization size
MAX_CANDIDATES_FOR_RERANK = 200
TOPK_SAVE = 100  # how many reranked docs to store per row

# Inputs
splade_path = './prior_results/halubench_splade_modified_metrics.csv'
mpnet_path  = './prior_results/halubench_mpnet_modified_metrics.csv'

# Outputs
out_csv = f'halubench_hybrid_linear_with_rerank_minilm_alpha_{ALPHA}_beta_{BETA}.csv'

# Reranker model settings
RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
RERANKER_MAX_LENGTH = 512
RERANKER_BATCH_SIZE = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32


# ----------------- Load -----------------
splade_df = pd.read_csv(splade_path)
mpnet_df  = pd.read_csv(mpnet_path)

# Normalize mpnet column name if needed
if 'all-mpnet-base-v2_ret_docs' in mpnet_df.columns:
    mpnet_df = mpnet_df.rename(columns={'all-mpnet-base-v2_ret_docs': 'mpnet_ret_docs'})

# Keep only necessary columns (HALUbench)
splade_keep = ['question', 'answer', 'passage', 'splade_ret_docs', 'passages_with_ids', 'groundtruth_with_ids']
missing_splade = set(splade_keep) - set(splade_df.columns)
if missing_splade:
    raise ValueError(f"Missing columns in splade_df: {missing_splade}")
splade_df = splade_df[splade_keep].copy()

mpnet_keep = ['mpnet_ret_docs']
missing_mpnet = set(mpnet_keep) - set(mpnet_df.columns)
if missing_mpnet:
    raise ValueError(f"Missing columns in mpnet_df: {missing_mpnet}")
mpnet_df = mpnet_df[mpnet_keep].copy()

# Align by index and concatenate mpnet_ret_docs
if len(splade_df) != len(mpnet_df):
    raise ValueError(f"Row count mismatch: Splade={len(splade_df)}, MpNet={len(mpnet_df)}. Row-wise concatenation requires equal lengths.")

splade_df = splade_df.reset_index(drop=True)
mpnet_df  = mpnet_df.reset_index(drop=True)

merged = pd.concat([splade_df, mpnet_df['mpnet_ret_docs']], axis=1)
print('Merged shape:', merged.shape)


# ----------------- Helpers -----------------
def safe_load(x):
    if isinstance(x, (list, dict)):
        return x
    if pd.isna(x):
        return None
    s = str(x)
    try:
        return json.loads(s)
    except Exception:
        try:
            import ast
            return ast.literal_eval(s)
        except Exception:
            return None

def normalize_doc_id(x):
    return None if x is None else str(x).strip().lower()

def to_score_dict(items):
    """
    Convert retrieval list into dict: doc_id -> score (float).
    Accepts dicts containing {'doc_id', 'score'} or {'id'/'docid', 'sim'}.
    """
    d = {}
    if not items:
        return d
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did:
            continue
        score = it.get('score', it.get('sim'))
        try:
            score = float(score)
        except Exception:
            continue
        d[did] = score
    return d

def min_max_norm(d):
    """
    Min–max normalize {doc_id: score} to [0, 1] per query.
    If empty or all-equal, returns zeros.
    """
    if not d:
        return {}
    vals = list(d.values())
    vmin, vmax = min(vals), max(vals)
    if vmax == vmin:
        return {k: 0.0 for k in d}
    return {k: (v - vmin) / (vmax - vmin) for k, v in d.items()}

def fuse_linear(mpnet_scores, splade_scores, alpha=0.5, normalize=True):
    """
    Linear interpolation fusion:
      fused = alpha * mpnet + (1 - alpha) * splade
    - If normalize, applies per-query min-max per source first.
    - Missing docs get 0 from the source where absent.
    Returns list of dicts: {'doc_id', 'score', 'rank'} sorted by score desc.
    """
    m = mpnet_scores or {}
    s = splade_scores or {}
    if normalize:
        m = min_max_norm(m)
        s = min_max_norm(s)
    all_ids = set(m.keys()) | set(s.keys())
    fused = {}
    for did in all_ids:
        ms = m.get(did, 0.0)
        ss = s.get(did, 0.0)
        fused[did] = alpha * ms + (1.0 - alpha) * ss
    sorted_items = sorted(fused.items(), key=lambda x: (-x[1], x[0]))
    out = []
    for i, (did, sc) in enumerate(sorted_items, start=1):
        out.append({'doc_id': did, 'score': float(sc), 'rank': i})
    return out

def to_qrels_docids_only(gt_items):
    """
    Build binary qrels: {doc_id: 1} from groundtruth_with_ids entries.
    Accepts list of dicts or list of ids.
    """
    rel = {}
    if not gt_items:
        return rel
    for it in gt_items:
        candidate = None
        if isinstance(it, dict):
            candidate = it.get('doc_id') or it.get('id') or it.get('docid')
        elif isinstance(it, (str, int)):
            candidate = it
        did = normalize_doc_id(candidate)
        if did:
            rel[did] = 1
    return rel

# ---- NEW: text mapping helpers ----
def build_doc_text_map(passages_items) -> Dict[str, str]:
    """
    Prefer passages_with_ids['full_text'|'text'|'content'].
    """
    mp = {}
    items = passages_items or []
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        txt = it.get('full_text') or it.get('text') or it.get('content')
        if did and isinstance(txt, str) and txt.strip():
            mp[did] = txt
    return mp

def extend_doc_text_map(mp, retriever_items):
    """
    Enrich with mpnet/splade text (prefer 'full_text', fallback to 'snippet'/'preview_snippet'/'text').
    """
    items = retriever_items or []
    for it in items:
        if not isinstance(it, dict):
            continue
        did = normalize_doc_id(it.get('doc_id') or it.get('id') or it.get('docid'))
        if not did or did in mp:
            continue
        txt = it.get('full_text') or it.get('snippet') or it.get('preview_snippet') or it.get('text')
        if isinstance(txt, str) and txt.strip():
            mp[did] = txt
    return mp


# ----------------- Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2 -----------------
class MiniLMReranker:
    """
    Cross-encoder reranker using 'cross-encoder/ms-marco-MiniLM-L-6-v2'.
    Returns a float score per (query, passage), where higher indicates higher relevance.
    """
    def __init__(self, model_name: str = RERANKER_MODEL_NAME, device: str = DEVICE, dtype=DTYPE, max_length: int = RERANKER_MAX_LENGTH):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.model.to(device)
        self.model.eval()
        self.device = device
        self.dtype = dtype
        self.max_length = max_length

    @torch.no_grad()
    def score(self, pairs: List[Tuple[str, str]], batch_size: int = RERANKER_BATCH_SIZE) -> List[float]:
        scores: List[float] = []
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i+batch_size]
            queries = [b[0] for b in batch]
            docs    = [b[1] for b in batch]
            enc = self.tokenizer(
                queries,
                docs,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            enc = {k: v.to(self.device) for k, v in enc.items()}
            logits = self.model(**enc).logits  # (B, 1) typically
            if logits.dim() == 2 and logits.size(1) == 1:
                batch_scores = logits.squeeze(1).detach().float().cpu().tolist()
            else:
                batch_scores = logits.view(-1).detach().float().cpu().tolist()
            scores.extend(batch_scores)
        return scores

reranker = MiniLMReranker()

def apply_reranker_on_fused_list(
    query: str,
    fused_list: List[Dict],
    doc_text_map: Dict[str, str],
    beta: float = BETA,
    max_candidates: int = MAX_CANDIDATES_FOR_RERANK,
    batch_size: int = RERANKER_BATCH_SIZE
) -> List[Dict]:
    """
    Input fused_list: [{'doc_id','score','rank'}]
    Output: [{'doc_id','fused_score','rerank_score','final_score','rank','full_text'}], sorted by final_score desc.
    """
    # Trim candidate pool
    pool = fused_list[:max_candidates]

    # Prepare (query, text) pairs
    pairs = []
    valid_idxs = []
    for i, e in enumerate(pool):
        did = e['doc_id']
        txt = doc_text_map.get(did)
        if isinstance(txt, str) and txt.strip():
            pairs.append((query, txt))
            valid_idxs.append(i)

    # If nothing to score, pass-through and still attach text if present
    if not pairs:
        out = []
        for i, e in enumerate(pool, start=1):
            did = e['doc_id']
            out.append({
                'doc_id': did,
                'fused_score': float(e['score']),
                'rerank_score': 0.0,
                'final_score': float(e['score']),
                'rank': i,
                'full_text': doc_text_map.get(did)
            })
        return out

    # Reranker scores
    rr_scores = reranker.score(pairs, batch_size=batch_size)

    # Attach scores, compute final
    for idx_local, s in zip(valid_idxs, rr_scores):
        pool[idx_local]['rerank_score'] = float(s)
    for e in pool:
        e['rerank_score'] = float(e.get('rerank_score', 0.0))
        e['fused_score'] = float(e.get('score', 0.0))
        e['final_score'] = beta * e['rerank_score'] + (1.0 - beta) * e['fused_score']

    # Sort and re-rank
    ranked = sorted(pool, key=lambda x: (-x['final_score'], x['doc_id']))

    # Build output with full_text
    out = []
    for new_rank, e in enumerate(ranked, start=1):
        did = e['doc_id']
        out.append({
            'doc_id': did,
            'fused_score': e['fused_score'],
            'rerank_score': e['rerank_score'],
            'final_score': e['final_score'],
            'rank': new_rank,
            'full_text': doc_text_map.get(did)
        })
    return out


# ----------------- Build runs + qrels (fusion + rerank) -----------------
results_by_qid_fused = {}     # qid -> {doc_id: fused score}
results_by_qid_reranked = {}  # qid -> {doc_id: final score}
qrels_by_qid = {}             # qid -> {doc_id: relevance}
rows_out = []

for q_idx, row in merged.iterrows():
    mpnet_items = safe_load(row['mpnet_ret_docs'])
    splade_items = safe_load(row['splade_ret_docs'])
    gt_items     = safe_load(row['groundtruth_with_ids'])
    passages_items = safe_load(row['passages_with_ids'])

    mpnet_scores  = to_score_dict(mpnet_items)
    splade_scores = to_score_dict(splade_items)

    fused = fuse_linear(mpnet_scores, splade_scores, alpha=ALPHA, normalize=True)

    # normalize doc_ids in fused output
    for e in fused:
        e['doc_id'] = normalize_doc_id(e['doc_id'])

    qid = f"q_{q_idx}"

    # Build text map from passages, then enrich from retriever lists if needed
    doc_text_map = build_doc_text_map(passages_items)
    doc_text_map = extend_doc_text_map(doc_text_map, mpnet_items)
    doc_text_map = extend_doc_text_map(doc_text_map, splade_items)

    # Rerank on top of fused list
    reranked = apply_reranker_on_fused_list(
        query=row['question'],
        fused_list=fused,
        doc_text_map=doc_text_map,
        beta=BETA,
        max_candidates=MAX_CANDIDATES_FOR_RERANK,
        batch_size=RERANKER_BATCH_SIZE
    )

    # Attach full_text to fused list for serialization
    fused_with_text = []
    for e in fused:
        did = e['doc_id']
        fused_with_text.append({
            'doc_id': did,
            'score': e['score'],
            'rank': e['rank'],
            'full_text': doc_text_map.get(did)
        })

    # run dicts
    run_fused = {e['doc_id']: float(e['score']) for e in fused if e['doc_id'] is not None}
    run_reranked = {e['doc_id']: float(e['final_score']) for e in reranked if e['doc_id'] is not None}

    results_by_qid_fused[qid] = run_fused
    results_by_qid_reranked[qid] = run_reranked

    # qrels
    qrels = to_qrels_docids_only(gt_items)
    qrels_by_qid[qid] = {doc_id: int(rel) for doc_id, rel in qrels.items()}

    rows_out.append({
        'qid': qid,
        'question': row['question'],
        'answer': row['answer'],
        'passage': row['passage'],
        'mpnet_ret_docs': mpnet_items,
        'splade_ret_docs': splade_items,
        'hybrid_ret_docs': fused_with_text,                 # with full_text
        'hybrid_reranked_docs': reranked[:TOPK_SAVE],       # with full_text
        'passages_with_ids': passages_items,
        'groundtruth_with_ids': gt_items
    })

# ----------------- Evaluate (MAP/NDCG + cuts) -----------------
metric_keys = {"map", "ndcg"} | {f"map_cut_{k}" for k in CUTS} | {f"ndcg_cut_{k}" for k in CUTS}
evaluator_fused = pytrec_eval.RelevanceEvaluator(qrels_by_qid, metric_keys)
evaluator_reranked = pytrec_eval.RelevanceEvaluator(qrels_by_qid, metric_keys)

eval_fused = evaluator_fused.evaluate(results_by_qid_fused)
eval_reranked = evaluator_reranked.evaluate(results_by_qid_reranked)

def metric_at(pv, metric_base, k):
    key = f"{metric_base}_{k}"
    return float(pv.get(key, 0.0))

def extract_metrics(pv):
    return {
        'MAP': float(pv.get('map', 0.0)),
        'NDCG': float(pv.get('ndcg', 0.0)),
        'MAP@3': round(metric_at(pv, 'map_cut', 3), 6),
        'NDCG@3': round(metric_at(pv, 'ndcg_cut', 3), 6),
        'MAP@5': round(metric_at(pv, 'map_cut', 5), 6),
        'NDCG@5': round(metric_at(pv, 'ndcg_cut', 5), 6),
        'MAP@10': round(metric_at(pv, 'map_cut', 10), 6),
        'NDCG@10': round(metric_at(pv, 'ndcg_cut', 10), 6),
    }

metrics_rows = []
for r in rows_out:
    qid = r['qid']
    fused_metrics = extract_metrics(eval_fused.get(qid, {}))
    rerank_metrics = extract_metrics(eval_reranked.get(qid, {}))
    metrics_rows.append({
        **r,
        **{f"FUSED_{k}": v for k, v in fused_metrics.items()},
        **{f"RERANK_{k}": v for k, v in rerank_metrics.items()},
    })

hybrid_df = pd.DataFrame(metrics_rows)
hybrid_df = hybrid_df[['question', 'answer', 'passage',
                       'mpnet_ret_docs', 'splade_ret_docs',
                       'hybrid_ret_docs', 'hybrid_reranked_docs',
                       'passages_with_ids', 'groundtruth_with_ids',
                       'FUSED_MAP', 'FUSED_NDCG', 'FUSED_MAP@3', 'FUSED_NDCG@3',
                       'FUSED_MAP@5', 'FUSED_NDCG@5', 'FUSED_MAP@10', 'FUSED_NDCG@10',
                       'RERANK_MAP', 'RERANK_NDCG', 'RERANK_MAP@3', 'RERANK_NDCG@3',
                       'RERANK_MAP@5', 'RERANK_NDCG@5', 'RERANK_MAP@10', 'RERANK_NDCG@10']]

# ------------- Save -------------
# Safer save in case out_csv has no directory
# dirpath = os.path.dirname(out_csv)
# if dirpath:
#     os.makedirs(dirpath, exist_ok=True)
hybrid_df.to_csv(out_csv, index=False)
print(f"Saved {len(hybrid_df)} rows to {out_csv}")

# Print averages
avg_fused = hybrid_df[['FUSED_MAP','FUSED_NDCG','FUSED_MAP@3','FUSED_NDCG@3','FUSED_MAP@5','FUSED_NDCG@5','FUSED_MAP@10','FUSED_NDCG@10']].mean().round(6)
avg_rerank = hybrid_df[['RERANK_MAP','RERANK_NDCG','RERANK_MAP@3','RERANK_NDCG@3','RERANK_MAP@5','RERANK_NDCG@5','RERANK_MAP@10','RERANK_NDCG@10']].mean().round(6)

print(f"Averages (alpha={ALPHA:.2f}) - FUSED:")
for m in ['FUSED_MAP','FUSED_NDCG','FUSED_MAP@3','FUSED_NDCG@3','FUSED_MAP@5','FUSED_NDCG@5','FUSED_MAP@10','FUSED_NDCG@10']:
    print(f"{m}: {avg_fused[m]}")

print(f"Averages (alpha={ALPHA:.2f}, beta={BETA:.2f}) - RERANK:")
for m in ['RERANK_MAP','RERANK_NDCG','RERANK_MAP@3','RERANK_NDCG@3','RERANK_MAP@5','RERANK_NDCG@5','RERANK_MAP@10','RERANK_NDCG@10']:
    print(f"{m}: {avg_rerank[m]}")

Merged shape: (14900, 7)
Saved 14900 rows to halubench_hybrid_linear_with_rerank_minilm_alpha_0.3_beta_0.85.csv
Averages (alpha=0.30) - FUSED:
FUSED_MAP: 0.844085
FUSED_NDCG: 0.864219
FUSED_MAP@3: 0.834239
FUSED_NDCG@3: 0.841771
FUSED_MAP@5: 0.838806
FUSED_NDCG@5: 0.850033
FUSED_MAP@10: 0.842712
FUSED_NDCG@10: 0.859471
Averages (alpha=0.30, beta=0.85) - RERANK:
RERANK_MAP: 0.855388
RERANK_NDCG: 0.872985
RERANK_MAP@3: 0.846342
RERANK_NDCG@3: 0.852801
RERANK_MAP@5: 0.850752
RERANK_NDCG@5: 0.860798
RERANK_MAP@10: 0.854345
RERANK_NDCG@10: 0.869451
