# Counterfactual Evidence-Based Explanation for RAG

This notebook demonstrates a post-hoc explanation method for retrieval-augmented generation (RAG) systems that focuses on **evidence support rather than semantic similarity**.

Given a question and a set of retrieved documents, the system first generates a final answer using the standard RAG pipeline.  
It then explains this answer by identifying **which parts of the retrieved documents explicitly support the answer content**.

Unlike similarity-based or attention-based highlighting methods, this approach treats explanation as an **evidence verification problem**:  
a sentence is highlighted only if it can be shown to directly support the generated answer.

The method is:
- model-agnostic and fully post-hoc
- compatible with local LLMs (e.g. Ollama)
- lightweight and bounded in runtime
- designed to reduce over-highlighting and spurious evidence

The resulting highlights aim to answer the question:  
**“Which retrieved text actually justifies the answer?”**

Con: we use the LLM to awnser the question of justifiying the awnser. If our LLM is poor, this can go wrong.

In [1]:
# Cell 1: Path + import hygiene
import sys
from pathlib import Path

project_root = next((p for p in [Path.cwd()] + list(Path.cwd().parents) if (p / "src").exists()), None)
if project_root is None:
    raise RuntimeError('"src" directory not found. Run notebook inside the repo.')

root_str = str(project_root)
if root_str not in sys.path:
    sys.path.insert(0, root_str)

bad_markers = ["\\.venv\\src\\", "/.venv/src/"]
sys.path = [p for p in sys.path if not any(m in p for m in bad_markers)]

print("project_root:", project_root)
print("sys.path[0]:", sys.path[0])


project_root: c:\Users\Admin\Desktop\XAI\xai-rag
sys.path[0]: c:\Users\Admin\Desktop\XAI\xai-rag


In [2]:
# Cell 2: Runtime checks (python + ollama)
import sys
import requests

print("python:", sys.version)
print("executable:", sys.executable)

r = requests.get("http://localhost:11434/api/tags", timeout=10)
r.raise_for_status()
models = [m["name"] for m in r.json().get("models", [])]
print("ollama models:", models)


python: 3.11.6 (tags/v3.11.6:8b6ee5b, Oct  2 2023, 14:57:12) [MSC v.1935 64 bit (AMD64)]
executable: c:\Users\Admin\Desktop\XAI\.venv\Scripts\python.exe
ollama models: ['qwen3-vl:8b']


In [3]:
# Cell 3: Imports (project + deps)
import re
import html
import numpy as np
from IPython.display import display, HTML
from tqdm.auto import tqdm

from src.modules.rag_engine import RAGEngine
from src.modules.llm_client import LLMClient
from src.modules.data_loader_single_hop import BoolQDataLoader


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Cell 4: Config
OLLAMA_MODEL = "qwen3-vl:8b"   # set to one of your /api/tags models
PERSIST_DIR = "../data/vector_db_bool"

TOP_K_DOCS = 4
MAX_SENTENCES_PER_DOC = 10      # cap work
MAX_SUPPORT_PER_CLAIM = 2       # reduce over-highlighting

print("Config loaded")


Config loaded


In [5]:
# Cell 5: Load data + build vector DB
data_loader = BoolQDataLoader()
documents = data_loader.setup()
print("loaded documents:", len(documents))

rag = RAGEngine(persist_dir=PERSIST_DIR)
rag.setup(documents=documents)

client = LLMClient(provider="ollama", model_name=OLLAMA_MODEL)
llm = client.get_llm()

print("READY")


loaded documents: 9427
Loading existing vector store from ../data/vector_db_bool...
RagEngine ready.
Connecting to local Ollama (qwen3-vl:8b)...
READY


In [6]:
# Cell 6: Ask question + retrieve docs
question = "What is the thrid tallest hotel in the world?"

docs = rag.retrieve_documents(question)
docs = docs[:TOP_K_DOCS]

print("retrieved docs:", len(docs))
for i, d in enumerate(docs):
    print("\n--- DOC", i, "---")
    print((d.page_content or "")[:500])


retrieved docs: 4

--- DOC 0 ---
The Burj Al Arab (Arabic: برج العرب‎, Tower of the Arabs) is a luxury hotel located in Dubai, United Arab Emirates. It is the third tallest hotel in the world (although 39% of its total height is made up of non-occupiable space). Burj Al Arab stands on an artificial island 280 m (920 ft) from Jumeirah Beach and is connected to the mainland by a private curving bridge. The shape of the structure is designed to resemble the sail of a ship. It has a helipad near the roof at a height of 210 m (689 f

--- DOC 1 ---
The Tower of Terror buildings are among the tallest structures found at their respective Disney resorts. At 199 feet (60.7 m), the Florida version is the second tallest attraction at the Walt Disney World Resort, with only Expedition Everest 199.5 feet (60.8 m) being taller. At the Disneyland Resort, the 183-foot (55.8 m) structure (which now houses Guardians of the Galaxy -- Mission: Breakout!) is the tallest building at the resort, as well as on

In [7]:
# Cell 7: Baseline answer (context-only)
context = "\n\n".join(d.page_content for d in docs if d.page_content)

prompt = f"""
Answer using only the context.
If the answer is missing, say unknown.
Return only the final answer.

Question: {question}

Context:
{context}

Answer:
""".strip()

print("Question: ", question)

baseline = llm.invoke(prompt).content.strip()
print("baseline:", baseline)


Question:  What is the thrid tallest hotel in the world?
baseline: Burj Al Arab


In [8]:
# Cell 8: Sentence splitter + HTML highlight helpers
def split_sentences(text: str):
    text = (text or "").strip()
    if not text:
        return []
    return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]

def _find_all_spans(text: str, needle: str):
    spans = []
    if not needle:
        return spans
    start = 0
    n = len(needle)
    while True:
        idx = text.find(needle, start)
        if idx == -1:
            break
        spans.append((idx, idx + n))
        start = idx + max(1, n)
    return spans

def _merge_spans(spans):
    if not spans:
        return []
    spans = sorted(spans, key=lambda x: (x[0], x[1]))
    merged = [spans[0]]
    for s, e in spans[1:]:
        ps, pe = merged[-1]
        if s <= pe:
            merged[-1] = (ps, max(pe, e))
        else:
            merged.append((s, e))
    return merged

def highlight_html_exact(text: str, snippets):
    spans = []
    for snip in snippets:
        snip = (snip or "").strip()
        if not snip:
            continue
        spans.extend(_find_all_spans(text, snip))
    spans = _merge_spans(spans)

    out = []
    last = 0
    for s, e in spans:
        out.append(html.escape(text[last:s]))
        out.append("<mark>")
        out.append(html.escape(text[s:e]))
        out.append("</mark>")
        last = e
    out.append(html.escape(text[last:]))
    return "".join(out)


In [9]:
# Cell 9: Claim extraction (simple and strict)
def extract_claims(answer: str):
    a = (answer or "").strip()
    if not a or a.lower() == "unknown":
        return []
    # For short answers like a color, treat as one claim
    return [a]

claims = extract_claims(baseline)
print("claims:", claims)


claims: ['Burj Al Arab']


In [10]:
# Cell 10: Support check prompt (strict entailment)
def support_check(claim: str, sentence: str):
    check_prompt = f"""
You are verifying evidence support.

Claim:
{claim}

Evidence sentence:
{sentence}

Decide if the evidence sentence explicitly supports the claim.
Reply with exactly one token: YES or NO.
""".strip()

    out = llm.invoke(check_prompt).content.strip().upper()
    return out.startswith("YES")

print("Support checker ready")


Support checker ready


In [11]:
# Cell 11: Candidate sentences (cheap prefilter)
def candidate_sentences(docs, max_sents_per_doc: int):
    per_doc = {}
    for i, d in enumerate(docs):
        sents = split_sentences(d.page_content)
        per_doc[i] = sents[:max_sents_per_doc]
    return per_doc

cands_by_doc = candidate_sentences(docs, MAX_SENTENCES_PER_DOC)
for di, sents in cands_by_doc.items():
    print("doc", di, "candidates:", len(sents))


doc 0 candidates: 5
doc 1 candidates: 4
doc 2 candidates: 5
doc 3 candidates: 2


In [12]:
# Cell 12: Evidence selection per claim (with progress bar)
highlights_by_doc = {i: [] for i in range(len(docs))}

total_checks = sum(len(sents) for sents in cands_by_doc.values()) * max(1, len(claims))
pbar = tqdm(total=total_checks, desc="Checking evidence support")

for claim in claims:
    found = 0
    for doc_idx, sents in cands_by_doc.items():
        for sent in sents:
            pbar.update(1)

            if support_check(claim, sent):
                highlights_by_doc[doc_idx].append(sent)
                found += 1
                if found >= MAX_SUPPORT_PER_CLAIM:
                    break

        if found >= MAX_SUPPORT_PER_CLAIM:
            break

pbar.close()

for i in highlights_by_doc:
    highlights_by_doc[i] = list(dict.fromkeys(highlights_by_doc[i]))
 
highlights_by_doc

Checking evidence support: 100%|██████████| 16/16 [16:18<00:00, 61.13s/it] 


{0: ['Burj Al Arab stands on an artificial island 280 m (920 ft) from Jumeirah Beach and is connected to the mainland by a private curving bridge.'],
 1: [],
 2: [],
 3: []}

In [13]:
# Cell 13: Render output
parts = []
parts.append("<style>mark{padding:0.08em 0.15em; border-radius:3px;}</style>")
parts.append("<h2>Answer</h2>")
parts.append(f"<div style='white-space: pre-wrap;'>{html.escape(baseline)}</div>")
parts.append("<hr/>")

parts.append("<h2>Retrieved Documents</h2>")
parts.append("<p>Highlights are sentences that explicitly support the claim.</p>")

for i, d in enumerate(docs):
    snippets = highlights_by_doc.get(i, [])
    body = highlight_html_exact(d.page_content, snippets) if snippets else html.escape(d.page_content)

    parts.append(f"<h3>Document {i}</h3>")
    parts.append("<div style='white-space: pre-wrap; border: 1px solid #ddd; padding: 10px; border-radius: 6px;'>")
    parts.append(body)
    parts.append("</div><br/>")

display(HTML("".join(parts)))


In [14]:
# Cell 14: Simple table for debugging
import pandas as pd

rows = []
for doc_idx, sents in highlights_by_doc.items():
    for s in sents:
        rows.append({
            "doc_idx": doc_idx,
            "highlighted_sentence": s
        })

df = pd.DataFrame(rows)
df


Unnamed: 0,doc_idx,highlighted_sentence
0,0,Burj Al Arab stands on an artificial island 28...


In [15]:
import html
from IPython.display import display, HTML

In [16]:
def highlight_text(text, snippets):
    out = html.escape(text)
    for s in snippets:
        out = out.replace(
            html.escape(s),
            f"<mark>{html.escape(s)}</mark>"
        )
    return out

html_out = "<h2>Answer</h2>"
html_out += f"<p>{html.escape(baseline)}</p><hr>"

for i, d in enumerate(docs):
    html_out += f"<h3>Document {i}</h3>"
    body = highlight_text(d.page_content, highlights_by_doc.get(i, []))
    html_out += f"<pre>{body}</pre>"

display(HTML(html_out))

In [17]:
# We prepare helpers for removing/keeping evidence and mapping highlights to sentence indices.
# We assume highlights mark full sentences.

def mask_remove(indices, sentences):
    return " ".join([s for i, s in enumerate(sentences) if i not in indices])

def mask_except(indices, sentences):
    return " ".join([s for i, s in enumerate(sentences) if i in indices])

def get_highlight_indices(sentences, highlights_by_doc):
    hl = set()
    for doc_sents in highlights_by_doc.values():
        for h in doc_sents:
            for i, s in enumerate(sentences):
                if h.strip() == s.strip():
                    hl.add(i)
    return sorted(list(hl))

sentences_full = split_sentences(context)
highlight_indices = get_highlight_indices(sentences_full, highlights_by_doc)
print("highlighted sentence indices:", highlight_indices)


highlighted sentence indices: [2]


In [18]:
# We measure drop when removing highlights (comprehensiveness) and when keeping only highlights (sufficiency).
# We compare answers relative to the baseline.

def answer_from_context(question, ctx):
    p = f"""Answer using only the context.
If the answer is missing, say unknown.
Return only the final answer.

Question: {question}

Context:
{ctx}

Answer:""".strip()
    return llm.invoke(p).content.strip()

def run_comp_suff(question, context, highlight_indices):
    sentences = split_sentences(context)
    baseline = answer_from_context(question, context)

    ctx_comp = mask_remove(highlight_indices, sentences)
    ans_comp = answer_from_context(question, ctx_comp)
    comp_drop = 1 if ans_comp != baseline else 0

    ctx_suff = mask_except(highlight_indices, sentences) if highlight_indices else ""
    ans_suff = answer_from_context(question, ctx_suff) if ctx_suff else "unknown"
    suff_drop = 1 if ans_suff != baseline else 0

    return {
        "baseline": baseline,
        "comp_answer": ans_comp,
        "suff_answer": ans_suff,
        "comprehensiveness_drop": comp_drop,
        "sufficiency_drop": suff_drop,
    }

res = run_comp_suff(question, context, highlight_indices)
res


{'baseline': 'Burj Al Arab',
 'comp_answer': 'Burj Al Arab',
 'suff_answer': 'unknown',
 'comprehensiveness_drop': 0,
 'sufficiency_drop': 1}

In [None]:
# Faster sweep: caching, fewer LLM calls, capped sentences, fewer questions.

import random
from statistics import mean
from functools import lru_cache

# global cache for any LLM call (same prompt -> instant result)
@lru_cache(maxsize=4096)
def llm_cached(prompt):
    return llm.invoke(prompt).content.strip()

def answer_from_context_cached(q, ctx):
    p = f"""Answer using only the context.
If the answer is missing, say unknown.
Return only the final answer.

Question: {q}

Context:
{ctx}

Answer:""".strip()
    return llm_cached(p)

@lru_cache(maxsize=4096)
def support_check_cached(claim, sentence):
    check_prompt = f"""You are verifying evidence support.

Claim:
{claim}

Evidence sentence:
{sentence}

Decide if the evidence sentence explicitly supports the claim.
Reply with exactly one token: YES or NO."""
    return llm_cached(check_prompt).upper().startswith("YES")


def run_one_fast(q):
    docs = rag.retrieve_documents(q)[:TOP_K_DOCS]
    ctx = "\n\n".join(d.page_content for d in docs if d.page_content)

    baseline = answer_from_context_cached(q, ctx)
    claims = extract_claims(baseline)

    cands = candidate_sentences(docs, min(5, MAX_SENTENCES_PER_DOC))

    hb = {i: [] for i in range(len(docs))}
    for claim in claims[:1]:  # only first claim to save time
        found = 0
        for di, sents in cands.items():
            for s in sents[:5]:
                if support_check_cached(claim, s):
                    hb[di].append(s)
                    found += 1
                    if found >= 1:
                        break
            if found >= 1:
                break

    sents_full = split_sentences(ctx)
    idxs = get_highlight_indices(sents_full, hb)

    r = run_comp_suff(q, ctx, idxs)
    return r["comprehensiveness_drop"], r["sufficiency_drop"]


qa_questions = get_questions_from_loader(data_loader, n=5)
random.shuffle(qa_questions)

comp_scores, suff_scores = [], []

for q in tqdm(qa_questions, desc="fast sweep"):
    c, s = run_one_fast(q)
    comp_scores.append(c)
    suff_scores.append(s)

print("avg comprehensiveness drop:", round(mean(comp_scores), 3))
print("avg sufficiency drop:", round(mean(suff_scores), 3))
info = llm_cached.cache_info()
print("LLM cache — hits:", info.hits, "misses:", info.misses, "current:", info.currsize, "max:", info.maxsize)


fast sweep: 100%|██████████| 5/5 [36:44<00:00, 440.94s/it]

avg comprehensiveness drop: 0.6
avg sufficiency drop: 0.8





AttributeError: 'CacheInfo' object has no attribute 'cache'

LLM cache — hits: 0 misses: 29 current: 29 max: 4096
