# Counterfactual Evidence-Based Explanation for RAG (Updated)

This notebook implements a post-hoc explanation method for retrieval-augmented generation (RAG) systems. It focuses on **evidence support** rather than simple semantic similarity. Given a question and a set of retrieved documents, the RAG engine answers the question and then highlights **which parts of the retrieved documents directly support the answer**.

This version corrects earlier issues by working with sentence indices instead of raw text when selecting evidence, ensuring proper mapping from document chunks back to sentences. It also skips evaluation for questions where no supportive evidence can be identified, preventing meaningless metrics.


In [None]:

# Set up environment and paths
import sys
from pathlib import Path

NOTEBOOK_DIR = Path.cwd()
project_root = None
for p in [NOTEBOOK_DIR] + list(NOTEBOOK_DIR.parents):
    if (p / 'src').exists():
        project_root = p
        break
if project_root is None:
    raise RuntimeError('Notebook must live inside the repo')

sys.path.insert(0, str(project_root))
print('Project root:', project_root)


Project root: c:\Users\Admin\Desktop\XAI\FINAL\xai-rag


In [None]:

# Load configuration settings
import tomllib
config_path = project_root / 'config.toml'
with open(config_path, 'rb') as f:
    cfg = tomllib.load(f)

med_cfg = cfg['medmcqa']
rag_cfg = cfg['rag']
llm_cfg = cfg['llm']

QUESTION_IDS = med_cfg['question_ids']
KG_CAPABLE = set(med_cfg.get('kg_capable', []))
SPLIT = med_cfg['split']

print('MedMCQA questions:', len(QUESTION_IDS))
print('KG-capable questions:', len(KG_CAPABLE))
print('RAG hops:', rag_cfg['n_hops'])
print('LLM config:', llm_cfg)


MedMCQA questions: 59
KG-capable questions: 12
RAG hops: 2
LLM config: {'provider': 'ollama', 'model': 'gemma3:4b'}


In [None]:

# Load MedMCQA questions for evaluation
from src.modules.loader.medmcqa_data_loader import MedMCQADataLoader

loader = MedMCQADataLoader()
documents = loader.setup()
print('Loaded questions:', len(documents))

# Display a few examples
for d in documents[:3]:
    print('Question:', d.metadata['question'])
    print('Answer letter:', d.metadata['answer'])
    print('KG-capable:', d.metadata['question_id'] in KG_CAPABLE)


Loaded questions: 59
Question: Which of the following agents is likely to cause cerebral calcification and hydrocephalus in a newborn whose mother has history of taking spiramycin but was not compliant with therapy?
Answer letter: B
KG-capable: False
Question: Myocarditis is caused bya) Pertussisb) Measlesc) Diptheriad) Scorpion sting
Answer letter: A
KG-capable: False
Question: Childhood osteopetrosis is characterized by – a) B/L frontal bossingb) Multiple # (fracture)c) Hepatosplenomegalyd) Cataracte) Mental retardation
Answer letter: A
KG-capable: True


In [None]:

# Initialize the LLM
from src.modules.llm.llm_client import LLMClient

llm_client = LLMClient(provider=llm_cfg['provider'], model_name=llm_cfg['model'])
llm = llm_client.get_llm()
print('LLM ready:', llm)


Connecting to local Ollama (gemma3:4b)...
LLM ready: model='gemma3:4b' reasoning=False temperature=0.0


In [None]:

# Load StatPearls corpus and set up RAG engine
from src.modules.loader.statspearls_data_loader import StatPearlsDataLoader
from src.modules.rag.rag_engine import RAGEngine

sp_loader = StatPearlsDataLoader()
sp_output = sp_loader.setup()

# Unpack documents and stats
if isinstance(sp_output, tuple):
    statpearls_docs, sp_stats = sp_output
    print('StatPearls build stats:', sp_stats)
else:
    statpearls_docs = sp_output

# Flatten nested lists
if len(statpearls_docs) > 0 and isinstance(statpearls_docs[0], list):
    statpearls_docs = [doc for article in statpearls_docs for doc in article]

print('StatPearls chunks loaded:', len(statpearls_docs))
print('Sample StatPearls metadata keys:', statpearls_docs[0].metadata.keys())

# Initialize RAG engine over StatPearls
rag = RAGEngine(persist_dir='../data/vector_db_statpearls')
rag.setup(documents=statpearls_docs, reset=False, k=rag_cfg['n_hops'])
print('RAG engine initialized')


StatPearls build stats: StatPearlsBuildStats(tarball_downloaded=False, extracted=False, nxml_files_found=9629, jsonl_files_created=0, articles_loaded=300, chunks_emitted=13922)
StatPearls chunks loaded: 13922
Sample StatPearls metadata keys: dict_keys(['source', 'split', 'title', 'topic_name', 'source_filename', 'chunk_index', 'chunk_id'])
Loading existing vector store from ../data/vector_db_statpearls...
RagEngine ready.
RAG engine initialized


In [None]:

# Retrieve some documents for a sample question to inspect retrieval
sample_question = documents[0].metadata['question']
sample_docs = rag.retrieve_documents(sample_question)
print('Retrieved', len(sample_docs), 'documents')
for i, d in enumerate(sample_docs[:2]):
    print(f'Doc {i}:')
    print(d.page_content[:300])


Retrieved 2 documents
Doc 0:
programs on a large scale that offer both maternal or neonatal screening to identify infection in mothers and infants. No vaccines are present to prevent infection, and no efficacious and safe therapies are available for the treatment of maternal or fetal CMV infection. [22] In some setups, gancyclo
Doc 1:
of infant head lag are linked to conditions causing neonatal/infantile hypotonia. These conditions constitute a differential diagnosis list to consider when assessing an infant with persistent or severe head lag. This includes chromosome disorders (ie, Prader Willi), Hypoxic-ischemic injuries, cereb


In [None]:

# Helper functions for splitting text and highlighting
import re
import html

def split_sentences(text: str):
    text = (text or '').strip()
    if not text:
        return []
    return [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]

def highlight_html(text: str, snippets):
    if not snippets:
        return html.escape(text)
    out = html.escape(text)
    for s in snippets:
        out = out.replace(html.escape(s), f"<mark>{html.escape(s)}</mark>")
    return out



In [None]:

# Generate candidate windows of sentences for each document

def candidate_windows(docs, max_sents_per_doc: int, window_size: int = 3):
    per_doc = {}
    for i, d in enumerate(docs):
        sents = split_sentences(d.page_content)[:max_sents_per_doc]
        windows = []
        for idx in range(len(sents)):
            start = max(0, idx - (window_size - 1))
            window = ' '.join(sents[start: idx+1])
            windows.append({
                'sentence': sents[idx],
                'window': window,
                'sent_idx': idx
            })
        per_doc[i] = windows
    return per_doc


In [None]:

# Extract claims from baseline answer, resolving MCQ letters to option text

def extract_claims(answer: str, question_doc=None):
    a = (answer or '').strip()
    if not a or a.lower() == 'unknown':
        return []
    # If answer is descriptive
    if len(a) > 2:
        return [a]
    # Resolve MC letters
    if question_doc is not None and a.upper() in {'A','B','C','D'}:
        raw = question_doc.metadata.get('cop_raw')
        if isinstance(raw, str):
            for line in raw.splitlines():
                if line.strip().startswith(a.upper()):
                    return [line.split(':', 1)[-1].strip()]
    return []


In [None]:

# Check if a window supports the claim using the LLM

def support_check(claim: str, window: str):
    prompt = (
        f'''You are verifying evidence support.
        Claim:
        {claim}
        Evidence context:
        {window}
        Decide if the evidence context states or clearly implies that the claim is correct.
        Reply with exactly one token: YES or NO.'''
    )
    out = llm.invoke(prompt).content.strip().upper()
    return out.startswith('YES')

print('Support checker ready')


Support checker ready


In [None]:

# Helpers for removing/keeping sentences and computing curve prefixes

def remove_indices(indices, sentences):
    ix = set(indices)
    return ' '.join([s for i, s in enumerate(sentences) if i not in ix])

def keep_indices(indices, sentences):
    ix = set(indices)
    return ' '.join([s for i, s in enumerate(sentences) if i in ix])

def curve_prefixes(indices, steps):
    if not indices:
        return [set() for _ in range(steps)]
    n = len(indices)
    cuts = []
    for k in range(1, steps+1):
        m = min(n, max(1, (k*n + steps - 1)//steps))
        cuts.append(set(indices[:m]))
    return cuts


In [None]:
# Run one counterfactual evaluation for a question

import time
import math
import random
import re

LLM_CALLS = 0
LLM_TIME_S = 0.0

def timed_invoke(prompt: str) -> str:
    global LLM_CALLS, LLM_TIME_S
    t0 = time.time()
    out = llm.invoke(prompt).content.strip()
    LLM_CALLS += 1
    LLM_TIME_S += (time.time() - t0)
    return out

def norm_choice(x: str) -> str:
    x = (x or "").strip().upper()
    m = re.search(r"\b([ABCD])\b", x)
    return m.group(1) if m else ""

def answer_choice_from_context(question: str, context: str) -> str:
    prompt = f'''
You are answering a multiple-choice question using only the provided context.

Return exactly one token: A or B or C or D.
If the answer cannot be derived from the context, return exactly: UNKNOWN

Question:
{question}

Context:
{context}

Answer:
'''.strip()
    out = timed_invoke(prompt)
    c = norm_choice(out)
    return c if c in {"A","B","C","D"} else "UNKNOWN"

def answer_text_from_context(question: str, context: str) -> str:
    prompt = f'''
Answer the question using only the provided context.

Return a short noun phrase (not a letter choice).
If the answer cannot be derived from the context, return exactly: UNKNOWN

Question:
{question}

Context:
{context}

Answer:
'''.strip()
    out = timed_invoke(prompt).strip()
    return out if out and out.upper() != "UNKNOWN" else "UNKNOWN"

def curve_prefixes(indices, steps: int):
    if not indices:
        return [set() for _ in range(steps)]
    n = len(indices)
    cuts = []
    for k in range(1, steps + 1):
        m = min(n, max(1, math.ceil(k * n / steps)))
        cuts.append(set(indices[:m]))
    return cuts

def run_one(
    question_doc,
    curve_steps: int = 6,
    max_support_per_claim: int = 3,
    max_sentences_per_doc: int = 25,
):
    question = question_doc.metadata["question"]
    gold_choice = norm_choice(str(question_doc.metadata.get("answer", "")))

    docs = rag.retrieve_documents(question)
    context = "\n\n".join(d.page_content for d in docs if d.page_content)

    pred_choice = answer_choice_from_context(question, context)
    pred_text = answer_text_from_context(question, context)

    claims = [] if pred_text == "UNKNOWN" else [pred_text]

    doc_offsets = []
    global_sents = []
    for d in docs:
        sents = split_sentences(d.page_content)
        doc_offsets.append(len(global_sents))
        global_sents.extend(sents)

    cands = candidate_windows(docs, max_sents_per_doc=max_sentences_per_doc, window_size=3)

    global_indices = set()
    if claims:
        for claim in claims:
            found = 0
            for doc_idx, items in cands.items():
                for item in items:
                    if support_check(claim, item["window"]):
                        global_idx = doc_offsets[doc_idx] + item["sent_idx"]
                        global_indices.add(global_idx)
                        found += 1
                        if found >= max_support_per_claim:
                            break
                if found >= max_support_per_claim:
                    break

    hl_indices = sorted(global_indices)

    ctx_wo = remove_indices(hl_indices, global_sents)
    ctx_only = keep_indices(hl_indices, global_sents)

    pred_wo = answer_choice_from_context(question, ctx_wo) if ctx_wo.strip() else "UNKNOWN"
    pred_only = answer_choice_from_context(question, ctx_only) if ctx_only.strip() else "UNKNOWN"

    comprehensiveness = int(pred_wo != pred_choice)
    sufficiency = int(pred_only == pred_choice)

    prefixes = curve_prefixes(hl_indices, curve_steps)

    del_curve = []
    for pref in prefixes:
        c = remove_indices(sorted(pref), global_sents)
        a = answer_choice_from_context(question, c) if c.strip() else "UNKNOWN"
        del_curve.append(int(a != pred_choice))
    deletion_auc = sum(del_curve) / len(del_curve) if del_curve else 0.0

    ins_curve = []
    for pref in prefixes:
        c = keep_indices(sorted(pref), global_sents)
        a = answer_choice_from_context(question, c) if c.strip() else "UNKNOWN"
        ins_curve.append(int(a == pred_choice))
    insertion_auc = sum(ins_curve) / len(ins_curve) if ins_curve else 0.0

    task_correct = (
        int(pred_choice == gold_choice)
        if (pred_choice in {"A","B","C","D"} and gold_choice in {"A","B","C","D"})
        else 0
    )

    return {
        "question_id": question_doc.metadata.get("question_id"),
        "question": question,
        "gold_choice": gold_choice,
        "pred_choice": pred_choice,
        "pred_text": pred_text,
        "task_correct": task_correct,
        "comprehensiveness": comprehensiveness,
        "sufficiency": sufficiency,
        "deletion_auc": deletion_auc,
        "insertion_auc": insertion_auc,
        "n_sentences": len(global_sents),
        "n_highlighted": len(hl_indices),
        "highlight_fraction": (len(hl_indices) / max(1, len(global_sents))),
        "has_evidence": int(len(hl_indices) > 0),
    }

print("Evaluation core ready.")

Evaluation core ready.


In [None]:
# Evaluate across all MedMCQA questions (StatPearls corpus)

import time
import random
import pandas as pd

random.seed(0)

t0 = time.time()
rows = [run_one(doc) for doc in documents]
wall_time_s = time.time() - t0

df = pd.DataFrame(rows)

summary = {
    "n_questions": int(len(df)),
    "task_accuracy": float(df["task_correct"].mean()),
    "comprehensiveness": float(df["comprehensiveness"].mean()),
    "sufficiency": float(df["sufficiency"].mean()),
    "deletion_auc": float(df["deletion_auc"].mean()),
    "insertion_auc": float(df["insertion_auc"].mean()),
    "highlight_fraction": float(df["highlight_fraction"].mean()),
    "avg_sentences_per_context": float(df["n_sentences"].mean()),
    "avg_highlighted_sentences": float(df["n_highlighted"].mean()),
    "evidence_coverage": float(df["has_evidence"].mean()),
    "total_llm_calls": int(LLM_CALLS),
    "total_llm_time_s": float(LLM_TIME_S),
    "avg_llm_call_time_s": float(LLM_TIME_S / max(1, LLM_CALLS)),
    "wall_time_s": float(wall_time_s),
}

print("=== EVALUATION SUMMARY ===")
for k, v in summary.items():
    if isinstance(v, float):
        print(f"{k}: {v:.4f}")
    else:
        print(f"{k}: {v}")

display(pd.DataFrame([summary]))
df.head(10)

=== EVALUATION SUMMARY ===
n_questions: 59
task_accuracy: 0.1186
comprehensiveness: 0.4237
sufficiency: 0.5932
deletion_auc: 0.4011
insertion_auc: 0.5706
highlight_fraction: 0.3093
avg_sentences_per_context: 10.2542
avg_highlighted_sentences: 2.9492
evidence_coverage: 0.9831
total_llm_calls: 937
total_llm_time_s: 191.7426
avg_llm_call_time_s: 0.2046
wall_time_s: 228.2352


Unnamed: 0,n_questions,task_accuracy,comprehensiveness,sufficiency,deletion_auc,insertion_auc,highlight_fraction,avg_sentences_per_context,avg_highlighted_sentences,evidence_coverage,total_llm_calls,total_llm_time_s,avg_llm_call_time_s,wall_time_s
0,59,0.118644,0.423729,0.59322,0.40113,0.570621,0.309289,10.254237,2.949153,0.983051,937,191.742582,0.204635,228.235206


Unnamed: 0,question_id,question,gold_choice,pred_choice,pred_text,task_correct,comprehensiveness,sufficiency,deletion_auc,insertion_auc,n_sentences,n_highlighted,highlight_fraction,has_evidence
0,7f444937-f1ae-403c-9427-34f6d5c18aa6,Which of the following agents is likely to cau...,B,A,Spiramycin,0,0,0,0.0,0.0,8,3,0.375,1
1,6ffd899a-4d4b-4216-b8ac-6e16a4b0daa1,Myocarditis is caused bya) Pertussisb) Measles...,A,B,Diphtheria,0,0,0,0.666667,0.0,11,3,0.272727,1
2,3f63787d-7816-48fe-a623-b61ba10a3001,Childhood osteopetrosis is characterized by – ...,A,A,frontal bossing,1,0,1,0.0,1.0,12,3,0.25,1
3,f7c6e673-3268-4c7a-abf2-dff2426a1ae0,The main function of Vitamin C in the body is -,C,UNKNOWN,Vitamin C’s function,0,1,1,0.666667,1.0,13,3,0.230769,1
4,80a922e3-e55d-4cdc-8a90-3100c3647e99,"The triad of hypertension, bradycardia and irr...",A,UNKNOWN,congestive heart failure,0,0,1,0.0,1.0,12,3,0.25,1
5,71453fe5-854a-4456-9012-04208d5132c2,Macrosomia is/are associated with:a) Gestation...,D,B,Maternal complications,0,0,0,0.666667,0.666667,4,3,0.75,1
6,8409ee38-1922-4ac9-9178-ba699e33e643,Long term use of lithium causes -,C,A,suicidal effects,0,1,1,1.0,0.666667,13,3,0.230769,1
7,9d0968c7-0bfa-42c5-9e1f-1cf35c40b36e,Dupuytren's contracture mvolves-,B,A,hand contractures,0,0,0,0.0,0.0,15,3,0.2,1
8,d4f06476-b47f-4bf7-9d32-531db1974a8e,Mechanism of action of beta-Lactam antibiotics...,B,A,Beta-lactam action,0,0,0,0.166667,0.0,10,3,0.3,1
9,b452daac-d602-4482-a09a-6294405b6f61,Iron absorption is increased by which of the f...,D,A,Dietary elements,0,1,0,1.0,0.0,13,3,0.230769,1
