# KGRAG Ex demo: MedMCQA (medical dataset)

This notebook demonstrates KGRAG Ex on MedMCQA, a medical multiple-choice dataset.

Workflow:
1. Load StatPearls as LangChain Documents.
2. Build the vector store and retrieval setup.
3. Build a knowledge graph from triplets extracted from the same documents.
4. Run the KGRAG Ex pipeline and convert KG paths into pseudo paragraphs.
5. Inspect explanations, optional perturbations, and report stats.



## Notebook setup and project path

Ensure the project root is on `sys.path` so local modules import correctly.


In [1]:
import sys
from pathlib import Path

project_root = next((p for p in [Path.cwd()] + list(Path.cwd().parents) if (p / "src").exists()), None)
if project_root is None:
    raise RuntimeError('"src" Verzeichnis nicht gefunden. Bitte Notebook im Projekt laufen lassen.')

root_str = str(project_root)
if root_str not in sys.path:
    sys.path.insert(0, root_str)

print("Project root:", project_root)


Project root: /data/home/hod31575/xai-rag


## Imports

Load data loaders, RAG components, KG builders, and the KGRAG Ex pipeline.


In [2]:
# %%
from pathlib import Path

from src.modules.loader.statspearls_data_loader import StatPearlsDataLoader
#from src.modules.loader.medmcqa_data_loader import MedMCQADataLoader

from src.modules.rag.rag_engine import RAGEngine
from src.modules.llm.llm_client import LLMClient

from src.modules.knowledge_graph.kg_triplet_extractor import KGTripletExtractor
from src.modules.knowledge_graph.kg_build_service import KGBuildService

from src.modules.knowledge_graph.kgrag_ex_pipeline import KGRAGExPipeline
from src.modules.explainers.kgrag_ex_explainer import KGRAGExExplainer
from src.modules.knowledge_graph.relation_registry import RelationRegistry, ProposedRelation, canon_relation


2026-01-08 15:40:09.867960: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-08 15:40:09.905810: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## 1. Load StatPearls, one Document per book

For a smoke test, keep only a few books and limit `max_chars_per_book`, otherwise KG building gets expensive.


In [3]:
sp_loader = StatPearlsDataLoader()

statpearls_docs, sp_stats = sp_loader.setup(
    limit_articles=300,
    as_documents=True,
    force_download=False,
    force_extract=False,
    force_rebuild_jsonl=False,
)

print("StatPearls stats:", sp_stats)
print("Docs:", len(statpearls_docs))
print("Sample meta:", statpearls_docs[0].metadata)
print("Sample preview:", statpearls_docs[0].page_content[:200])




StatPearls stats: StatPearlsBuildStats(tarball_downloaded=False, extracted=False, nxml_files_found=9629, jsonl_files_created=0, articles_loaded=300, chunks_emitted=13883)
Docs: 13883
Sample meta: {'source': 'statpearls', 'split': 'repo', 'title': 'Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings through hospital discharge.', 'topic_name': 'Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings through hospital discharge.', 'source_filename': 'article-100024.nxml', 'chunk_index': 0, 'chunk_id': '6de0970aebeb1d82231d62346489c3cd6ee24eebe32979badfdeae7817d4c9bb'}
Sample preview: Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings throu

## 2. RAG engine setup

Build a fresh Chroma vector store from the documents.


In [4]:
rag = RAGEngine(persist_dir="../data/vector_db_statpearls_kgragex")
rag.setup(documents=statpearls_docs, reset=False)


Loading existing vector store from ../data/vector_db_statpearls_kgragex...
RagEngine ready.


## 3. LLM client

Ollama runs locally; adjust the model name to your installation.


In [5]:
client = LLMClient(provider="ollama", model_name="llama3.2:latest")

## 4. Build the KG from StatPearls chunks (JSONL cache)

In [6]:
registry_path = Path("../data/relations_registry.json")
registry = RelationRegistry.load(registry_path)

extractor = KGTripletExtractor(llm_client=client, max_retries=2, relation_registry=registry)
kg_service = KGBuildService(extractor=extractor, relation_registry=registry, registry_cache_path=registry_path, add_reverse_edges=False)


kg_store, kg_stats = kg_service.build_or_load(
    docs=statpearls_docs,
    cache_path=Path("../data/statpearls_kg/kg_statpearls.jsonl"),
    limit=13922,
    force_rebuild=False,
    chunk_id_prefix="sp",
    source_name="statpearls",
)

print("KG stats:", kg_store.stats())
print("Build stats:", kg_stats)



Connecting to local Ollama (llama3.2:latest)...


KeyboardInterrupt: 

## 5. Quick KG stats


In [None]:
import numpy as np

nodes = [str(n) for n in kg_store.g.nodes]
token_counts = np.array([len(n.split()) for n in nodes])
print("Nodes:", len(nodes))
print("Tokens per node min/median/mean/max:", token_counts.min(), np.median(token_counts), token_counts.mean(), token_counts.max())

long_nodes = [n for n in nodes if len(n.split()) > 10]
print("Nodes with >10 tokens:", len(long_nodes))
print("Example long nodes:", long_nodes[:5])

### Graph connectivity and relation stats

Inspect connectivity, degree distribution, and relation frequencies in the KG.


In [None]:
import networkx as nx
from collections import Counter
import numpy as np

G = kg_store.g
print("Nodes:", G.number_of_nodes(), "Edges:", G.number_of_edges())

wcc = list(nx.weakly_connected_components(G)) if G.number_of_nodes() else []
sizes = sorted([len(c) for c in wcc], reverse=True) if wcc else [0]
print("Weakly connected components:", len(wcc))
print("Largest component size:", sizes[0])

deg = np.array([d for _, d in G.degree()]) if G.number_of_nodes() else np.array([0])
print("Degree min/median/mean/max:", int(deg.min()), float(np.median(deg)), float(deg.mean()), int(deg.max()))

rel_counts = Counter()
for _, _, data in G.edges(data=True):
    rel_counts[str(data.get("relation", "Unknown"))] += 1
print("Top relations:", rel_counts.most_common(12))

In [None]:
u, v, data = next(iter(G.edges(data=True)))
print("u:", u)
print("v:", v)
print("edge keys:", list(data.keys()))
print("edge data:", data)


In [None]:
from collections import Counter

type_count = Counter(G.nodes[n].get("type", "Unknown") for n in G.nodes)
print(type_count.most_common(12))


In [None]:
DG = nx.DiGraph(
    (u, v, d) for u, v, d in G.edges(data=True)
    if not any(obs.get("meta", {}).get("synthetic") for obs in d.get("observations", []))
)

print(nx.is_directed_acyclic_graph(DG))

longest_path = nx.dag_longest_path(DG)
print("Longest DAG path length:", len(longest_path) - 1)
print("Path:", longest_path)


In [None]:
import networkx as nx

G = kg_store.g
GU = G.to_undirected()

# größte Komponente wählen
largest_comp = max(nx.connected_components(GU), key=len)
H = GU.subgraph(largest_comp).copy()

# 2x BFS Heuristik für Durchmesserpfad
n0 = next(iter(H.nodes))
u = max(nx.single_source_shortest_path_length(H, n0).items(), key=lambda x: x[1])[0]
v, dist = max(nx.single_source_shortest_path_length(H, u).items(), key=lambda x: x[1])

path = nx.shortest_path(H, u, v)

print("Largest component size:", H.number_of_nodes())
print("Approx diameter (edges):", dist)
print("Path nodes:", path)


In [None]:
path_with_rels = []
for a, b in zip(path, path[1:]):
    # da wir auf undirected arbeiten, holen wir relation aus einer Richtung, falls vorhanden
    if G.has_edge(a, b):
        rel = G[a][b].get("relation")
    elif G.has_edge(b, a):
        rel = G[b][a].get("relation")
    else:
        rel = None
    path_with_rels.append((a, rel, b))

for triple in path_with_rels:
    print(triple)


## 6. Pipeline and explainer


In [None]:
pipeline = KGRAGExPipeline(rag=rag, kg=kg_store, llm_client=client)
explainer = KGRAGExExplainer(pipeline)

## 7. Load MedMCQA, small sample


In [None]:
mcqa_loader = MedMCQADataLoader()
questions = mcqa_loader.setup(split="train", as_documents=True, limit=8)

print("MedMCQA questions:", len(questions))
print("Sample question:", questions[0].metadata.get("question"))
print("Gold:", questions[0].metadata.get("answer"))

## 8. Single-run test on one MedMCQA question


In [None]:
d = questions[0]

qid = str(d.metadata.get("question_id", "q-0"))
question_text = str(d.metadata.get("question") or d.page_content or "").strip()
gold = d.metadata.get("answer")

run = pipeline.run(
    question_id=qid,
    question=question_text,
    gold_answer=gold,
    entity_k=6,
    top_k_docs=4,
)

print("Question ID:", run.question_id)
print("Question:", run.question)
print("Gold:", run.gold_answer)
print("Entities:", run.entities)
print("Start, End:", run.start, run.end)
print("Path length:", len(run.path))
print("KG context preview:", (run.kg_context or "")[:600])
print("Answer:", run.answer)
print("LLM calls:", run.llm_calls)

## 9. Optional: KG-level perturbations


In [None]:
report = explainer.explain(run)

print("Most influential node:", report.most_influential_node)
print("Most influential edge:", report.most_influential_edge)
print("Most influential subpath:", report.most_influential_subpath)

print("\nFirst 10 perturbation outcomes:")
for o in report.outcomes[:10]:
    print({"kind": o.kind, "removed": o.removed, "answer_changed": o.answer_changed, "answer": (o.answer or "")[:50]})

## 10. Mini-batch test


In [None]:
import pandas as pd

rows = []
for d in questions:
    qid = str(d.metadata.get("question_id", "q"))
    qtxt = str(d.metadata.get("question") or d.page_content or "").strip()
    gold = d.metadata.get("answer")

    run_i = pipeline.run(
        question_id=qid,
        question=qtxt,
        gold_answer=gold,
        entity_k=6,
        top_k_docs=4,
    )
    rep_i = explainer.explain(run_i)
    num_changed = sum(1 for o in rep_i.outcomes if o.answer_changed)

    rows.append(
        {
            "question_id": run_i.question_id,
            "gold": run_i.gold_answer,
            "path_len": len(run_i.path),
            "entities": ", ".join(run_i.entities[:6]),
            "answer": (run_i.answer or "")[:10],
            "perturbations": len(rep_i.outcomes),
            "answer_changes": num_changed,
            "most_node": rep_i.most_influential_node,
            "most_edge": rep_i.most_influential_edge,
            "most_subpath": rep_i.most_influential_subpath,
        }
    )

df = pd.DataFrame(rows)
df

## 11. Optional: compact stats for your RQs

RQ2 (position proxy) depends on explainer implementation details.  
RQ3 (node types) depends on which entity types your triplet extractor emits.  
RQ4 (graph metrics) depends on whether your metrics module computes degree and betweenness.

If your report exposes these fields, we summarize them here.


In [None]:
if getattr(report, "rq2_positions", None):
    print("RQ2 positions sample:", report.rq2_positions[:5])

if getattr(report, "rq3_node_types", None):
    print("RQ3 node types sample:", report.rq3_node_types[:10])

if getattr(report, "rq4_graph_metrics", None):
    print("RQ4 graph metrics sample:", report.rq4_graph_metrics[:5])


## 12. Try your own question

Use this to quickly check whether the KG has enough coverage.

Note:
If `path_len` is often 0, it's usually an entity extraction/matching or KG coverage issue, not a RAG issue.


In [None]:
qid = "manual-1"
question = "How does CTO relates to Cardiomyopathy"

run_m = pipeline.run(
    question_id=qid,
    question=question,
    gold_answer=None,
    entity_k=6,
    top_k_docs=4,
)

print("Entities:", run_m.entities)
print("Start, End:", run_m.start, run_m.end)
print("Path length:", len(run_m.path))
print("KG context preview:", (run_m.kg_context or "")[:400])
print("Answer:", run_m.answer)

rep_m = explainer.explain(run_m)
print("Most influential node:", rep_m.most_influential_node)
print("Most influential edge:", rep_m.most_influential_edge)
print("Most influential subpath:", rep_m.most_influential_subpath)


In [None]:
qid = "manual-1"
question = "How does Tension pneumothorax relates to Septicemia"

run_m = pipeline.run(
    question_id=qid,
    question=question,
    gold_answer=None,
    entity_k=6,
    top_k_docs=4,
)

print("Entities:", run_m.entities)
print("Start, End:", run_m.start, run_m.end)
print("Path length:", len(run_m.path))
print("KG context preview:", (run_m.kg_context or "")[:400])
print("Answer:", run_m.answer)

rep_m = explainer.explain(run_m)
print("Most influential node:", rep_m.most_influential_node)
print("Most influential edge:", rep_m.most_influential_edge)
print("Most influential subpath:", rep_m.most_influential_subpath)


### Optional: print RQ metrics for the manual question

Surface the RQ fields directly for quick inspection.


In [None]:
print("sensitivity:", rep_m.rq1_sensitivity)
print("Postion:", rep_m.rq2_positions)
print("Node types:", rep_m.rq3_node_types)
print("graph_metrics:", rep_m.rq4_graph_metrics)

### Debug: graph connectivity between start/end nodes

Check whether the start/end nodes exist, inspect neighbors, and probe shortest paths.


In [None]:
import networkx as nx

G = kg_store.g
start = run_m.start
end = run_m.end

print("start exists:", start in G)
print("end exists:", end in G)

if start in G:
    print("\nStart neighbors (out):", list(G.successors(start))[:30])
    print("Start neighbors (in):", list(G.predecessors(start))[:30])

if end in G:
    print("\nEnd neighbors (out):", list(G.successors(end))[:30])
    print("End neighbors (in):", list(G.predecessors(end))[:30])

if start in G and end in G:
    GU = G.to_undirected(as_view=True)
    try:
        p = nx.shortest_path(GU, start, end)
        print("\nUndirected shortest path length:", len(p) - 1)
        print("Path nodes:", p[:20], "..." if len(p) > 20 else "")
    except Exception as e:
        print("\nNo undirected path:", e)

# Additional: search for similar node names
def find_similar_nodes(q, limit=20):
    ql = (q or "").lower()
    hits = [n for n in G.nodes if ql in str(n).lower()]
    return hits[:limit]

print("\nSimilar to 'CTO':", find_similar_nodes("CTO", 30))
print("Similar to 'Chronic Total Occlusion':", find_similar_nodes("Chronic Total Occlusion", 30))
print("Similar to 'Stroke':", find_similar_nodes("Stroke", 30))
print("Similar to 'Ischemic Heart Disease':", find_similar_nodes("Ischemic Heart Disease", 30))
