In [15]:
# eval/beir_eval.py
import json
import psycopg2
import requests
import zipfile
import os
from sentence_transformers import SentenceTransformer, CrossEncoder
from tqdm import tqdm
import numpy as np
from ranx import Qrels, Run, evaluate
import sys
import pathlib

# Add project root to path
# sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[1]))
from fusion.adaptive_fusion import get_alpha

In [16]:
# ================================
# CONFIGURATION
# ================================
DATASETS = ["scifact", "trec-covid"]
TOP_K = 100
RERANK_K = 30
FINAL_K = 10

# Connect to PostgreSQL
conn = psycopg2.connect(
    host="localhost",
    port=5433,
    dbname="ir_db",
    user="postgres",
    password="mysecretpassword"
)
conn.autocommit = True
cur = conn.cursor()

In [17]:
# Models
dense_model = SentenceTransformer('all-MiniLM-L6-v2')
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [18]:
# Map dataset names to table names
def get_table_name(dataset_name):
    return f"beir_{dataset_name.replace('-', '_')}"

def download_beir_dataset(dataset_name):
    """Download BEIR dataset metadata (queries and qrels only)"""
    dataset_path = f"data/beir/{dataset_name}"
    if os.path.exists(dataset_path):
        print(f"Dataset {dataset_name} already exists, skipping download.")
        return dataset_path
    
    print(f"Downloading {dataset_name}...")
    url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
    
    headers = {'User-Agent': 'Mozilla/5.0'}
    r = requests.get(url, timeout=120, headers=headers, stream=True)
    r.raise_for_status()
    
    os.makedirs("data/beir", exist_ok=True)
    zip_path = f"data/beir/{dataset_name}.zip"
    
    with open(zip_path, 'wb') as f:
        for chunk in r.iter_content(chunk_size=8192):
            f.write(chunk)
    
    print(f"Extracting...")
    with zipfile.ZipFile(zip_path) as z:
        z.extractall("data/beir")
    
    os.remove(zip_path)
    return dataset_path

def bm25_search(query, table_name, limit=TOP_K):
    """BM25 search on specified table"""
    cur.execute(f"""
        SELECT id, ts_rank_cd(clean_text, plainto_tsquery(%s)) AS score
        FROM {table_name}
        WHERE clean_text @@ plainto_tsquery(%s)
        ORDER BY score DESC LIMIT %s
    """, (query, query, limit))
    return [(row[0], float(row[1])) for row in cur.fetchall()]

def dense_search(query, table_name, limit=TOP_K):
    """Dense search on specified table"""
    q_emb = dense_model.encode(query, normalize_embeddings=True)
    cur.execute(f"""
        SELECT id, embedding <=> %s::vector AS dist
        FROM {table_name}
        ORDER BY dist LIMIT %s
    """, (q_emb.tolist(), limit))
    results = cur.fetchall()
    
    if not results:
        return []
    
    scores = [1 - row[1] for row in results]
    max_score = max(scores) if scores else 1
    normalized = [s / max_score for s in scores]
    return [(row[0], normalized[i]) for i, row in enumerate(results)]

def get_document_text(doc_id, table_name):
    """Retrieve document text for reranking"""
    cur.execute(f"SELECT text FROM {table_name} WHERE id = %s", (doc_id,))
    row = cur.fetchone()
    return row[0] if row else ""

def rerank_documents(query, doc_ids, table_name):
    """Rerank documents using cross-encoder"""
    passages = [get_document_text(doc_id, table_name)[:1000] for doc_id in doc_ids]
    pairs = [[query, passage] for passage in passages if passage]
    
    if not pairs:
        return []
    
    # Get raw logits from cross-encoder
    scores = reranker.predict(pairs)
    
    # Convert logits to positive scores using sigmoid
    # This ensures scores are between 0 and 1
    import math
    positive_scores = [1 / (1 + math.exp(-score)) for score in scores]
    
    return positive_scores

In [19]:
def load_dataset(dataset_name):
    """Load queries and qrels"""
    path = download_beir_dataset(dataset_name)
    
    queries_path = f"{path}/queries.jsonl"
    with open(queries_path) as f:
        # CRITICAL: Keep query IDs as strings exactly as they appear
        queries = {}
        for line in f:
            q = json.loads(line)
            qid = str(q['_id'])  # Ensure string
            queries[qid] = q['text']
    
    print(f"Loaded {len(queries)} queries. Sample query IDs: {list(queries.keys())[:5]}")
    
    # qrels_path = f"{path}/qrels/test.tsv"
    # with open(qrels_path) as f:
    #     qrels_dict = {}
    #     for i, line in enumerate(f):
    #         if i == 0 and 'query-id' in line.lower():
    #             continue
            
    #         parts = line.strip().split('\t')
    #         if len(parts) >= 3:
    #             qid = str(parts[0])  # Ensure string
    #             docid = str(parts[2])  # Ensure string
    #             rel = int(parts[3]) if len(parts) > 3 else 1
    #             qrels_dict.setdefault(qid, {})[docid] = rel
    qrels_path = f"{path}/qrels/test.tsv"
    with open(qrels_path) as f:
        qrels_dict = {}
        for i, line in enumerate(f):
            # Skip header if present
            if i == 0 and any(h in line.lower() for h in ["query-id", "qid", "topic"]):
                continue

            parts = line.strip().split('\t')
            if len(parts) < 3:
                continue

            if len(parts) == 3:
                # BEIR-style: query-id corpus-id score
                qid, docid, rel = parts
            else:
                # TREC-style: query-id Q0 docid rel [iteration...]
                qid, _q0, docid, rel = parts[:4]

            qid = str(qid)
            docid = str(docid)
            rel = int(rel)

            qrels_dict.setdefault(qid, {})[docid] = rel

    
    print(f"Loaded {len(qrels_dict)} qrels. Sample qrel query IDs: {list(qrels_dict.keys())[:5]}")
    
    return queries, qrels_dict

def rrf_fusion(bm25_results, dense_results, k=60):
    """Reciprocal Rank Fusion"""
    scores = {}
    for i, (doc_id, _) in enumerate(bm25_results):
        scores[str(doc_id)] = scores.get(str(doc_id), 0) + 1 / (i + k)
    for i, (doc_id, _) in enumerate(dense_results):
        scores[str(doc_id)] = scores.get(str(doc_id), 0) + 1 / (i + k)
    return scores

# ================================
# EVALUATION LOOP
# ================================
results_summary = {}

for dataset_name in DATASETS:
    print(f"\n{'='*60}")
    print(f"=== Evaluating on {dataset_name.upper()} ===")
    print(f"{'='*60}")
    
    table_name = get_table_name(dataset_name)
    
    # Check if table exists
    cur.execute("""
        SELECT EXISTS (
            SELECT FROM information_schema.tables 
            WHERE table_schema = 'public' 
            AND table_name = %s
        );
    """, (table_name,))
    
    if not cur.fetchone()[0]:
        print(f"❌ Table {table_name} does not exist!")
        print(f"Please run: python indexing/index_beir.py")
        continue
    
    # Verify table has documents
    cur.execute(f"SELECT COUNT(*) FROM {table_name}")
    doc_count = cur.fetchone()[0]
    print(f"✓ Using table: {table_name} ({doc_count:,} documents)")
    
    try:
        queries, qrels_dict = load_dataset(dataset_name)
        qrels = Qrels(qrels_dict)
        some_qid = next(iter(qrels_dict.keys()))
        print("Sample qid:", some_qid)
        print("Sample qrel docids for that qid:", list(qrels_dict[some_qid].keys())[:10])

        table_name = get_table_name(dataset_name)
        for did in list(qrels_dict[some_qid].keys())[:10]:
            cur.execute(f"SELECT COUNT(*) FROM {table_name} WHERE id = %s", (did,))
            count = cur.fetchone()[0]
            print(did, "exists in DB?" , count > 0)

        run_bm25 = {}
        run_dense = {}
        run_rrf = {}
        run_adaptive = {}
        
        for qid, query in tqdm(queries.items(), desc="Running queries"):
            try:
                # 1. BM25
                bm25_results = bm25_search(query, table_name, TOP_K)
                run_bm25[qid] = {str(doc_id): score for doc_id, score in bm25_results} if bm25_results else {}
                
                # 2. Dense
                dense_results = dense_search(query, table_name, TOP_K)
                run_dense[qid] = {str(doc_id): score for doc_id, score in dense_results} if dense_results else {}
                
                # 3. RRF
                rrf_scores = rrf_fusion(bm25_results, dense_results)
                run_rrf[qid] = rrf_scores if rrf_scores else {}
                
                # 4. Adaptive Fusion + Re-rank
                alpha = get_alpha(query)
                fused = {}
                bm25_dict = {doc_id: score for doc_id, score in bm25_results}
                dense_dict = {doc_id: score for doc_id, score in dense_results}
                all_ids = set(bm25_dict) | set(dense_dict)
                
                for doc_id in all_ids:
                    s_bm25 = bm25_dict.get(doc_id, 0)
                    s_dense = dense_dict.get(doc_id, 0)
                    fused[doc_id] = alpha * s_bm25 + (1 - alpha) * s_dense
                
                candidates = sorted(fused.items(), key=lambda x: x[1], reverse=True)[:RERANK_K]
                candidate_ids = [doc_id for doc_id, _ in candidates]
                
                if candidate_ids:
                    rerank_scores = rerank_documents(query, candidate_ids, table_name)
                    final_scored = sorted(zip(candidate_ids, rerank_scores), key=lambda x: x[1], reverse=True)[:FINAL_K]
                    run_adaptive[qid] = {str(doc_id): score for doc_id, score in final_scored}
                else:
                    run_adaptive[qid] = {}
                    
            except Exception as e:
                print(f"\nError processing query {qid}: {e}")
                continue
        
        # Save runs
        os.makedirs("eval/output", exist_ok=True)
        Run(run_bm25).save(f"eval/output/{dataset_name}_bm25.json")
        Run(run_dense).save(f"eval/output/{dataset_name}_dense.json")
        Run(run_rrf).save(f"eval/output/{dataset_name}_rrf.json")
        Run(run_adaptive).save(f"eval/output/{dataset_name}_adaptive.json")
        
        # Evaluate
        print(f"\nEvaluating {dataset_name.upper()}...")
        metrics = ["ndcg@10", "recall@100", "map@100"]
        eval_results = {}
        
        bm25_run = Run(run_bm25)
        bm25_scores = evaluate(qrels, bm25_run, metrics, make_comparable=True)
        eval_results["BM25"] = bm25_scores
        
        dense_run = Run(run_dense)
        dense_scores = evaluate(qrels, dense_run, metrics, make_comparable=True)
        eval_results["Dense (MiniLM)"] = dense_scores
        
        rrf_run = Run(run_rrf)
        rrf_scores = evaluate(qrels, rrf_run, metrics, make_comparable=True)
        eval_results["RRF"] = rrf_scores
        
        adaptive_run = Run(run_adaptive)
        adaptive_scores = evaluate(qrels, adaptive_run, metrics, make_comparable=True)
        eval_results["Adaptive Fusion + Re-rank (Yours)"] = adaptive_scores
        
        results_summary[dataset_name] = eval_results
        
        print(f"\nResults on {dataset_name.upper()}:")
        print("-" * 60)
        for method, scores in eval_results.items():
            print(f"{method}:")
            for metric, value in scores.items():
                print(f"  {metric}: {value:.4f}")
        print("-" * 60)
        
    except Exception as e:
        print(f"Error processing {dataset_name}: {e}")
        import traceback
        traceback.print_exc()
        continue

# Summary Table
print("\n" + "="*60)
print("=== SUMMARY TABLE FOR IEEE REPORT ===")
print("="*60)
print("| Method                        | SciFact nDCG@10 | TREC-COVID nDCG@10 |")
print("|-------------------------------|-----------------|--------------------|")
for method in ["BM25", "Dense (MiniLM)", "RRF", "Adaptive Fusion + Re-rank (Yours)"]:
    sci = results_summary.get("scifact", {}).get(method, {}).get("ndcg@10", "N/A")
    trec = results_summary.get("trec-covid", {}).get(method, {}).get("ndcg@10", "N/A")
    if isinstance(sci, float) and isinstance(trec, float):
        print(f"| {method:<29} | {sci:>15.4f} | {trec:>18.4f} |")
    else:
        print(f"| {method:<29} | {str(sci):>15} | {str(trec):>18} |")

print("\nEvaluation complete! Results saved in eval/output/")

cur.close()
conn.close()


=== Evaluating on SCIFACT ===
✓ Using table: beir_scifact (5,183 documents)
Dataset scifact already exists, skipping download.
Loaded 1109 queries. Sample query IDs: ['0', '2', '4', '6', '9']
Loaded 300 qrels. Sample qrel query IDs: ['1', '3', '5', '13', '36']
Sample qid: 1
Sample qrel docids for that qid: ['31715818']
31715818 exists in DB? True


Running queries: 100%|██████████| 1109/1109 [37:12<00:00,  2.01s/it] 



Evaluating SCIFACT...

Results on SCIFACT:
------------------------------------------------------------
BM25:
  ndcg@10: 0.0436
  recall@100: 0.0428
  map@100: 0.0428
Dense (MiniLM):
  ndcg@10: 0.6424
  recall@100: 0.8753
  map@100: 0.5998
RRF:
  ndcg@10: 0.6548
  recall@100: 0.8753
  map@100: 0.6140
Adaptive Fusion + Re-rank (Yours):
  ndcg@10: 0.6684
  recall@100: 0.8140
  map@100: 0.6170
------------------------------------------------------------

=== Evaluating on TREC-COVID ===
✓ Using table: beir_trec_covid (171,332 documents)
Dataset trec-covid already exists, skipping download.
Loaded 50 queries. Sample query IDs: ['1', '2', '3', '4', '5']
Loaded 50 qrels. Sample qrel query IDs: ['1', '2', '3', '4', '5']
Sample qid: 1
Sample qrel docids for that qid: ['005b2j4b', '00fmeepz', 'g7dhmyyo', '0194oljo', '021q9884', '02f0opkr', '047xpt2c', '04ftw7k9', 'pl9ht0d0', '05vx82oo']
005b2j4b exists in DB? True
00fmeepz exists in DB? True
g7dhmyyo exists in DB? True
0194oljo exists in DB? T

Running queries: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it]



Evaluating TREC-COVID...

Results on TREC-COVID:
------------------------------------------------------------
BM25:
  ndcg@10: 0.1986
  recall@100: 0.0208
  map@100: 0.0135
Dense (MiniLM):
  ndcg@10: 0.4661
  recall@100: 0.0412
  map@100: 0.0285
RRF:
  ndcg@10: 0.5003
  recall@100: 0.0565
  map@100: 0.0376
Adaptive Fusion + Re-rank (Yours):
  ndcg@10: 0.3549
  recall@100: 0.0088
  map@100: 0.0073
------------------------------------------------------------

=== SUMMARY TABLE FOR IEEE REPORT ===
| Method                        | SciFact nDCG@10 | TREC-COVID nDCG@10 |
|-------------------------------|-----------------|--------------------|
| BM25                          |          0.0436 |             0.1986 |
| Dense (MiniLM)                |          0.6424 |             0.4661 |
| RRF                           |          0.6548 |             0.5003 |
| Adaptive Fusion + Re-rank (Yours) |          0.6684 |             0.3549 |

Evaluation complete! Results saved in eval/output/
