# KGRAG Ex, medizinischer Datensatz, MedMCQA

Dieses Notebook demonstriert KGRAG Ex auf einem medizinischen Multiple Choice Datensatz, MedMCQA.

Ablauf:
1. MedMCQA laden, als LangChain Documents
2. Vectorstore, Retrieval Setup
3. Knowledge Graph aus Triples bauen, aus denselben Documents
4. KGRAG Ex Pipeline ausführen, KG Pfad in Pseudo Paragraph konvertieren
5. Perturbationen auf KG Ebene, Node, Edge, Subpath, Einfluss auf Antwort



In [1]:
import sys
from pathlib import Path

project_root = next((p for p in [Path.cwd()] + list(Path.cwd().parents) if (p / "src").exists()), None)
if project_root is None:
    raise RuntimeError('"src" Verzeichnis nicht gefunden. Bitte Notebook im Projekt laufen lassen.')

root_str = str(project_root)
if root_str not in sys.path:
    sys.path.insert(0, root_str)

print("Project root:", project_root)


Project root: /Users/MeinNotebook/xai-rag


In [2]:
# %%
from pathlib import Path

from src.modules.statspearls_data_loader import StatPearlsDataLoader
from src.modules.medmcqa_data_loader import MedMCQADataLoader

from src.modules.rag_engine import RAGEngine
from src.modules.llm_client import LLMClient

from src.modules.kg_triplet_extractor import KGTripletExtractor
from src.modules.kg_build_service import KGBuildService

from src.modules.kgrag_ex_pipeline import KGRAGExPipeline
from src.modules.kgrag_ex_explainer import KGRAGExExplainer


  from .autonotebook import tqdm as notebook_tqdm



## 1. StatPearls laden, pro Buch ein Document

Für Smoke Test: wenige Bücher, und max_chars_per_book begrenzen, sonst wird KG Build teuer.

In [3]:
sp_loader = StatPearlsDataLoader()

statpearls_docs, sp_stats = sp_loader.setup(
    limit_articles=300,
    as_documents=True,
    force_download=False,
    force_extract=False,
    force_rebuild_jsonl=False,
)

print("StatPearls stats:", sp_stats)
print("Docs:", len(statpearls_docs))
print("Sample meta:", statpearls_docs[0].metadata)
print("Sample preview:", statpearls_docs[0].page_content[:200])




StatPearls stats: StatPearlsBuildStats(tarball_downloaded=False, extracted=False, nxml_files_found=9629, jsonl_files_created=0, articles_loaded=300, chunks_emitted=13922)
Docs: 13922
Sample meta: {'source': 'statpearls', 'split': 'repo', 'title': 'Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings through hospital discharge.', 'topic_name': 'Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings through hospital discharge.', 'source_filename': 'article-100024.nxml', 'chunk_index': 0, 'chunk_id': '6de0970aebeb1d82231d62346489c3cd6ee24eebe32979badfdeae7817d4c9bb'}
Sample preview: Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings throu

## 2. RAG Engine Setup

Wir bauen einen frischen Chroma Vectorstore aus den Documents.


In [4]:
rag = RAGEngine(persist_dir="../data/vector_db_statpearls_kgragex")
rag.setup(documents=statpearls_docs, reset=False, k_documents=6)


Loading existing vector store from ../data/vector_db_statpearls_kgragex...
RagEngine ready.


## 3. LLM Client

Ollama lokal, Modell bitte an deine lokale Installation anpassen.


In [5]:
client = LLMClient(provider="ollama", model_name="llama3.2:latest")

## 5. KG aus StatPearls Chunks bauen, Cache JSONL
#
Wichtig:
KGBuildService nutzt doc.metadata["chunk_id"] falls vorhanden, damit chunk_id stabil bleibt.


In [6]:
extractor = KGTripletExtractor(llm_client=client, max_retries=2)
kg_service = KGBuildService(extractor=extractor, add_reverse_edges=True)

kg_store, kg_stats = kg_service.build_or_load(
    docs=statpearls_docs,
    cache_path=Path("../data/statpearls_kg/kg_statpearls.jsonl"),
    limit=10,
    force_rebuild=False,
    chunk_id_prefix="sp",
    source_name="statpearls",
)

print("KG stats:", kg_store.stats())
print("Build stats:", kg_stats)



KG stats: {'nodes': 15, 'edges': 20}
Build stats: KGBuildStats(docs_seen=0, docs_with_triples=0, triples_forward_total=0, triples_reverse_total=0, triples_written_total=20)


## 6. KG Quick Stats

In [7]:
import numpy as np

nodes = [str(n) for n in kg_store.g.nodes]
token_counts = np.array([len(n.split()) for n in nodes])
print("Nodes:", len(nodes))
print("Tokens per node min/median/mean/max:", token_counts.min(), np.median(token_counts), token_counts.mean(), token_counts.max())

long_nodes = [n for n in nodes if len(n.split()) > 10]
print("Nodes with >10 tokens:", len(long_nodes))
print("Example long nodes:", long_nodes[:5])

Nodes: 15
Tokens per node min/median/mean/max: 1 2.0 1.9333333333333333 4
Nodes with >10 tokens: 0
Example long nodes: []


In [8]:
import networkx as nx
from collections import Counter
import numpy as np

G = kg_store.g
print("Nodes:", G.number_of_nodes(), "Edges:", G.number_of_edges())

wcc = list(nx.weakly_connected_components(G)) if G.number_of_nodes() else []
sizes = sorted([len(c) for c in wcc], reverse=True) if wcc else [0]
print("Weakly connected components:", len(wcc))
print("Largest component size:", sizes[0])

deg = np.array([d for _, d in G.degree()]) if G.number_of_nodes() else np.array([0])
print("Degree min/median/mean/max:", int(deg.min()), float(np.median(deg)), float(deg.mean()), int(deg.max()))

rel_counts = Counter()
for _, _, data in G.edges(data=True):
    rel_counts[str(data.get("relation", "Unknown"))] += 1
print("Top relations:", rel_counts.most_common(12))

Nodes: 15 Edges: 20
Weakly connected components: 5
Largest component size: 4
Degree min/median/mean/max: 2 2.0 2.6666666666666665 4
Top relations: [('affects', 5), ('is_affected_by', 5), ('comorbid_with', 4), ('involves_medication', 1), ('is_involved_in_treatment', 1), ('has_increased_risk_due_to', 1), ('increases_risk_of', 1), ('has_occurrence_of', 1), ('occurs_in', 1)]


## 7. Pipeline und Explainer

In [7]:
pipeline = KGRAGExPipeline(rag=rag, kg=kg_store, llm_client=client)
explainer = KGRAGExExplainer(pipeline)

## 8. MedMCQA laden, kleine Stichprobe

In [8]:
mcqa_loader = MedMCQADataLoader()
questions = mcqa_loader.setup(split="train", as_documents=True, limit=8)

print("MedMCQA questions:", len(questions))
print("Sample question:", questions[0].metadata.get("question"))
print("Gold:", questions[0].metadata.get("answer"))

MedMCQA questions: 8
Sample question: Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma
Gold: C


## 9. Single Run Test, eine MedMCQA Frage


In [11]:
d = questions[0]

qid = str(d.metadata.get("question_id", "q-0"))
question_text = str(d.metadata.get("question") or d.page_content or "").strip()
gold = d.metadata.get("answer")

run = pipeline.run(
    question_id=qid,
    question=question_text,
    gold_answer=gold,
    entity_k=6,
    top_k_docs=4,
)

print("Question ID:", run.question_id)
print("Question:", run.question)
print("Gold:", run.gold_answer)
print("Entities:", run.entities)
print("Start, End:", run.start, run.end)
print("Path length:", len(run.path))
print("KG context preview:", (run.kg_context or "")[:600])
print("Answer:", run.answer)
print("LLM calls:", run.llm_calls)

Connecting to local Ollama (llama3.2:latest)...
Question ID: e9ad821a-c438-4965-9f77-760819dfa155
Question: Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma
Gold: C
Entities: ['obstruction', 'hyperplasia', 'kidney']
Start, End: None None
Path length: 0
KG context preview: 
Answer: I cannot determine the change in kidney parenchyma due to chronic urethral obstruction caused by benign prostatic hyperplasia from the provided context.
LLM calls: 1


## 10. Optional, Perturbationen auf KG Ebene

In [12]:
report = explainer.explain(run)

print("Most influential node:", report.most_influential_node)
print("Most influential edge:", report.most_influential_edge)
print("Most influential subpath:", report.most_influential_subpath)

print("\nFirst 10 perturbation outcomes:")
for o in report.outcomes[:10]:
    print({"kind": o.kind, "removed": o.removed, "answer_changed": o.answer_changed, "answer": (o.answer or "")[:50]})

Most influential node: None
Most influential edge: None
Most influential subpath: None

First 10 perturbation outcomes:


## 11. Mini Batch Test


In [13]:
import pandas as pd

rows = []
for d in questions:
    qid = str(d.metadata.get("question_id", "q"))
    qtxt = str(d.metadata.get("question") or d.page_content or "").strip()
    gold = d.metadata.get("answer")

    run_i = pipeline.run(
        question_id=qid,
        question=qtxt,
        gold_answer=gold,
        entity_k=6,
        top_k_docs=4,
    )
    rep_i = explainer.explain(run_i)
    num_changed = sum(1 for o in rep_i.outcomes if o.answer_changed)

    rows.append(
        {
            "question_id": run_i.question_id,
            "gold": run_i.gold_answer,
            "path_len": len(run_i.path),
            "entities": ", ".join(run_i.entities[:6]),
            "answer": (run_i.answer or "")[:10],
            "perturbations": len(rep_i.outcomes),
            "answer_changes": num_changed,
            "most_node": rep_i.most_influential_node,
            "most_edge": rep_i.most_influential_edge,
            "most_subpath": rep_i.most_influential_subpath,
        }
    )

df = pd.DataFrame(rows)
df

Unnamed: 0,question_id,gold,path_len,entities,answer,perturbations,answer_changes,most_node,most_edge,most_subpath
0,e9ad821a-c438-4965-9f77-760819dfa155,C,0,"obstruction, hyperplasia, kidney",I cannot d,0,0,,,
1,e3d3c4e1-4fb2-45e7-9f88-247cc8f373b3,C,0,"Vitamin D, Vitamin B12, Vitamin A",A,0,0,,,
2,5c38bea6-787a-44a9-b2df-88f4218ab914,D,0,"surgical, options, for, morbid, obesity, except",C,0,0,,,
3,cdeedb04-fbe9-432c-937c-d53ac24475de,A,0,"common carotid, optic nerve, artery, thrombus,...",C,0,0,,,
4,dc6794a3-b108-47c5-8b1b-3b4931577249,B,0,"growth, hormone, effect, on, growth, through",A,0,0,,,
5,5ab84ea8-12d1-47d4-ab22-668ebf01e64c,C,0,"Scrub typhus, typhus, scrub",I cannot d,0,0,,,
6,a83de6e4-9427-4480-b404-d96621ebb640,C,0,"colposcopy, cervical intraepithelial neoplasia...",I cannot d,0,0,,,
7,f3bf8583-231b-4b7a-828c-179b0f9ccdd9,C,0,"Rectal exam, Anal fissure, Hemorrhoids, Divert...",A,0,0,,,


## 9. Optional, einfache Statistiken für deine RQs

RQ2, Positionsproxy, das ist abhängig von deinem Explainer Implementationsdetail.  
RQ3, Node Types, hängt davon ab, welche Entity Types dein Triplet Extractor generiert.  
RQ4, Graph Metrics, hängt davon ab, ob dein Metrics Modul Degree und Betweenness berechnet.

Wenn dein Report diese Felder enthält, zeigen wir sie kompakt.


In [14]:
if getattr(report, "rq2_positions", None):
    print("RQ2 positions sample:", report.rq2_positions[:5])

if getattr(report, "rq3_node_types", None):
    print("RQ3 node types sample:", report.rq3_node_types[:10])

if getattr(report, "rq4_graph_metrics", None):
    print("RQ4 graph metrics sample:", report.rq4_graph_metrics[:5])


## 10. Eigene Frage ausprobieren

Hier kannst du schnell prüfen, ob der KG ausreichend Coverage hat.

Hinweis:
Wenn `path_len` oft 0 ist, ist das kein RAG Problem, sondern Entity Extraction, Entity Matching, oder KG Coverage.


In [8]:
qid = "manual-1"
question = "How does CTO relates to Cardiomyopathy"

run_m = pipeline.run(
    question_id=qid,
    question=question,
    gold_answer=None,
    entity_k=6,
    top_k_docs=4,
)

print("Entities:", run_m.entities)
print("Start, End:", run_m.start, run_m.end)
print("Path length:", len(run_m.path))
print("KG context preview:", (run_m.kg_context or "")[:400])
print("Answer:", run_m.answer)

rep_m = explainer.explain(run_m)
print("Most influential node:", rep_m.most_influential_node)
print("Most influential edge:", rep_m.most_influential_edge)
print("Most influential subpath:", rep_m.most_influential_subpath)


Connecting to local Ollama (llama3.2:latest)...
('CTO', 'Cardiomyopathy')
PathAsLists(node_list=['CTO', 'Ischemic Heart Disease', 'Cardiomyopathy'], edge_list=['affects', 'comorbid_with'], subpath_list=[('CTO', 'affects', 'Ischemic Heart Disease'), ('Ischemic Heart Disease', 'comorbid_with', 'Cardiomyopathy')], chain_str='CTO->[affects]->Ischemic Heart Disease->[comorbid_with]->Cardiomyopathy')
CTO->[affects]->Ischemic Heart Disease->[comorbid_with]->Cardiomyopathy
Entities: ['CTO', 'Cardiomyopathy']
Start, End: CTO Cardiomyopathy
Path length: 2
KG context preview: A cerebrovascular accident (CTO) can have a significant impact on cardiovascular health, as it often leads to ischemia in the brain. Ischemic heart disease, characterized by reduced blood flow to the heart muscle, is a common complication of CTO and shares many risk factors with cardiomyopathy, a condition where the heart muscle becomes weakened and cannot function properly. In fact, individuals w
Answer: CTO relates to Card

In [9]:
print("sensitivity:", rep_m.rq1_sensitivity)
print("Postion:", rep_m.rq2_positions)
print("Node types:", rep_m.rq3_node_types)
print("graph_metrics:", rep_m.rq4_graph_metrics)

sensitivity: RQ1Sensitivity(node_changed=3, edge_changed=1, subpath_changed=2, node_total=3, edge_total=2, subpath_total=2)
Postion: [RQ2PositionRecord(kind='node', removed='CTO', index=0, rel_pos=0.0, answer_changed=True), RQ2PositionRecord(kind='node', removed='Ischemic Heart Disease', index=1, rel_pos=0.5, answer_changed=True), RQ2PositionRecord(kind='node', removed='Cardiomyopathy', index=2, rel_pos=1.0, answer_changed=True), RQ2PositionRecord(kind='edge', removed='CTO::affects::Ischemic Heart Disease', index=0, rel_pos=0.0, answer_changed=False), RQ2PositionRecord(kind='edge', removed='Ischemic Heart Disease::comorbid_with::Cardiomyopathy', index=1, rel_pos=1.0, answer_changed=True), RQ2PositionRecord(kind='subpath', removed='(CTO, affects, Ischemic Heart Disease)', index=0, rel_pos=0.0, answer_changed=True), RQ2PositionRecord(kind='subpath', removed='(Ischemic Heart Disease, comorbid_with, Cardiomyopathy)', index=1, rel_pos=1.0, answer_changed=True)]
Node types: [RQ3NodeTypeRecor

In [9]:
import networkx as nx

G = kg_store.g
start = run_m.start
end = run_m.end

print("start exists:", start in G)
print("end exists:", end in G)

if start in G:
    print("\nStart neighbors (out):", list(G.successors(start))[:30])
    print("Start neighbors (in):", list(G.predecessors(start))[:30])

if end in G:
    print("\nEnd neighbors (out):", list(G.successors(end))[:30])
    print("End neighbors (in):", list(G.predecessors(end))[:30])

if start in G and end in G:
    GU = G.to_undirected(as_view=True)
    try:
        p = nx.shortest_path(GU, start, end)
        print("\nUndirected shortest path length:", len(p) - 1)
        print("Path nodes:", p[:20], "..." if len(p) > 20 else "")
    except Exception as e:
        print("\nNo undirected path:", e)

# zusätzlich: suche ähnliche Node-Namen
def find_similar_nodes(q, limit=20):
    ql = (q or "").lower()
    hits = [n for n in G.nodes if ql in str(n).lower()]
    return hits[:limit]

print("\nSimilar to 'CTO':", find_similar_nodes("CTO", 30))
print("Similar to 'Chronic Total Occlusion':", find_similar_nodes("Chronic Total Occlusion", 30))
print("Similar to 'Stroke':", find_similar_nodes("Stroke", 30))
print("Similar to 'Ischemic Heart Disease':", find_similar_nodes("Ischemic Heart Disease", 30))


start exists: True
end exists: True

Start neighbors (out): ['Ischemic Heart Disease', 'TIMI Flow Grading System']
Start neighbors (in): ['Ischemic Heart Disease', 'TIMI Flow Grading System']

End neighbors (out): ['Ischemic Heart Disease']
End neighbors (in): ['Ischemic Heart Disease']

Undirected shortest path length: 2
Path nodes: ['CTO', 'Ischemic Heart Disease', 'Cardiomyopathy'] 

Similar to 'CTO': ['CTO']
Similar to 'Chronic Total Occlusion': ['Chronic Total Occlusion']
Similar to 'Stroke': []
Similar to 'Ischemic Heart Disease': ['Ischemic Heart Disease']
