<a href="https://colab.research.google.com/github/lakshaychitkara/2210990536/blob/main/Research-paper-RAGed-Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##setup

In [None]:
!pip -q install -U transformers accelerate bitsandbytes sentence-transformers faiss-cpu datasets evaluate nltk pandas tqdm

import os, re, time, math, json
import pandas as pd
from tqdm import tqdm

import nltk
nltk.download("punkt")
from nltk.tokenize import sent_tokenize

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
import evaluate

In [None]:
#Load dataset (SciFact-style: question + answer + optional evidence)
# This uses a lightweight public dataset via datasets. If you want your own dataset later, you’ll just supply the same schema.

In [None]:
from datasets import load_dataset

# Small, fast demo dataset
# We'll use "scifact" (question/claim + abstracts) style using the BEIR variant isn't directly in HF in one format,
# so we will use a simpler QA dataset for correctness metrics and treat retrieved docs as citations.
# Recommended for "evidence doc id" experiments: later swap to BEIR SciFact with explicit corpus/qrels.

ds = load_dataset("nq_open", split="train[:200]")  # keep small for Colab
ds = ds.map(lambda x: {"question": x["question"], "answer": x["answer"][0] if len(x["answer"]) else ""})
ds[0]

Cell 3 — Build a tiny document corpus

For controlled experiments, you need a fixed corpus. For demo, we’ll build a corpus from Wikipedia-like snippets shipped in HF datasets (fast).
Later, you can replace this with your paper PDFs, website dumps, or your own dataset corpus.

In [None]:
from datasets import load_dataset

# Modern Wikipedia dataset (no script error)
corpus_ds = load_dataset(
    "wikimedia/wikipedia",
    "20231101.en",
    split="train[:2000]"
)

raw_docs = [
    d["text"].replace("\n", " ").strip()
    for d in corpus_ds
    if d.get("text")
]

raw_docs = [t for t in raw_docs if len(t) > 400]
raw_docs = raw_docs[:1500]

print(len(raw_docs))
print(raw_docs[0][:300])

Cell 4 — Chunking functions

In [None]:
def chunk_fixed(text, chunk_size=512, overlap=64):
    words = text.split()
    chunks = []
    i = 0
    while i < len(words):
        chunk = " ".join(words[i:i+chunk_size])
        if chunk.strip():
            chunks.append(chunk)
        i += max(1, chunk_size - overlap)
    return chunks

def chunk_sentence(text, max_words=220):
    sents = sent_tokenize(text)
    chunks, buf = [], []
    buf_words = 0
    for s in sents:
        w = s.split()
        if buf_words + len(w) > max_words and buf:
            chunks.append(" ".join(buf))
            buf, buf_words = [], 0
        buf.append(s)
        buf_words += len(w)
    if buf:
        chunks.append(" ".join(buf))
    return chunks

def build_chunks(corpus_texts, strategy):
    all_chunks = []
    for doc_id, text in enumerate(corpus_texts):
        if strategy["type"] == "fixed":
            chunks = chunk_fixed(text, strategy["chunk_size"], strategy["overlap"])
        elif strategy["type"] == "sentence":
            chunks = chunk_sentence(text, strategy["max_words"])
        else:
            raise ValueError("Unknown chunking strategy")
        for ci, c in enumerate(chunks):
            all_chunks.append({
                "doc_id": doc_id,
                "chunk_id": f"{doc_id}_{ci}",
                "text": c
            })
    return all_chunks

Cell 5 — Embeddings + FAISS index builder

In [None]:
embed_model_name = "sentence-transformers/all-MiniLM-L6-v2"
embedder = SentenceTransformer(embed_model_name)

def build_faiss_index(chunks):
    texts = [c["text"] for c in chunks]
    embs = embedder.encode(texts, batch_size=64, show_progress_bar=True, convert_to_numpy=True, normalize_embeddings=True)
    dim = embs.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embs)
    return index, embs

Cell 5 — Embeddings + FAISS index builder

In [None]:
embed_model_name = "sentence-transformers/all-MiniLM-L6-v2"
embedder = SentenceTransformer(embed_model_name)

def build_faiss_index(chunks):
    texts = [c["text"] for c in chunks]
    embs = embedder.encode(texts, batch_size=64, show_progress_bar=True, convert_to_numpy=True, normalize_embeddings=True)
    dim = embs.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embs)
    return index, embs

Cell 6 — Retriever + optional reranker

In [None]:
reranker_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker = CrossEncoder(reranker_name)

def retrieve(query, index, chunks, top_k=5):
    q_emb = embedder.encode([query], normalize_embeddings=True, convert_to_numpy=True)
    scores, ids = index.search(q_emb, top_k)
    hits = []
    for score, idx in zip(scores[0], ids[0]):
        c = chunks[int(idx)]
        hits.append({**c, "score": float(score)})
    return hits

def rerank(query, hits, top_k=5):
    pairs = [(query, h["text"]) for h in hits]
    rr_scores = reranker.predict(pairs)
    reranked = []
    for h, s in zip(hits, rr_scores):
        reranked.append({**h, "rr_score": float(s)})
    reranked.sort(key=lambda x: x["rr_score"], reverse=True)
    return reranked[:top_k]

Cell 7 — Load a small instruct LLM (4-bit)

This runs on typical Colab T4/L4. If you get OOM, reduce model size.

In [None]:
model_name = "Qwen/Qwen2.5-3B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)
model.eval()

Cell 8 — Prompts + generation wrapper (with citation format)

We’ll force citations like [1], [2] pointing to retrieved chunks.

In [None]:
def build_prompt(question, contexts, tactic="normal"):
    # contexts: list of dicts with "text"
    # Create numbered sources
    sources = []
    for i, c in enumerate(contexts, start=1):
        sources.append(f"[{i}] {c['text']}")
    sources_block = "\n\n".join(sources)

    if tactic == "normal":
        instr = (
            "Answer the question using the sources below. "
            "If unsure, say you don't know. Add citations like [1] after the sentences they support."
        )
    elif tactic == "strict":
        instr = (
            "You MUST answer using ONLY the sources below. "
            "If the answer is not explicitly stated, reply exactly: \"I don't know based on the provided sources.\" "
            "Every factual sentence must end with at least one citation like [1]."
        )
    else:
        raise ValueError("Unknown tactic")

    user = f"Question: {question}\n\nSources:\n{sources_block}\n\n{instr}\nAnswer:"
    return user

@torch.no_grad()
def generate_answer(prompt, decoding):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    t0 = time.perf_counter()

    out = model.generate(
        **inputs,
        max_new_tokens=decoding["max_new_tokens"],
        do_sample=decoding["temperature"] > 0,
        temperature=decoding["temperature"],
        top_p=decoding["top_p"],
        repetition_penalty=1.05,
        pad_token_id=tokenizer.eos_token_id
    )

    latency = time.perf_counter() - t0
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    # Keep only after "Answer:" (best-effort)
    ans = text.split("Answer:")[-1].strip()
    prompt_tokens = int(inputs["input_ids"].shape[1])
    gen_tokens = int(out.shape[1] - inputs["input_ids"].shape[1])
    return ans, latency, prompt_tokens, gen_tokens

def extract_citations(answer_text):
    # returns cited source indices like [1], [2]
    cites = re.findall(r"\[(\d+)\]", answer_text)
    return [int(c) for c in cites if c.isdigit()]

Cell 9 — Metrics (quality + citation accuracy + basic faithfulness proxy)

For demo:

Quality: token-level F1 (SQuAD-like)

Citation accuracy: proportion of cited ids that are valid, plus “has any citation”

Faithfulness proxy (cheap): answer sentences that have at least one citation

In [None]:
squad_metric = evaluate.load("squad")

def squad_like_f1(pred, gold):
    # evaluate.squad expects list of dicts
    refs = [{"id":"0", "answers":{"text":[gold], "answer_start":[0]}}]
    preds = [{"id":"0", "prediction_text": pred}]
    return squad_metric.compute(predictions=preds, references=refs)["f1"]

def faithfulness_proxy(answer):
    sents = [s.strip() for s in re.split(r"[.\n]", answer) if s.strip()]
    if not sents:
        return 0.0
    cited = 0
    for s in sents:
        if re.search(r"\[\d+\]", s):
            cited += 1
    return cited / len(sents)

def citation_metrics(answer, n_sources):
    cites = extract_citations(answer)
    if len(cites) == 0:
        return {"has_citation": 0, "valid_cite_rate": 0.0}
    valid = sum(1 for c in cites if 1 <= c <= n_sources)
    return {
        "has_citation": 1,
        "valid_cite_rate": valid / len(cites)
    }

Cell 10 — Experiment grid + runner

In [None]:
chunking_grid = [
    {"name":"fixed_512", "type":"fixed", "chunk_size":512, "overlap":64},
    {"name":"fixed_256", "type":"fixed", "chunk_size":256, "overlap":32},
    {"name":"sentence_220w", "type":"sentence", "max_words":220},
]

retrieval_grid = [
    {"top_k": 3},
    {"top_k": 5},
    {"top_k": 10},
]

rerank_grid = [
    {"rerank": False},
    {"rerank": True},
]

prompt_grid = [
    {"tactic": "normal"},
    {"tactic": "strict"},
]

decoding_grid = [
    {"name":"deterministic", "temperature":0.0, "top_p":1.0, "max_new_tokens":160},
    {"name":"balanced", "temperature":0.4, "top_p":0.9, "max_new_tokens":160},
]

def run_experiment(ds, corpus_texts, n_questions=50):
    rows = []
    sample = ds.select(range(min(n_questions, len(ds))))

    for ch in chunking_grid:
        print(f"\n--- Building chunks & index: {ch['name']} ---")
        chunks = build_chunks(corpus_texts, ch)
        index, _ = build_faiss_index(chunks)

        for r in retrieval_grid:
            for rr in rerank_grid:
                for pg in prompt_grid:
                    for dec in decoding_grid:
                        config_id = f"{ch['name']}|k{r['top_k']}|rr{int(rr['rerank'])}|{pg['tactic']}|{dec['name']}"
                        print(f"\nRunning: {config_id}")

                        for ex in tqdm(sample, desc=config_id):
                            q = ex["question"]
                            gold = ex["answer"]

                            t_retr0 = time.perf_counter()
                            hits = retrieve(q, index, chunks, top_k=r["top_k"])
                            retr_latency = time.perf_counter() - t_retr0

                            if rr["rerank"]:
                                t_rr0 = time.perf_counter()
                                hits = rerank(q, hits, top_k=r["top_k"])
                                rr_latency = time.perf_counter() - t_rr0
                            else:
                                rr_latency = 0.0

                            prompt = build_prompt(q, hits, tactic=pg["tactic"])
                            pred, gen_latency, p_tok, g_tok = generate_answer(prompt, dec)

                            f1 = squad_like_f1(pred, gold)
                            faith = faithfulness_proxy(pred)
                            cmet = citation_metrics(pred, n_sources=len(hits))

                            rows.append({
                                "config": config_id,
                                "question": q,
                                "gold": gold,
                                "pred": pred,
                                "f1": f1,
                                "faithfulness_proxy": faith,
                                "has_citation": cmet["has_citation"],
                                "valid_cite_rate": cmet["valid_cite_rate"],
                                "retrieval_latency_s": retr_latency,
                                "rerank_latency_s": rr_latency,
                                "gen_latency_s": gen_latency,
                                "prompt_tokens": p_tok,
                                "gen_tokens": g_tok,
                                "total_latency_s": retr_latency + rr_latency + gen_latency
                            })

    return pd.DataFrame(rows)

df = run_experiment(ds, raw_docs, n_questions=30)
df.head()

Cell 11 — Aggregate results + save

In [None]:
agg = (df.groupby("config")
       .agg(
           mean_f1=("f1","mean"),
           mean_faith=("faithfulness_proxy","mean"),
           citation_presence=("has_citation","mean"),
           mean_valid_cite=("valid_cite_rate","mean"),
           mean_total_latency=("total_latency_s","mean"),
           mean_prompt_tokens=("prompt_tokens","mean"),
           mean_gen_tokens=("gen_tokens","mean"),
       )
       .reset_index()
       .sort_values(["mean_f1","mean_faith","mean_valid_cite","mean_total_latency"], ascending=[False,False,False,True]))

display(agg.head(20))

df.to_csv("rag_design_study_runs.csv", index=False)
agg.to_csv("rag_design_study_summary.csv", index=False)

print("Saved: rag_design_study_runs.csv, rag_design_study_summary.csv")

2) What you do after running this code

Start with the demo (as-is) and confirm it runs end-to-end.

You’ll get rag_design_study_runs.csv (per-question) and rag_design_study_summary.csv (per-config).

Switch from the demo corpus to your real corpus (papers/articles).

Replace raw_docs with your documents:

If you have PDFs: extract text (PyMuPDF) → list of strings.

If you have web pages: store cleaned text per page.

Switch from the demo QA dataset to your evaluation datasets.

Make a CSV with columns: question, answer (and optionally gold_doc_id).

Load it in Colab and convert to the same schema used in the notebook.

Add interaction analyses (your “pairwise interaction effects” objective).

Keep the grid small (e.g., chunking × top_k) and plot heatmaps:

mean_f1 vs (chunking, top_k)

mean_latency vs (rerank, top_k)

Upgrade “faithfulness/factuality” measurement for your paper.

Keep the cheap proxy in the tables (it’s useful), but add:

an LLM-as-a-judge faithfulness score (open-source judge or API),

and a citation-to-evidence alignment check (sentence-level entailment, or judge).

For a proper RAG paper, you should eventually move to:

BEIR SciFact

BEIR Natural Questions

HotpotQA

your own domain corpus