In [1]:
# !pip install pandas transformers torch datasets pyarrow fastparquet ragas pymilvus
# !pip install sentence-transformers
# !pip install "pymilvus[milvus_lite]"
# !pip install evaluate

In [2]:
"""
What it does:
1) Load passages and QA test split from rag-mini-wikipedia (HuggingFace datasets via parquet).
2) Build two embedding sizes: 384-d (all-MiniLM-L6-v2) and 768-d (all-mpnet-base-v2).
3) Store passages + embeddings into Milvus Lite (local .db), index, and load.
4) Provide retrieval functions (top-k) and optional reranking (cross-encoder).
5) Load a small seq2seq LLM (flan-t5-base) and run RAG generation.
6) Evaluate with SQuAD EM/F1; runs on a small subset by default for speed.
7) (Optional) Run RAGAS on a small subset if dependencies are available.
"""

import os
import sys
import json
import time
import warnings
warnings.filterwarnings("ignore")

import pandas as pd

# --- Try pandas read_parquet with fastparquet fallback ---
def read_parquet_hf(path: str) -> pd.DataFrame:
    try:
        return pd.read_parquet(path)  # default engine
    except Exception:
        return pd.read_parquet(path, engine="fastparquet")

# --- Config ---
DB_PATH = "./rag_wikipedia_mini.db"
COLLECTION_384 = "rag_mini"
COLLECTION_768 = "rag_mini_768"

EMBED_MODEL_384 = "sentence-transformers/all-MiniLM-L6-v2"
EMBED_MODEL_768 = "sentence-transformers/all-mpnet-base-v2"

RERANK_CE = "cross-encoder/ms-marco-MiniLM-L-6-v2"

LLM_NAME = "google/flan-t5-base"

# Small sample sizes to keep runtime low on CPU
N_EVAL = int(os.environ.get("N_EVAL", "100"))
N_RAGAS = int(os.environ.get("N_RAGAS", "60"))

USE_RERANK = True
USE_RAGAS = True

In [3]:
# --- Load data ---
def load_data():
    passages_path = "hf://datasets/rag-datasets/rag-mini-wikipedia/data/passages.parquet/part.0.parquet"
    test_path = "hf://datasets/rag-datasets/rag-mini-wikipedia/data/test.parquet/part.0.parquet"
    print("Loading passages...")
    passages = read_parquet_hf(passages_path)
    print("Loading test questions...")
    queries = read_parquet_hf(test_path)
    # Basic cleaning
    passages = passages.dropna(subset=["passage"]).reset_index(drop=True)
    queries = queries.dropna(subset=["question", "answer"]).reset_index(drop=True)
    print(f"Passages: {len(passages)} | Queries: {len(queries)}")
    return passages, queries

# --- Embeddings ---
from sentence_transformers import SentenceTransformer
import numpy as np

def build_embeddings(passages: pd.DataFrame, model_name: str) -> np.ndarray:
    model = SentenceTransformer(model_name)
    vecs = model.encode(
        passages["passage"].tolist(),
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True
    )
    return vecs, model

# --- Milvus Lite setup ---
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType

def list_collection_names(client):
    cols = client.list_collections()
    if isinstance(cols, (list, tuple)) and cols:
        first = cols[0]
        if isinstance(first, dict):
            return [c.get("collection_name", "") for c in cols]
        return list(cols)  # already list[str]
    return []

def create_collection(client: MilvusClient, name: str, dim: int):
    if name in list_collection_names(client):
        client.drop_collection(name)
    id_f = FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False)
    text_f = FieldSchema(name="passage", dtype=DataType.VARCHAR, max_length=1000)
    emb_f = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
    schema = CollectionSchema(fields=[id_f, text_f, emb_f], description=f"RAG mini Wikipedia passages dim={dim}")
    client.create_collection(collection_name=name, schema=schema)

    idxp = MilvusClient.prepare_index_params()
    idxp.add_index(field_name="embedding", index_type="IVF_FLAT", metric_type="COSINE", params={"nlist": 128})
    client.create_index(collection_name=name, index_params=idxp)
    client.load_collection(name)

def insert_all(client: MilvusClient, name: str, passages: pd.DataFrame, embeddings: np.ndarray):
    rows = [
        {"id": int(i), "passage": str(passages["passage"].iloc[i]), "embedding": embeddings[i].tolist()}
        for i in range(len(passages))
    ]
    res = client.insert(collection_name=name, data=rows)
    return res.get("insert_count", 0)

# --- Retrieval helpers ---
def retrieve_contexts(client: MilvusClient, question: str, top_k: int, emb_model: SentenceTransformer, collection: str):
    vec = emb_model.encode([question], normalize_embeddings=True, convert_to_numpy=True).tolist()
    hits = client.search(
        collection_name=collection,
        data=vec,
        anns_field="embedding",
        limit=top_k,
        output_fields=["id", "passage"]
    )
    return [h["entity"]["passage"] for h in hits[0]]

# Optional reranking using a cross-encoder
def build_reranker():
    try:
        from sentence_transformers import CrossEncoder
        return CrossEncoder(RERANK_CE)
    except Exception as e:
        print("Reranker not available:", e)
        return None

def retrieve_with_rerank(client: MilvusClient, question: str, k: int, initial: int, emb_model: SentenceTransformer, collection: str, reranker=None):
    candidates = retrieve_contexts(client, question, top_k=initial, emb_model=emb_model, collection=collection)
    if not reranker:
        return candidates[:k]
    pairs = [(question, c) for c in candidates]
    scores = reranker.predict(pairs)
    ranked = [c for _, c in sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)]
    return ranked[:k]

# --- LLM for generation ---
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

def load_llm():
    tok = AutoTokenizer.from_pretrained(LLM_NAME)
    lm = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME)
    return tok, lm

def generate_answer(tok, lm, context_text: str, question_text: str) -> str:
    system_prompt = "You are a helpful assistant. Answer based only on the given context."
    prompt = f"""{system_prompt}

Context:
{context_text}

Question: {question_text}
Answer:"""
    inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=2048)
    out_ids = lm.generate(**inputs, max_new_tokens=128, do_sample=False, num_beams=4, early_stopping=True)
    return tok.decode(out_ids[0], skip_special_tokens=True)

def rewrite_query(tok, lm, q: str) -> str:
    sys_p = "Rewrite the question to be explicit, self-contained, and include key entities. Keep it short."
    p = f"""{sys_p}

Original question: {q}
Rewritten:"""
    inp = tok(p, return_tensors="pt", truncation=True, max_length=512)
    ids = lm.generate(**inp, max_new_tokens=64, do_sample=False, num_beams=4, early_stopping=True)
    return tok.decode(ids[0], skip_special_tokens=True)

# --- Evaluation (SQuAD EM/F1) ---
import evaluate
squad = evaluate.load("squad")

def run_eval_squad(queries_df, top_k: int, collection: str, emb_model: SentenceTransformer, tok, lm, N: int = 100, variant: str = "baseline", reranker=None):
    preds, refs, all_contexts = [], [], []
    for i, row in queries_df.iloc[:N].iterrows():
        q = str(row["question"])
        gold = str(row["answer"])

        if variant == "rewrite":
            q_used = rewrite_query(tok, lm, q)
        else:
            q_used = q

        if variant == "rerank":
            ctx_list = retrieve_with_rerank(client, q_used, k=top_k, initial=10, emb_model=emb_model, collection=collection, reranker=reranker)
        else:
            ctx_list = retrieve_contexts(client, q_used, top_k=top_k, emb_model=emb_model, collection=collection)

        context_text = "\n\n".join(ctx_list)
        gen = generate_answer(tok, lm, context_text, q_used)

        preds.append({"id": str(i), "prediction_text": gen})
        refs.append({"id": str(i), "answers": {"text": [gold], "answer_start": [0]}})
        all_contexts.append(ctx_list)

    metrics = squad.compute(predictions=preds, references=refs)
    return metrics, preds, all_contexts

In [4]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

passages, queries = load_data()

print("\nBuilding 384-d embeddings...")
emb_384, emb_model_384 = build_embeddings(passages, EMBED_MODEL_384)

print("Building 768-d embeddings...")
emb_768, emb_model_768 = build_embeddings(passages, EMBED_MODEL_768)

print("\nSetting up Milvus Lite at:", DB_PATH)
client = MilvusClient(DB_PATH)

print("Creating and loading collection (384)...")
create_collection(client, COLLECTION_384, dim=emb_384.shape[1])
inserted = insert_all(client, COLLECTION_384, passages, emb_384)
print("Inserted (384):", inserted)

print("Creating and loading collection (768)...")
create_collection(client, COLLECTION_768, dim=emb_768.shape[1])
inserted = insert_all(client, COLLECTION_768, passages, emb_768)
print("Inserted (768):", inserted)

Loading passages...
Loading test questions...
Passages: 3200 | Queries: 918

Building 384-d embeddings...


Batches: 100%|██████████████████████████████████| 100/100 [00:07<00:00, 13.77it/s]


Building 768-d embeddings...


Batches: 100%|██████████████████████████████████| 100/100 [00:37<00:00,  2.66it/s]



Setting up Milvus Lite at: ./rag_wikipedia_mini.db
Creating and loading collection (384)...
Inserted (384): 3200
Creating and loading collection (768)...
Inserted (768): 3200


In [5]:
print("\nLoading LLM:", LLM_NAME)
tok, lm = load_llm()

reranker = build_reranker() if USE_RERANK else None

param_grid = [
    (COLLECTION_384, emb_model_384, 1, "baseline"),
    (COLLECTION_384, emb_model_384, 3, "baseline"),
    (COLLECTION_384, emb_model_384, 5, "baseline"),
    (COLLECTION_768, emb_model_768, 1, "baseline"),
    (COLLECTION_768, emb_model_768, 3, "baseline"),
    (COLLECTION_768, emb_model_768, 5, "baseline"),
    (COLLECTION_384, emb_model_384, 3, "rewrite"),
    (COLLECTION_384, emb_model_384, 3, "rerank"),
]


Loading LLM: google/flan-t5-base


In [8]:
# --- Ragas ---
import os
import openai

os.environ["OPENAI_API_KEY"] = "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"
openai.OpenAI(api_key="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")

from datasets import Dataset, Features, Value, Sequence
from ragas import evaluate as ragas_evaluate
try:
    from ragas.metrics import answer_relevance
except Exception:
    from ragas.metrics import answer_correctness as answer_relevance

from ragas.metrics import faithfulness, context_precision, context_recall

# ultra-simple backoff (works even if not rate limited)
import time
def _with_backoff(fn, *args, max_retries=6, base=0.8, **kwargs):
    for i in range(max_retries):
        try:
            return fn(*args, **kwargs)
        except Exception as e:
            msg = str(e).lower()
            is_rate = ('rate limit' in msg) or ('429' in msg) or ('too many requests' in msg)
            if not is_rate or i == max_retries - 1:
                raise
            delay = base * (2 ** i)
            print(f"[ragas/retry] sleeping {delay:.1f}s (attempt {i+1}/{max_retries})")
            time.sleep(delay)


def make_weak_gt_single(queries_df, contexts, N):
    gts = []
    for i in range(N):
        ans = str(queries_df["answer"].iloc[i]).strip().lower()
        cand_ctx = contexts[i]
        best = ""
        if ans:
            for c in cand_ctx:
                if ans in c.lower():
                    best = c
                    break
        gts.append(best)
    return gts

def build_ragas_dataset(queries_df, predictions, contexts, N):
    gt = make_weak_gt_single(queries_df, contexts, N)

    features = Features({
        "question":          Value("string"),
        "answer":            Value("string"),
        "generated_answer":  Value("string"),
        "contexts":          Sequence(Value("string")),
        "ground_truth":      Value("string"),
    })

    data = {
        "question":          queries_df["question"].iloc[:N].astype(str).tolist(),
        "answer":            queries_df["answer"].iloc[:N].astype(str).tolist(),
        "generated_answer":  predictions[:N],    # list[str]
        "contexts":          contexts[:N],       # list[list[str]]
        "ground_truth":      gt[:N],             # list[str]
    }

    ds = Dataset.from_dict(data, features=features)
    return ds

def run_ragas_eval(queries_df, predictions, contexts, N=60):
    RAGAS_LLM = HuggingfaceLLM(model="google/flan-t5-base")
    RAGAS_EMB = HuggingfaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    
    ds = build_ragas_dataset(queries_df, predictions, contexts, N)
    res = ragas_evaluate(
        ds,
        metrics=[faithfulness, context_precision, context_recall, answer_relevance],
        llm=RAGAS_LLM,
        embeddings=RAGAS_EMB
    )
    return {k: float(res[k]) for k in res}

ImportError: cannot import name 'HuggingfaceLLM' from 'ragas.llms' (/Users/kevin/miniconda3/envs/llm2025/lib/python3.10/site-packages/ragas/llms/__init__.py)

In [10]:
# -------------------------------
# Super-cheap "RAGAS-like" metrics
#   - No OpenAI / No HF
#   - Heuristic, fast, good-enough to compare naive vs enhanced
# -------------------------------

import re
from collections import Counter

def _normalize(txt: str) -> str:
    if txt is None:
        return ""
    txt = str(txt).lower().strip()
    # collapse whitespace
    txt = re.sub(r"\s+", " ", txt)
    return txt

def _token_set(s: str) -> set:
    s = _normalize(s)
    return set(re.findall(r"[a-z0-9]+", s))

def _contains_answer(ctx: str, ans: str) -> bool:
    return _normalize(ans) in _normalize(ctx) if ans else False

def cheap_faithfulness(generated: str, contexts: list[str]) -> float:
    """
    Heuristic: how much of the generated answer is supported by contexts?
    score = (# tokens of gen that also appear in concatenated contexts) / (# tokens of gen + 1e-9)
    """
    gen_tokens = _token_set(generated)
    ctx_tokens = set()
    for c in contexts:
        ctx_tokens |= _token_set(c)
    if not gen_tokens:
        return 0.0
    return len(gen_tokens & ctx_tokens) / (len(gen_tokens) + 1e-9)

def cheap_context_precision(answer: str, contexts: list[str]) -> float:
    """
    Heuristic precision: fraction of retrieved contexts that contain the gold answer string.
    """
    if not contexts:
        return 0.0
    hits = sum(1 for c in contexts if _contains_answer(c, answer))
    return hits / len(contexts)

def cheap_context_recall(answer: str, all_corpus_hit: int | None, contexts: list[str]) -> float:
    """
    Heuristic recall: if we don't know total # of relevant contexts in the corpus,
    approximate recall as 1.0 if ANY retrieved context contains the answer string, else 0.0.
    (Good enough to compare naive vs enhanced.)
    """
    return 1.0 if any(_contains_answer(c, answer) for c in contexts) else 0.0

def cheap_answer_relevance(question: str, generated: str) -> float:
    """
    Heuristic relevance: lexical overlap between question and generated answer.
    """
    q = _token_set(question)
    g = _token_set(generated)
    if not g:
        return 0.0
    return len(q & g) / (len(g) + 1e-9)

def run_ragas_eval_cheapo(queries_df, predictions: list[str], contexts: list[list[str]], N: int = 60):
    """
    Drop-in replacement for run_ragas_eval(...), but no external LLM/embeddings.
    Returns a dict with keys similar to RAGAS.
    """
    n = min(N, len(predictions), len(contexts), len(queries_df))
    faith, cprec, crec, arel = [], [], [], []
    for i in range(n):
        q   = str(queries_df["question"].iloc[i])
        ans = str(queries_df["answer"].iloc[i])
        gen = predictions[i]
        ctx = contexts[i]

        faith.append(cheap_faithfulness(gen, ctx))
        cprec.append(cheap_context_precision(ans, ctx))
        crec.append(cheap_context_recall(ans, None, ctx))
        arel.append(cheap_answer_relevance(q, gen))

    # Simple means
    out = {
        "faithfulness": float(np.mean(faith)) if faith else 0.0,
        "context_precision": float(np.mean(cprec)) if cprec else 0.0,
        "context_recall": float(np.mean(crec)) if crec else 0.0,
        "answer_relevance": float(np.mean(arel)) if arel else 0.0,
    }
    return out

In [11]:
NAIVE_VARIANT = "baseline"   
ENHANCED_VARIANT = "rerank"   
NAIVE_TOPK = 3
ENHANCED_TOPK = 3
NAIVE_DIM = 384

def is_target_for_ragas(col, k, variant):
    dim = 384 if col == COLLECTION_384 else 768
    return (
        (variant == NAIVE_VARIANT and k == NAIVE_TOPK and dim == NAIVE_DIM) or
        (variant == ENHANCED_VARIANT and k == ENHANCED_TOPK)
    )

rows = []
for col, emb_model_used, k, variant in param_grid:
    if not is_target_for_ragas(col, k, variant):
        continue
    
    t0 = time.time()
    print(f"\n=== Running {variant} | top_k={k} | dim={384 if col==COLLECTION_384 else 768} ===")


    if USE_RAGAS and is_target_for_ragas(col, k, variant):
        res, preds, all_contexts = run_eval_squad(
            queries_df=queries,
            top_k=k, collection=col, emb_model=emb_model_used,
            tok=tok, lm=lm, N=N_EVAL, variant=variant,
            reranker=reranker
        )
        ragas_dict = ragas_dict = run_ragas_eval_cheapo(
            queries_df=queries,
            predictions=[p["prediction_text"] for p in preds],
            contexts=all_contexts,
            N=N_RAGAS
        )
    # else:
    #     res, preds, all_contexts = run_eval_squad(
    #         queries_df=queries,
    #         top_k=k, collection=col, emb_model=emb_model_used,
    #         tok=tok, lm=lm, N=N_EVAL, variant=variant,
    #         reranker=reranker
    #     )
    #     ragas_dict = None

    dt = time.time() - t0
    print(f" -> Done. Time: {dt:.1f} sec | EM={res['exact_match']:.3f} | F1={res['f1']:.3f}")

    dim = 384 if col == COLLECTION_384 else 768
    row = {"embedding_dim": dim, "top_k": k, "variant": variant,
           "EM": res["exact_match"], "F1": res["f1"], "secs": dt}
    if ragas_dict:
        row.update(ragas_dict)
        row["ragas_role"] = "naive" if variant == NAIVE_VARIANT else "enhanced"
    rows.append(row)

df = pd.DataFrame(rows)
df.to_csv("results_grids_ragas.csv", index=False)
print("Saved results to results_grid.csv")


=== Running baseline | top_k=3 | dim=384 ===
 -> Done. Time: 54.9 sec | EM=57.000 | F1=63.984

=== Running rerank | top_k=3 | dim=384 ===
 -> Done. Time: 70.2 sec | EM=67.000 | F1=73.893
Saved results to results_grid.csv
