# Counterfactual Evidence-Based Explanation for RAG

This notebook demonstrates a post-hoc explanation method for retrieval-augmented generation (RAG) systems that focuses on **evidence support rather than semantic similarity**.

Given a question and a set of retrieved documents, the system first generates a final answer using the standard RAG pipeline.  
It then explains this answer by identifying **which parts of the retrieved documents explicitly support the answer content**.

Unlike similarity-based or attention-based highlighting methods, this approach treats explanation as an **evidence verification problem**:  
a sentence is highlighted only if it can be shown to directly support the generated answer.

The method is:
- model-agnostic and fully post-hoc
- compatible with local LLMs (e.g. Ollama)
- lightweight and bounded in runtime
- designed to reduce over-highlighting and spurious evidence

The resulting highlights aim to answer the question:  
**“Which retrieved text actually justifies the answer?”**

Con: we use the LLM to awnser the question of justifiying the awnser. If our LLM is poor, this can go wrong.

In [1]:
# Cell 1: Path + import hygieneimport sysfrom pathlib import Pathproject_root = next((p for p in [Path.cwd()] + list(Path.cwd().parents) if (p / "src").exists()), None)if project_root is None:    raise RuntimeError('"src" directory not found. Run notebook inside the repo.')root_str = str(project_root)if root_str not in sys.path:    sys.path.insert(0, root_str)bad_markers = ["\\.venv\\src\\", "/.venv/src/"]sys.path = [p for p in sys.path if not any(m in p for m in bad_markers)]print("project_root:", project_root)print("sys.path[0]:", sys.path[0])

In [2]:
# Cell 2: Runtime checks (python + ollama)import sysimport requestsprint("python:", sys.version)print("executable:", sys.executable)r = requests.get("http://localhost:11434/api/tags", timeout=10)r.raise_for_status()models = [m["name"] for m in r.json().get("models", [])]print("ollama models:", models)

In [3]:
# Cell 3: Imports (project + deps)import reimport htmlimport numpy as npfrom IPython.display import display, HTMLfrom tqdm.auto import tqdmfrom src.modules.rag_engine import RAGEnginefrom src.modules.llm_client import LLMClientfrom src.modules.data_loader_single_hop import BoolQDataLoader

In [4]:
# Cell 4: ConfigOLLAMA_MODEL = "qwen3-vl:8b"PERSIST_DIR = "../data/vector_db_bool"TOP_K_DOCS = 4MAX_SENTENCES_PER_DOC = 10MAX_SUPPORT_PER_CLAIM = 2print("Config loaded")

In [5]:
# Cell 5: Load data + build vector DBdata_loader = BoolQDataLoader()documents = data_loader.setup()print("loaded documents:", len(documents))rag = RAGEngine(persist_dir=PERSIST_DIR)rag.setup(documents=documents)client = LLMClient(provider="ollama", model_name=OLLAMA_MODEL)llm = client.get_llm()print("READY")

In [ ]:
# We build helpers to link highlighted sentences to indices.# These are needed for comprehensiveness and sufficiency.def mask_remove(indices, sentences):    return " ".join([s for i, s in enumerate(sentences) if i not in indices])def mask_except(indices, sentences):    return " ".join([s for i, s in enumerate(sentences) if i in indices])def get_highlight_indices(sentences, highlights_by_doc):    hl = set()    for doc_sents in highlights_by_doc.values():        for h in doc_sents:            for i, s in enumerate(sentences):                if h.strip() == s.strip():                    hl.add(i)    return sorted(list(hl))

In [ ]:
# We measure answer drops for removing and keeping evidence.# This implements comprehensiveness and sufficiency.def answer_from_context(question, ctx):    p = f"""Answer using only the context.If the answer is missing, say unknown.Return only the final answer.Question: {question}Context:{ctx}Answer:""".strip()    return llm.invoke(p).content.strip()def run_comp_suff(question, context, highlight_indices):    sentences = split_sentences(context)    baseline = answer_from_context(question, context)    ctx_comp = mask_remove(highlight_indices, sentences)    ans_comp = answer_from_context(question, ctx_comp)    comp_drop = 1 if ans_comp != baseline else 0    ctx_suff = mask_except(highlight_indices, sentences) if highlight_indices else ""    ans_suff = answer_from_context(question, ctx_suff) if ctx_suff else "unknown"    suff_drop = 1 if ans_suff != baseline else 0    return {        "baseline": baseline,        "comp_answer": ans_comp,        "suff_answer": ans_suff,        "comprehensiveness_drop": comp_drop,        "sufficiency_drop": suff_drop    }

In [ ]:
# We run a fast sweep with caching to reduce LLM calls.# This is a smoke test across multiple BoolQ questions.import randomfrom statistics import meanfrom functools import lru_cache@lru_cache(maxsize=4096)def llm_cached(prompt):    return llm.invoke(prompt).content.strip()def answer_from_context_cached(q, ctx):    p = f"""Answer using only the context.If the answer is missing, say unknown.Return only the final answer.Question: {q}Context:{ctx}Answer:""".strip()    return llm_cached(p)@lru_cache(maxsize=4096)def support_check_cached(claim, sentence):    check_prompt = f"""You are verifying evidence support.Claim:{claim}Evidence sentence:{sentence}Decide if the evidence sentence explicitly supports the claim.Reply with exactly one token: YES or NO."""    return llm_cached(check_prompt).upper().startswith("YES")def get_questions_from_loader(loader, n=5):    if hasattr(loader, "questions"):        return list(loader.questions)[:n]    if hasattr(loader, "data"):        return [x.get("question") for x in loader.data if "question" in x][:n]    qs = []    for d in documents:        meta = getattr(d, "metadata", {}) or {}        if "question" in meta:            qs.append(meta["question"])    return qs[:n]def run_one_fast(q):    docs = rag.retrieve_documents(q)[:TOP_K_DOCS]    ctx = "\n\n".join(d.page_content for d in docs if d.page_content)    baseline = answer_from_context_cached(q, ctx)    claims = extract_claims(baseline)    cands = candidate_sentences(docs, min(5, MAX_SENTENCES_PER_DOC))    hb = {i: [] for i in range(len(docs))}    for claim in claims[:1]:        found = 0        for di, sents in cands.items():            for s in sents[:5]:                if support_check_cached(claim, s):                    hb[di].append(s)                    found += 1                    if found >= 1:                        break            if found >= 1:                break    sents_full = split_sentences(ctx)    idxs = get_highlight_indices(sents_full, hb)    r = run_comp_suff(q, ctx, idxs)    return r["comprehensiveness_drop"], r["sufficiency_drop"], hbqa_questions = get_questions_from_loader(data_loader, n=5)random.shuffle(qa_questions)comp_scores, suff_scores = [], []for q in tqdm(qa_questions, desc="fast sweep"):    c, s, _ = run_one_fast(q)    comp_scores.append(c)    suff_scores.append(s)print("avg comprehensiveness drop:", round(mean(comp_scores), 3))print("avg sufficiency drop:", round(mean(suff_scores), 3))info = llm_cached.cache_info()print("LLM cache — hits:", info.hits, "misses:", info.misses, "current:", info.currsize, "max:", info.maxsize)