# KGRAG Ex demo: MedMCQA (medical dataset)

This notebook demonstrates KGRAG Ex on MedMCQA, a medical multiple-choice dataset.

Workflow:
1. Load StatPearls as LangChain Documents.
2. Build the vector store and retrieval setup.
3. Build a knowledge graph from triplets extracted from the same documents.
4. Run the KGRAG Ex pipeline and convert KG paths into pseudo paragraphs.
5. Inspect explanations, optional perturbations, and report stats.



## Notebook setup and project path

Ensure the project root is on `sys.path` so local modules import correctly.


In [36]:
import sys
from pathlib import Path

project_root = next((p for p in [Path.cwd()] + list(Path.cwd().parents) if (p / "src").exists()), None)
if project_root is None:
    raise RuntimeError('"src" Verzeichnis nicht gefunden. Bitte Notebook im Projekt laufen lassen.')

root_str = str(project_root)
if root_str not in sys.path:
    sys.path.insert(0, root_str)

print("Project root:", project_root)


Project root: /Users/MeinNotebook/xai-rag


## Imports

Load data loaders, RAG components, KG builders, and the KGRAG Ex pipeline.


In [37]:
# %%
from pathlib import Path

from src.modules.loader.statspearls_data_loader import StatPearlsDataLoader
from src.modules.loader.medmcqa_data_loader import MedMCQADataLoader

from src.modules.rag.rag_engine import RAGEngine
from src.modules.llm.llm_client import LLMClient

from src.modules.knowledge_graph.kg_triplet_extractor import KGTripletExtractor
from src.modules.knowledge_graph.kg_build_service import KGBuildService

from src.modules.knowledge_graph.kgrag_ex_pipeline import KGRAGExPipeline
from src.modules.explainers.kgrag_ex_explainer import KGRAGExExplainer
from src.modules.knowledge_graph.relation_registry import RelationRegistry, ProposedRelation, canon_relation


## 1. Load StatPearls, one Document per book

For a smoke test, keep only a few books and limit `max_chars_per_book`, otherwise KG building gets expensive.


In [38]:
sp_loader = StatPearlsDataLoader()

statpearls_docs, sp_stats = sp_loader.setup(
    limit_articles=300,
    as_documents=True,
    force_download=False,
    force_extract=False,
    force_rebuild_jsonl=False,
)

print("StatPearls stats:", sp_stats)
print("Docs:", len(statpearls_docs))
print("Sample meta:", statpearls_docs[0].metadata)
print("Sample preview:", statpearls_docs[0].page_content[:200])




StatPearls stats: StatPearlsBuildStats(tarball_downloaded=False, extracted=False, nxml_files_found=9629, jsonl_files_created=0, articles_loaded=300, chunks_emitted=13922)
Docs: 13922
Sample meta: {'source': 'statpearls', 'split': 'repo', 'title': 'Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings through hospital discharge.', 'topic_name': 'Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings through hospital discharge.', 'source_filename': 'article-100024.nxml', 'chunk_index': 0, 'chunk_id': '6de0970aebeb1d82231d62346489c3cd6ee24eebe32979badfdeae7817d4c9bb'}
Sample preview: Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings throu

## 2. RAG engine setup

Build a fresh Chroma vector store from the documents.


In [39]:
rag = RAGEngine(persist_dir="../data/vector_db_statpearls_kgragex")
rag.setup(documents=statpearls_docs, reset=False)


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 3c895ca0-86e1-414c-85f5-06a77389744a)')' thrown while requesting HEAD https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/./modules.json
Retrying in 1s [Retry 1/5].


Loading existing vector store from ../data/vector_db_statpearls_kgragex...
RagEngine ready.


## 3. LLM client

Ollama runs locally; adjust the model name to your installation.


In [40]:
client = LLMClient(provider="ollama", model_name="llama3.2:latest")

## 4. Build the KG from StatPearls chunks (JSONL cache)

In [41]:
registry_path = Path("../data/relations_registry.json")
registry = RelationRegistry.load(registry_path)

extractor = KGTripletExtractor(llm_client=client, max_retries=2, relation_registry=registry)
kg_service = KGBuildService(extractor=extractor, relation_registry=registry, registry_cache_path=registry_path, add_reverse_edges=False)


kg_store, kg_stats = kg_service.build_or_load(
    docs=statpearls_docs,
    cache_path=Path("../data/statpearls_kg/kg_statpearls.jsonl"),
    limit=13922,
    force_rebuild=False,
    chunk_id_prefix="sp",
    source_name="statpearls",
)

print("KG stats:", kg_store.stats())
print("Build stats:", kg_stats)



KG stats: {'nodes': 1387, 'edges': 2460}
Build stats: KGBuildStats(docs_seen=0, docs_with_triples=0, triples_forward_total=0, triples_reverse_total=0, triples_written_total=2460)


## 5. Quick KG stats


In [42]:
import numpy as np

nodes = [str(n) for n in kg_store.g.nodes]
token_counts = np.array([len(n.split()) for n in nodes])
print("Nodes:", len(nodes))
print("Tokens per node min/median/mean/max:", token_counts.min(), np.median(token_counts), token_counts.mean(), token_counts.max())

long_nodes = [n for n in nodes if len(n.split()) > 10]
print("Nodes with >10 tokens:", len(long_nodes))
print("Example long nodes:", long_nodes[:5])

Nodes: 1387
Tokens per node min/median/mean/max: 1 2.0 1.9798125450612833 9
Nodes with >10 tokens: 0
Example long nodes: []


### Graph connectivity and relation stats

Inspect connectivity, degree distribution, and relation frequencies in the KG.


In [43]:
import networkx as nx
from collections import Counter
import numpy as np

G = kg_store.g
print("Nodes:", G.number_of_nodes(), "Edges:", G.number_of_edges())

wcc = list(nx.weakly_connected_components(G)) if G.number_of_nodes() else []
sizes = sorted([len(c) for c in wcc], reverse=True) if wcc else [0]
print("Weakly connected components:", len(wcc))
print("Largest component size:", sizes[0])

deg = np.array([d for _, d in G.degree()]) if G.number_of_nodes() else np.array([0])
print("Degree min/median/mean/max:", int(deg.min()), float(np.median(deg)), float(deg.mean()), int(deg.max()))

rel_counts = Counter()
for _, _, data in G.edges(data=True):
    rel_counts[str(data.get("relation", "Unknown"))] += 1
print("Top relations:", rel_counts.most_common(12))

Nodes: 1387 Edges: 2460
Weakly connected components: 255
Largest component size: 633
Degree min/median/mean/max: 2 2.0 3.5472242249459263 70
Top relations: [('affects', 421), ('is_affected_by', 421), ('part_of', 290), ('has_part', 290), ('involves_medication', 129), ('is_involved_in_treatment', 129), ('is_detected_by', 52), ('detects', 52), ('has_occurrence_of', 49), ('occurs_in', 49), ('can_affect', 38), ('can_be_affected_by', 38)]


## 6. Pipeline and explainer


In [44]:
pipeline = KGRAGExPipeline(rag=rag, kg=kg_store, llm_client=client)
explainer = KGRAGExExplainer(pipeline)

## 7. Load MedMCQA, small sample


In [45]:
mcqa_loader = MedMCQADataLoader()
questions = mcqa_loader.setup(split="train", as_documents=True, limit=1)

print("MedMCQA questions:", len(questions))
print("Sample question:", questions[0].metadata.get("question"))
print("Gold:", questions[0].metadata.get("answer"))

MedMCQA questions: 1
Sample question: True statements about asbestosis
Gold: D


## 8. Single-run test on one MedMCQA question


In [46]:
d = questions[0]

qid = str(d.metadata.get("question_id", "q-0"))
question_text = str(d.metadata.get("question") or d.page_content or "").strip()
gold = d.metadata.get("answer")

run = pipeline.run(
    question_id=qid,
    question=question_text,
    gold_answer=gold,
    entity_k=6,
    top_k_docs=4,
)

print("Question ID:", run.question_id)
print("Question:", run.question)
print("Gold:", run.gold_answer)
print("Entities:", run.entities)
print("Start, End:", run.start, run.end)
print("Path length:", len(run.path))
print("KG context preview:", (run.kg_context or "")[:600])
print("Answer:", run.answer)
print("LLM calls:", run.llm_calls)

Connecting to local Ollama (llama3.2:latest)...
Question ID: b255c99a-ee84-4903-a501-35b501afeff0
Question: True statements about asbestosis
Gold: D
Entities: ['asbestosis', 'disease', 'lung', 'cancer', 'risk']
Start, End: Disease Lung
Path length: 0
KG context preview: 
Answer: I cannot provide information on asbestosis based on the provided context. Can I help you with something else?
LLM calls: 1


## 9. Optional: KG-level perturbations


In [47]:
report = explainer.explain(run)

print("Most influential node:", report.most_influential_node)
print("Most influential edge:", report.most_influential_edge)
print("Most influential subpath:", report.most_influential_subpath)

print("\nFirst 10 perturbation outcomes:")
for o in report.outcomes[:10]:
    print({"kind": o.kind, "removed": o.removed, "answer_changed": o.answer_changed, "answer": (o.answer or "")[:50]})

Most influential node: None
Most influential edge: None
Most influential subpath: None

First 10 perturbation outcomes:


## 10. Mini-batch test


In [48]:
import pandas as pd

rows = []
for d in questions:
    qid = str(d.metadata.get("question_id", "q"))
    qtxt = str(d.metadata.get("question") or d.page_content or "").strip()
    gold = d.metadata.get("answer")

    run_i = pipeline.run(
        question_id=qid,
        question=qtxt,
        gold_answer=gold,
        entity_k=6,
        top_k_docs=4,
    )
    rep_i = explainer.explain(run_i)
    num_changed = sum(1 for o in rep_i.outcomes if o.answer_changed)

    rows.append(
        {
            "question_id": run_i.question_id,
            "gold": run_i.gold_answer,
            "path_len": len(run_i.path),
            "entities": ", ".join(run_i.entities[:6]),
            "answer": (run_i.answer or "")[:10],
            "perturbations": len(rep_i.outcomes),
            "answer_changes": num_changed,
            "most_node": rep_i.most_influential_node,
            "most_edge": rep_i.most_influential_edge,
            "most_subpath": rep_i.most_influential_subpath,
        }
    )

df = pd.DataFrame(rows)
df

Unnamed: 0,question_id,gold,path_len,entities,answer,perturbations,answer_changes,most_node,most_edge,most_subpath
0,b255c99a-ee84-4903-a501-35b501afeff0,D,0,"asbestosis, disease, lung, cancer, risk",I cannot p,0,0,,,


## 11. Optional: compact stats for your RQs

RQ2 (position proxy) depends on explainer implementation details.  
RQ3 (node types) depends on which entity types your triplet extractor emits.  
RQ4 (graph metrics) depends on whether your metrics module computes degree and betweenness.

If your report exposes these fields, we summarize them here.


In [49]:
if getattr(report, "rq2_positions", None):
    print("RQ2 positions sample:", report.rq2_positions[:5])

if getattr(report, "rq3_node_types", None):
    print("RQ3 node types sample:", report.rq3_node_types[:10])

if getattr(report, "rq4_graph_metrics", None):
    print("RQ4 graph metrics sample:", report.rq4_graph_metrics[:5])


## 12. Try your own question

Use this to quickly check whether the KG has enough coverage.

Note:
If `path_len` is often 0, it's usually an entity extraction/matching or KG coverage issue, not a RAG issue.


In [50]:
qid = "manual-1"
question = "How does CTO relates to Cardiomyopathy"

run_m = pipeline.run(
    question_id=qid,
    question=question,
    gold_answer=None,
    entity_k=6,
    top_k_docs=4,
)

print("Entities:", run_m.entities)
print("Start, End:", run_m.start, run_m.end)
print("Path length:", len(run_m.path))
print("KG context preview:", (run_m.kg_context or "")[:400])
print("Answer:", run_m.answer)

rep_m = explainer.explain(run_m)
print("Most influential node:", rep_m.most_influential_node)
print("Most influential edge:", rep_m.most_influential_edge)
print("Most influential subpath:", rep_m.most_influential_subpath)


Entities: ['CTO', 'Cardiomyopathy']
Start, End: CTO Cardiomyopathy
Path length: 0
KG context preview: A cerebrovascular accident (CTO) can have a significant impact on cardiovascular health, as it often leads to ischemia in the brain. Ischemic heart disease, characterized by reduced blood flow to the heart muscle, is a common complication of CTO and shares many risk factors with cardiomyopathy, a condition where the heart muscle becomes weakened and cannot function properly. In fact, individuals w
Answer: CTO relates to Cardiomyopathy as a risk factor and complication.


KeyboardInterrupt: 

In [None]:
qid = "manual-1"
question = "How does Tension pneumothorax relates to Septicemia"

run_m = pipeline.run(
    question_id=qid,
    question=question,
    gold_answer=None,
    entity_k=6,
    top_k_docs=4,
)

print("Entities:", run_m.entities)
print("Start, End:", run_m.start, run_m.end)
print("Path length:", len(run_m.path))
print("KG context preview:", (run_m.kg_context or "")[:400])
print("Answer:", run_m.answer)

rep_m = explainer.explain(run_m)
print("Most influential node:", rep_m.most_influential_node)
print("Most influential edge:", rep_m.most_influential_edge)
print("Most influential subpath:", rep_m.most_influential_subpath)


Entities: ['Tension', 'pneumothorax', 'septicemia']
Start, End: None None
Path length: 0
KG context preview: 
Answer: I cannot determine how tension pneumothorax relates to septicemia based on the provided context.
Most influential node: None
Most influential edge: None
Most influential subpath: None


### Optional: print RQ metrics for the manual question

Surface the RQ fields directly for quick inspection.


In [None]:
print("sensitivity:", rep_m.rq1_sensitivity)
print("Postion:", rep_m.rq2_positions)
print("Node types:", rep_m.rq3_node_types)
print("graph_metrics:", rep_m.rq4_graph_metrics)

### Debug: graph connectivity between start/end nodes

Check whether the start/end nodes exist, inspect neighbors, and probe shortest paths.


In [None]:
import networkx as nx

G = kg_store.g
start = run_m.start
end = run_m.end

print("start exists:", start in G)
print("end exists:", end in G)

if start in G:
    print("\nStart neighbors (out):", list(G.successors(start))[:30])
    print("Start neighbors (in):", list(G.predecessors(start))[:30])

if end in G:
    print("\nEnd neighbors (out):", list(G.successors(end))[:30])
    print("End neighbors (in):", list(G.predecessors(end))[:30])

if start in G and end in G:
    GU = G.to_undirected(as_view=True)
    try:
        p = nx.shortest_path(GU, start, end)
        print("\nUndirected shortest path length:", len(p) - 1)
        print("Path nodes:", p[:20], "..." if len(p) > 20 else "")
    except Exception as e:
        print("\nNo undirected path:", e)

# Additional: search for similar node names
def find_similar_nodes(q, limit=20):
    ql = (q or "").lower()
    hits = [n for n in G.nodes if ql in str(n).lower()]
    return hits[:limit]

print("\nSimilar to 'CTO':", find_similar_nodes("CTO", 30))
print("Similar to 'Chronic Total Occlusion':", find_similar_nodes("Chronic Total Occlusion", 30))
print("Similar to 'Stroke':", find_similar_nodes("Stroke", 30))
print("Similar to 'Ischemic Heart Disease':", find_similar_nodes("Ischemic Heart Disease", 30))
