# KGRAG Ex demo: MedMCQA (medical dataset)

This notebook demonstrates KGRAG Ex on MedMCQA, a medical multiple-choice dataset.

Workflow:
1. Load StatPearls as LangChain Documents.
2. Build the vector store and retrieval setup.
3. Build a knowledge graph from triplets extracted from the same documents.
4. Run the KGRAG Ex pipeline and convert KG paths into pseudo paragraphs.
5. Inspect explanations, optional perturbations, and report stats.



## Notebook setup and project path

Ensure the project root is on `sys.path` so local modules import correctly.


In [1]:
import sys
from pathlib import Path

project_root = next((p for p in [Path.cwd()] + list(Path.cwd().parents) if (p / "src").exists()), None)
if project_root is None:
    raise RuntimeError('"src" Verzeichnis nicht gefunden. Bitte Notebook im Projekt laufen lassen.')

root_str = str(project_root)
if root_str not in sys.path:
    sys.path.insert(0, root_str)

print("Project root:", project_root)


Project root: /Users/MeinNotebook/xai-rag


## Imports

Load data loaders, RAG components, KG builders, and the KGRAG Ex pipeline.


In [2]:
# %%
from pathlib import Path

from src.modules.loader.statspearls_data_loader import StatPearlsDataLoader
from src.modules.loader.medmcqa_data_loader import MedMCQADataLoader

from src.modules.rag.rag_engine import RAGEngine
from src.modules.rag.multihop_rag_engine import MultiHopRAGEngine
from src.modules.llm.llm_client import LLMClient

from src.modules.knowledge_graph.kg_triplet_extractor import KGTripletExtractor
from src.modules.knowledge_graph.kg_build_service import KGBuildService

from src.modules.knowledge_graph.kgrag_ex_pipeline import KGRAGExPipeline
from src.modules.explainers.kgrag_ex_explainer import KGRAGExExplainer
from src.modules.knowledge_graph.relation_registry import RelationRegistry, ProposedRelation, canon_relation


  from .autonotebook import tqdm as notebook_tqdm


## 1. Load StatPearls, one Document per book

For a smoke test, keep only a few books and limit `max_chars_per_book`, otherwise KG building gets expensive.


In [3]:
sp_loader = StatPearlsDataLoader()

statpearls_docs, sp_stats = sp_loader.setup(
    limit_articles=300,
    as_documents=True,
    force_download=False,
    force_extract=False,
    force_rebuild_jsonl=False,
)

print("StatPearls stats:", sp_stats)
print("Docs:", len(statpearls_docs))
print("Sample meta:", statpearls_docs[0].metadata)
print("Sample preview:", statpearls_docs[0].page_content[:200])




StatPearls stats: StatPearlsBuildStats(tarball_downloaded=False, extracted=False, nxml_files_found=9629, jsonl_files_created=0, articles_loaded=300, chunks_emitted=13922)
Docs: 13922
Sample meta: {'source': 'statpearls', 'split': 'repo', 'title': 'Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings through hospital discharge.', 'topic_name': 'Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings through hospital discharge.', 'source_filename': 'article-100024.nxml', 'chunk_index': 0, 'chunk_id': '6de0970aebeb1d82231d62346489c3cd6ee24eebe32979badfdeae7817d4c9bb'}
Sample preview: Thrombolysis in Myocardial Infarction (TIMI) Trial, Phase I: A comparison between intravenous tissue plasminogen activator and intravenous streptokinase. Clinical findings throu

## 3. LLM client

Ollama runs locally; adjust the model name to your installation.


In [4]:
client = LLMClient(provider="ollama", model_name="gemma3:4b")

## 2. RAG engine setup

Build a fresh Chroma vector store from the documents.


In [5]:
NUM_HOPS = 2
rag = RAGEngine(persist_dir="../data/vector_db_statpearls_kgragex")
rag.setup(documents=statpearls_docs, reset=False)
multi_hop = MultiHopRAGEngine(rag_engine=rag, llm_client=client, num_hops=NUM_HOPS)

Loading existing vector store from ../data/vector_db_statpearls_kgragex...
RagEngine ready.
Connecting to local Ollama (gemma3:4b)...


## 4. Build the KG from StatPearls chunks (JSONL cache)

In [12]:
registry_path = Path("../data/relations_registry.json")
registry = RelationRegistry.load(registry_path)

extractor = KGTripletExtractor(llm_client=client, max_retries=2, relation_registry=registry)
kg_service = KGBuildService(extractor=extractor, relation_registry=registry, registry_cache_path=registry_path, add_reverse_edges=False)


kg_store, kg_stats = kg_service.build_or_load(
    docs=statpearls_docs,
    cache_path=Path("../data/statpearls_kg/kg_statpearls.jsonl"),
    limit=5000,
    force_rebuild=False,
    chunk_id_prefix="sp",
    source_name="statpearls",
)

print("KG stats:", kg_store.stats())
print("Build stats:", kg_stats)



KG stats: {'nodes': 5718, 'edges': 6563}
Build stats: KGBuildStats(docs_seen=0, docs_with_triples=0, triples_forward_total=0, triples_reverse_total=0, triples_written_total=6563)


## 5. Quick KG stats


In [7]:
import numpy as np

nodes = [str(n) for n in kg_store.g.nodes]
token_counts = np.array([len(n.split()) for n in nodes])
print("Nodes:", len(nodes))
print("Tokens per node min/median/mean/max:", token_counts.min(), np.median(token_counts), token_counts.mean(), token_counts.max())

long_nodes = [n for n in nodes if len(n.split()) > 10]
print("Nodes with >10 tokens:", len(long_nodes))
print("Example long nodes:", long_nodes[:5])

Nodes: 5718
Tokens per node min/median/mean/max: 1 2.0 2.2848898216159497 6
Nodes with >10 tokens: 0
Example long nodes: []


### Graph connectivity and relation stats

Inspect connectivity, degree distribution, and relation frequencies in the KG.


In [8]:
import networkx as nx
from collections import Counter
import numpy as np

G = kg_store.g
print("Nodes:", G.number_of_nodes(), "Edges:", G.number_of_edges())

wcc = list(nx.weakly_connected_components(G)) if G.number_of_nodes() else []
sizes = sorted([len(c) for c in wcc], reverse=True) if wcc else [0]
print("Weakly connected components:", len(wcc))
print("Largest component size:", sizes[0])

deg = np.array([d for _, d in G.degree()]) if G.number_of_nodes() else np.array([0])
print("Degree min/median/mean/max:", int(deg.min()), float(np.median(deg)), float(deg.mean()), int(deg.max()))

rel_counts = Counter()
for _, _, data in G.edges(data=True):
    rel_counts[str(data.get("relation", "Unknown"))] += 1
print("Top relations:", rel_counts.most_common(12))

Nodes: 5718 Edges: 6563
Weakly connected components: 643
Largest component size: 4034
Degree min/median/mean/max: 1 1.0 2.2955578873732074 140
Top relations: [('ASSOCIATED_WITH', 854), ('CAUSES', 275), ('CONTRAINDICATED_FOR', 173), ('TREATED_WITH', 159), ('AFFECTS', 156), ('INCLUDES', 136), ('CONTRIBUTES_TO', 123), ('USED_FOR', 111), ('TREATS', 109), ('REQUIRES', 85), ('CAN_CAUSE', 83), ('INDICATED_FOR', 79)]


In [9]:
import networkx as nx
from typing import List, Tuple, Optional, Dict, Any

def _to_simple_undirected(G: nx.MultiDiGraph) -> nx.Graph:
    """
    MultiDiGraph oder DiGraph -> einfacher ungerichteter Graph.
    Mehrfachkanten werden zusammengefasst.
    """
    H = nx.Graph()
    H.add_nodes_from(G.nodes())
    H.add_edges_from((u, v) for u, v in G.edges())
    return H

def _largest_connected_component_nodes(H: nx.Graph) -> List[str]:
    if H.number_of_nodes() == 0:
        return []
    comps = list(nx.connected_components(H))
    if not comps:
        return []
    return list(max(comps, key=len))

def _double_sweep_long_path(H: nx.Graph, nodes_subset: Optional[List[str]] = None) -> List[str]:
    """
    Double Sweep BFS: liefert typischerweise einen sehr langen kürzesten Pfad,
    oft eine gute Approximation des Durchmessers in großen Graphen.
    """
    if H.number_of_nodes() == 0:
        return []

    if nodes_subset is not None:
        S = set(nodes_subset)
        Hs = H.subgraph(S).copy()
    else:
        Hs = H

    if Hs.number_of_nodes() == 0:
        return []

    start = next(iter(Hs.nodes()))
    dist1 = nx.single_source_shortest_path_length(Hs, start)
    a = max(dist1, key=dist1.get)

    dist2 = nx.single_source_shortest_path_length(Hs, a)
    b = max(dist2, key=dist2.get)

    path = nx.shortest_path(Hs, source=a, target=b)
    return path

def path_relations_along_nodes(G: nx.MultiDiGraph, path_nodes: List[str]) -> List[str]:
    """
    Extrahiert pro Schritt u->v ein Relationslabel, wenn vorhanden.
    Bei MultiDiGraph kann es mehrere Kanten geben, hier wird ein repräsentatives Label gewählt.
    """
    rels: List[str] = []
    for u, v in zip(path_nodes, path_nodes[1:]):
        rel = None

        if G.has_edge(u, v):
            data = G.get_edge_data(u, v)
            if isinstance(data, dict) and data:
                first_key = next(iter(data.keys()))
                rel = data[first_key].get("relation") or first_key

        if rel is None and G.has_edge(v, u):
            data = G.get_edge_data(v, u)
            if isinstance(data, dict) and data:
                first_key = next(iter(data.keys()))
                rel = data[first_key].get("relation") or first_key

        rels.append(str(rel) if rel is not None else "Unknown")
    return rels

def longest_connected_path(
    G: nx.MultiDiGraph,
    *,
    use_largest_component: bool = True,
    return_relations: bool = True
) -> Dict[str, Any]:
    """
    Gibt einen langen zusammenhängenden Pfad zurück.
    Pfad ist ein kürzester Pfad zwischen zwei weit entfernten Knoten in der größten Komponente, approximiert.
    """
    H = _to_simple_undirected(G)

    if H.number_of_nodes() == 0:
        return {"path_nodes": [], "path_length_edges": 0, "component_size": 0, "relations": []}

    if use_largest_component:
        comp_nodes = _largest_connected_component_nodes(H)
        path_nodes = _double_sweep_long_path(H, nodes_subset=comp_nodes)
        component_size = len(comp_nodes)
    else:
        path_nodes = _double_sweep_long_path(H, nodes_subset=None)
        component_size = H.number_of_nodes()

    relations = path_relations_along_nodes(G, path_nodes) if return_relations else []

    return {
        "path_nodes": path_nodes,
        "path_length_edges": max(0, len(path_nodes) - 1),
        "component_size": component_size,
        "relations": relations,
    }

# Anwendung
G = kg_store.g

result = longest_connected_path(G, use_largest_component=True, return_relations=True)

print("Komponentengröße:", result["component_size"])
print("Pfadlänge in Kanten:", result["path_length_edges"])
print("Pfad Knoten:")
print(" -> ".join(result["path_nodes"]))

if result["relations"]:
    print("\nRelationen entlang des Pfades:")
    for u, r, v in zip(result["path_nodes"], result["relations"], result["path_nodes"][1:]):
        print(f"{u}  [{r}]  {v}")


Komponentengröße: 4034
Pfadlänge in Kanten: 20
Pfad Knoten:
L1 to L4 -> psoas muscle -> iliacus muscle -> lesser trochanter of the femur -> psoas major -> L1-3 branches of the lumbar plexus -> iliacus -> Abdominal aorta -> psoas -> iliopsoas -> Psoas syndrome -> MET -> diagnosis -> laboratory evaluation -> liver function -> liver transaminases -> disseminated disease -> Herpes simplex virus -> neonatal HSV -> sepsis evaluation -> lumbar puncture

Relationen entlang des Pfades:
L1 to L4  [ATTACHES_TO]  psoas muscle
psoas muscle  [JOINS]  iliacus muscle
iliacus muscle  [INSERTS_AT]  lesser trochanter of the femur
lesser trochanter of the femur  [INSERTS_ON]  psoas major
psoas major  [RECEIVES_INNERVATION_FROM]  L1-3 branches of the lumbar plexus
L1-3 branches of the lumbar plexus  [RECEIVES_INNERVATION_FROM]  iliacus
iliacus  [RECEIVES_BLOOD_FROM]  Abdominal aorta
Abdominal aorta  [RECEIVES_BLOOD_FROM]  psoas
psoas  [DIVIDED_INTO]  iliopsoas
iliopsoas  [ASSOCIATED_WITH]  Psoas syndrome
P

In [10]:
import random
import networkx as nx

def to_simple_undirected(G: nx.MultiDiGraph) -> nx.Graph:
    H = nx.Graph()
    H.add_nodes_from(G.nodes())
    H.add_edges_from((u, v) for u, v in G.edges())
    return H

def largest_component_subgraph(H: nx.Graph) -> nx.Graph:
    comp = max(nx.connected_components(H), key=len)
    return H.subgraph(comp).copy()

def double_sweep_path(H: nx.Graph, start=None):
    if H.number_of_nodes() == 0:
        return []

    if start is None:
        start = next(iter(H.nodes()))

    dist1 = nx.single_source_shortest_path_length(H, start)
    a = max(dist1, key=dist1.get)

    dist2 = nx.single_source_shortest_path_length(H, a)
    b = max(dist2, key=dist2.get)

    return nx.shortest_path(H, a, b)

def best_of_k_double_sweeps(H: nx.Graph, k: int = 50, seed: int = 42):
    rng = random.Random(seed)
    nodes = list(H.nodes())
    best = []

    for _ in range(max(1, k)):
        start = rng.choice(nodes)
        p = double_sweep_path(H, start=start)
        if len(p) > len(best):
            best = p

    return best

# Anwendung
G = kg_store.g
H = to_simple_undirected(G)
Hc = largest_component_subgraph(H)

path_nodes = best_of_k_double_sweeps(Hc, k=80, seed=1)
print("Komponentengröße:", Hc.number_of_nodes())
print("Pfadlänge in Kanten:", len(path_nodes) - 1)
print("Pfad:")
print(" -> ".join(path_nodes))


Komponentengröße: 4034
Pfadlänge in Kanten: 20
Pfad:
L1 to L4 -> psoas muscle -> iliacus muscle -> lesser trochanter of the femur -> psoas major -> L1-3 branches of the lumbar plexus -> iliacus -> Abdominal aorta -> psoas -> iliopsoas -> Psoas syndrome -> MET -> diagnosis -> laboratory evaluation -> liver function -> liver transaminases -> disseminated disease -> Herpes simplex virus -> neonatal HSV -> sepsis evaluation -> lumbar puncture


In [11]:
import random
import networkx as nx

def greedy_long_simple_path(H: nx.Graph, start, rng: random.Random) -> list:
    visited = {start}
    path = [start]
    cur = start

    while True:
        candidates = [n for n in H.neighbors(cur) if n not in visited]
        if not candidates:
            break

        # Heuristik: lieber in "schmale" Bereiche laufen, nicht sofort in Hubs
        candidates.sort(key=lambda n: H.degree(n))

        # kleines Randomness Fenster, damit es nicht immer identisch ist
        top = candidates[: min(5, len(candidates))]
        nxt = rng.choice(top)

        visited.add(nxt)
        path.append(nxt)
        cur = nxt

    return path

def best_greedy_simple_path(H: nx.Graph, trials: int = 200, seed: int = 42) -> list:
    rng = random.Random(seed)
    nodes = list(H.nodes())
    best = []

    for _ in range(max(1, trials)):
        start = rng.choice(nodes)
        p = greedy_long_simple_path(H, start, rng)
        if len(p) > len(best):
            best = p

    return best

# Anwendung
G = kg_store.g
H = to_simple_undirected(G)
Hc = largest_component_subgraph(H)

p = best_greedy_simple_path(Hc, trials=500, seed=7)
print("Komponentengröße:", Hc.number_of_nodes())
print("Greedy einfacher Pfad, Kanten:", len(p) - 1)
print("Pfad:")
print(" -> ".join(p))


Komponentengröße: 4034
Greedy einfacher Pfad, Kanten: 10
Pfad:
hyaluronic acid -> deep fascia -> elastin -> Marfan syndrome -> fasciopathy -> excessive running -> Plantar fasciopathy -> stiffening of connective tissue -> Adhesive capsulitis -> Increased fascial stiffness -> Peyronie disease


## 6. Pipeline and explainer


In [13]:
pipeline = KGRAGExPipeline(rag=multi_hop, kg=kg_store, llm_client=client)
explainer = KGRAGExExplainer(pipeline)

## 7. Load MedMCQA, small sample


In [14]:
import tomllib

with open("../config.toml", "rb") as f:
    config = tomllib.load(f)

allowed_ids = set(config["medmcqa"]["question_ids"])

mcqa_loader = MedMCQADataLoader()
questions = mcqa_loader.setup(split="train", as_documents=True, limit=10000)
print("geladen:", len(questions))

questions = [
    q for q in questions
    if q.metadata.get("question_id") in allowed_ids
]
print("gefiltert:", len(questions))
question = questions[20]

print(questions)


geladen: 10000
gefiltert: 59
[Document(metadata={'question_id': '7a10e05d-f20e-4a3f-8e3c-b6835c8e1b6b', 'question': 'Following is a clinical feature of cerebellar disease', 'answer': 'C', 'cop_raw': 3, 'split': 'train', 'source': 'medmcqa', 'subject_name': 'Medicine', 'topic_name': 'C.N.S', 'choice_type': 'single'}, page_content='Following is a clinical feature of cerebellar disease\n\nA: Paralysis\nB: Sensory deficit\nC: Ataxia\nD: Resting tremors\n\nExplanation: . *Dysmetria and ataxia are the two impoant symptoms of cerebellar disease. Cerebellar ataxia can occur as a result of many diseases and may present with symptoms of an inability to coordinate balance, gait, extremity and eye movements.Lesions to the cerebellum can cause dyssynergia, dysmetria, dysdiadochokinesia, dysahria and ataxia of stance and gait. Deficits are observed with movements on the same side of the body as the lesion (ipsilateral). Clinicians often use visual observation of people performing motor tasks in orde

## 8. Single-run test on one MedMCQA question


In [None]:
import re
import pandas as pd

def normalize_mc_answer(ans: str) -> str:
    """
    Normalisiert LLM Antworten auf einen einzelnen Buchstaben A/B/C/D.
    Akzeptiert z.B. "A", "A.", "A:", "A )", "Answer: A", "A. ..." usw.
    Gibt "" zurück, wenn kein eindeutiger Buchstabe gefunden wird.
    """
    s = (ans or "").strip().upper()
    if not s:
        return ""
    m = re.search(r"\b([ABCD])\b", s)
    return m.group(1) if m else ""

def is_valid_letter(x: str) -> bool:
    return x in {"A", "B", "C", "D"}

def build_context_judger_prompt(
    question: str,
    options_block: str,
    gold: str,
    answer: str,
    kg_chain: str,
    kg_paragraph: str,
) -> str:
    """
    Prompt für den zweiten LLM Client, der ausschließlich 0 oder 1 zurückgibt:
    1 = Kontext reicht plausibel zur Beantwortung
    0 = Kontext ist thematisch/faktisch nicht ausreichend oder driftet
    """
    return f"""
You are a strict evaluator.

Task:
Decide whether the provided KG context is plausibly sufficient to answer the multiple-choice question.

Return format:
Return ONLY a single character: 1 or 0.
- 1: The context (chain and paragraph) contains direct, relevant information that supports choosing the answer.
- 0: The context is irrelevant, too generic, drifts to other topics, or does not contain information needed to justify an option.

Be conservative:
If key medical concepts from the question/options are missing, return 0.
If the paragraph only talks about generic things like "pain", "treatment", "risk" without linking to the asked concept or options, return 0.

Input:
Question:
{question}

Options:
{options_block}

Gold:
{gold}

Model Answer:
{answer}

KG Chain:
{kg_chain}

KG Paragraph:
{kg_paragraph}
""".strip()

def judge_context_binary(
    llm_client_judge,
    question: str,
    options_block: str,
    gold: str,
    answer: str,
    kg_chain: str,
    kg_paragraph: str,
) -> tuple[int, str]:
    """
    Nutzt den zweiten LLM Client und gibt (0/1, raw_text) zurück.
    """
    llm = llm_client_judge.get_llm()
    if llm is None:
        raise RuntimeError("Judge LLMClient.get_llm() returned None")

    prompt = build_context_judger_prompt(
        question=question,
        options_block=options_block,
        gold=gold,
        answer=answer,
        kg_chain=kg_chain,
        kg_paragraph=kg_paragraph,
    )
    resp = llm.invoke(prompt)
    raw = (getattr(resp, "content", str(resp)) or "").strip()
    m = re.search(r"[01]", raw)
    val = int(m.group(0)) if m else 0
    return val, raw

llm_client_judge = LLMClient(provider="groq", model_name="llama-3.3-70b-versatile")
rows = []

for d in questions:
    meta = (d.metadata or {}).copy()
    qid = str(meta.get("question_id", meta.get("id", "q")))

    question = str(meta.get("question", "")).strip()
    if not question:
        txt0 = str(d.page_content or "").strip()
        question = (txt0.splitlines()[0].strip() if txt0 else "")

    opts = {"A": "", "B": "", "C": "", "D": ""}
    txt = str(d.page_content or "")
    for line in txt.splitlines():
        line = line.strip()
        if len(line) >= 2 and line[0] in "ABCD" and line[1] == ":":
            opts[line[0]] = line[2:].strip()

    options_block = "\n".join([f"{k}. {v}" for k, v in opts.items() if v])

    gold = str(meta.get("answer", "")).strip().upper()
    if not is_valid_letter(gold):
        cop = meta.get("cop", meta.get("cop_raw", None))
        gold = ""
        if cop is not None:
            try:
                cop_i = int(cop)
                if cop_i in (0, 1, 2, 3):
                    gold = "ABCD"[cop_i]
                elif cop_i in (1, 2, 3, 4):
                    gold = "ABCD"[cop_i - 1]
            except Exception:
                gold = ""

    run_i = pipeline.run(
        question_id=qid,
        question=question,
        gold_answer=gold,
        options=options_block,
    )
    if not (run_i.kg_chain or "").strip():
        continue
    ans_norm = normalize_mc_answer(run_i.answer)

    gold_ok = is_valid_letter(gold)
    ans_ok = is_valid_letter(ans_norm)
    matches_gold = bool(gold_ok and ans_ok and (ans_norm == gold))

    context_ok = None
    judge_raw = ""
    if matches_gold:
        context_ok, judge_raw = judge_context_binary(
            llm_client_judge=llm_client_judge,
            question=run_i.question,
            options_block=options_block,
            gold=gold,
            answer=ans_norm,
            kg_chain=run_i.kg_chain,
            kg_paragraph=run_i.kg_paragraph,
        )
    else:
        context_ok = 0

    rows.append(
        {
            "question_id": run_i.question_id,
            "question": run_i.question,
            "options": options_block,
            "gold": gold,
            "answer_raw": run_i.answer,
            "answer": ans_norm,
            "matches_gold": int(matches_gold),
            "context_ok": int(context_ok),
            "judge_raw": judge_raw,
            "kg_chain": run_i.kg_chain,
            "kg_paragraph": run_i.kg_paragraph,
            "llm_calls": run_i.llm_calls,
        }
    )

    df = pd.DataFrame(rows)
    df.to_csv("train_questions.csv")


Connecting to Groq (llama-3.3-70b-versatile)...


In [19]:
context_ok_ids = (
    df.loc[df["context_ok"] == 1, "question_id"]
    .dropna()
    .astype(str)
    .tolist()
)

for i in context_ok_ids:
    print(i)

15374e65-2f0e-4b06-9698-2a3fd852ef08
fc913366-8a30-442c-9769-0e4aa0ad5829
6e46a29a-3d34-4b35-bb4f-39693b5c608e
75bec15b-5bca-42e6-98fc-f48d9963ff4c
3f63787d-7816-48fe-a623-b61ba10a3001
bc3df9cd-8f28-4710-beb7-8861adc974be
7f444937-f1ae-403c-9427-34f6d5c18aa6
71c5caf2-bd33-41ff-89d6-3ccf06137dc1
f4a69fc6-54ef-4342-8fcc-3849ca6e8d72
dea5db70-e900-4d6e-8851-e3e81b9df816
6f0fb72d-ed2c-4dbb-8091-b839451b46a0
d5820f6b-b829-435f-87e5-98d5514000d0
dc47f7fa-731b-4715-89fa-fef4958eb57f
05454fa7-0d26-4f83-be39-cd51e4622f19
329f4086-c777-4455-9ff0-79e80575784a
967f4844-2959-42d2-9164-b0cbd55012bf
b261484b-635b-4abb-8301-80bee3da24f7
70d993ec-1ae5-4189-9c5e-2f77f45434ce
ca659bf0-9bf4-415d-afac-18a5b878a136
bc191b12-a85e-41ef-99b9-a1aad5fa6c37
ae4abedb-6625-4037-b810-f27189f4eda2
c227709c-cc31-460f-a5e5-fa5248dfa3d3
250a657a-e35d-4a5b-98d3-6399760b0209
d710d2f4-0d24-46b4-b50d-9633fe43b73f
775156bc-1cc6-4d4d-ae0d-852182dfc94e
f7c6e673-3268-4c7a-abf2-dff2426a1ae0
fdd71923-4db9-4133-b496-a386aba2d4b1
c

In [None]:
question_ids = [
    "7a10e05d-f20e-4a3f-8e3c-b6835c8e1b6b",
    "52fd537b-dc7f-4cf1-9a3a-6757ebd9ea31",
    "bf28775c-403b-48c9-a5ba-07f1217f17e3",
    "fb704dd-cd21-4196-8f4a-3e63dc7b341f",
    "30f2a6d9-f28f-48f0-8749-ee1de6a3d6c8",
    "2fb10eb6-7899-4770-ae74-fd810d176a22",
    "2102be8a-6b29-43cf-a259-0c9a4cd33f53",
    "3f020b60-da1a-40fa-93eb-462cbcc61ed2",
    "2120feb1-085e-4af7-8788-e4540f510f54",
    "4546a1ec-61a8-4cce-ad74-88b704bec98c",
    "f54f8baa-fab2-4eae-a6d5-2be43f1d2930",
    "5288e33d-a4df-4fd8-8498-7b0030b6fbc3",
    "329564f9-5916-4e24-b8b6-efa776d8f334",
    "1578420b-45fb-46b1-bc81-da223c9a9b4a",
    "693d59e9-1eee-4b5b-addd-f6963f1e69e6",
    "80a922e3-e55d-4cdc-8a90-3100c3647e99",
    "71453fe5-854a-4456-9012-04208d5132c2",
    "d4f06476-b47f-4bf7-9d32-531db1974a8e",
    "b452daac-d602-4482-a09a-6294405b6f61",
    "9d0968c7-0bfa-42c5-9e1f-1cf35c40b36e",
    "b00dd38c-f342-4272-b2e3-a9700b6ff406",
    "ee52d885-7ce4-4a98-a4d3-a0464b388072",
    "2e237c12-b3f6-4771-ac63-b2780e15be83",
    "fc7caec4-0208-4d96-8fd4-4f3a9037209b",
    "0d6d8296-aef0-4eed-9297-01ef78d699ae",
    "8409ee38-1922-4ac9-9178-ba699e33e643",
    "99d4ff2a-2513-480c-ac7e-1020b52c3bc3",
    "6ffd899a-4d4b-4216-b8ac-6e16a4b0daa1",
    "15374e65-2f0e-4b06-9698-2a3fd852ef08",
    "fc913366-8a30-442c-9769-0e4aa0ad5829",
    "6e46a29a-3d34-4b35-bb4f-39693b5c608e",
    "75bec15b-5bca-42e6-98fc-f48d9963ff4c",
    "3f63787d-7816-48fe-a623-b61ba10a3001",
    "bc3df9cd-8f28-4710-beb7-8861adc974be",
    "7f444937-f1ae-403c-9427-34f6d5c18aa6",
    "71c5caf2-bd33-41ff-89d6-3ccf06137dc1",
    "f4a69fc6-54ef-4342-8fcc-3849ca6e8d72",
    "dea5db70-e900-4d6e-8851-e3e81b9df816",
    "6f0fb72d-ed2c-4dbb-8091-b839451b46a0",
    "d5820f6b-b829-435f-87e5-98d5514000d0",
    "dc47f7fa-731b-4715-89fa-fef4958eb57f",
    "05454fa7-0d26-4f83-be39-cd51e4622f19",
    "329f4086-c777-4455-9ff0-79e80575784a",
    "967f4844-2959-42d2-9164-b0cbd55012bf",
    "b261484b-635b-4abb-8301-80bee3da24f7",
    "70d993ec-1ae5-4189-9c5e-2f77f45434ce",
    "ca659bf0-9bf4-415d-afac-18a5b878a136",
    "bc191b12-a85e-41ef-99b9-a1aad5fa6c37",
    "ae4abedb-6625-4037-b810-f27189f4eda2",
    "c227709c-cc31-460f-a5e5-fa5248dfa3d3",
    "250a657a-e35d-4a5b-98d3-6399760b0209",
    "d710d2f4-0d24-46b4-b50d-9633fe43b73f",
    "775156bc-1cc6-4d4d-ae0d-852182dfc94e",
    "f7c6e673-3268-4c7a-abf2-dff2426a1ae0",
    "fdd71923-4db9-4133-b496-a386aba2d4b1",
    "c2b01b67-109e-4017-b9e8-90158959db10",
    "9895028a-61bb-41d1-b589-b595c5d70706",
    "45ddd48d-d0c6-4100-ae86-01c170a203b6",
    "566ebfa3-bb52-48eb-bd99-bfddee74eab9"
]
print(len(question_ids))

print

59


In [88]:
question = questions[50]
qid = question.metadata.get("question_id")
question_str = question.metadata.get("question")
gold = question.metadata.get("answer")

opts = {"A": "", "B": "", "C": "", "D": ""}
txt = str(question.page_content or "")
for line in txt.splitlines():
    line = line.strip()
    if len(line) >= 2 and line[0] in "ABCD" and line[1] == ":":
        opts[line[0]] = line[2:].strip()

options_block = "\n".join([f"{k}. {v}" for k, v in opts.items() if v])

print(qid)
print(question_str)
print(gold)
print(options_block)

250a657a-e35d-4a5b-98d3-6399760b0209
A patient in ICU developed stress related mucosal damage. MOST common site of stress ulcer is:
B
A. Ileum
B. Stomach
C. Duodenum
D. Esophagus


In [90]:
valid_qs = []
for question in questions:
    qid = question.metadata.get("question_id")
    question_str = question.metadata.get("question")
    gold = question.metadata.get("answer")

    opts = {"A": "", "B": "", "C": "", "D": ""}
    txt = str(question.page_content or "")
    for line in txt.splitlines():
        line = line.strip()
        if len(line) >= 2 and line[0] in "ABCD" and line[1] == ":":
            opts[line[0]] = line[2:].strip()

    options_block = "\n".join([f"{k}. {v}" for k, v in opts.items() if v])

    run_m = pipeline.run(
        question_id=qid,
        question=question_str,
        gold_answer=gold,
        options=options_block
    )
    if len(run_m.path) > 0:
        valid_qs.append(qid)
'''
print("Entities:", run_m.entities)
print("Start, End:", run_m.start, run_m.end)
print("Path length:", len(run_m.path))
print("KG context preview:", (run_m.kg_context or "")[:400])
print("Answer:", run_m.answer)

rep_m = explainer.explain(run_m, options=options_block)
print("Most influential node:", rep_m.most_influential_node)
print("Most influential edge:", rep_m.most_influential_edge)
print("Most influential subpath:", rep_m.most_influential_subpath)
'''

Joined Paragrahhs:  
DEBUG shortest_path steps:
  subject: Clostridium difficile infection relation: PRESENTS_AS object: Diarrhea
DEBUG PathAsLists
  node_list: ['Clostridium difficile infection', 'Diarrhea']
  edge_list: ['PRESENTS_AS']
  subpath_list: [('Clostridium difficile infection', 'PRESENTS_AS', 'Diarrhea')]
  chain_str: Clostridium difficile infection->[PRESENTS_AS]->Diarrhea
Joined Paragrahhs:  *Clostridium difficile* infection, often abbreviated as CDI, is a bacterial infection caused by the overgrowth of the bacterium *Clostridium difficile* in the colon. This overgrowth frequently leads to inflammation and disruption of the normal gut flora, which is a primary driver of the condition. A hallmark symptom of CDI is diarrhea, resulting from the bacteria’s production of toxins that irritate the intestinal lining and stimulate fluid secretion.
MSG to LLM : You are a knowledgeable medical assistant.

        Use the following medical paragraph to answer the multiple-choice ques

'\nprint("Entities:", run_m.entities)\nprint("Start, End:", run_m.start, run_m.end)\nprint("Path length:", len(run_m.path))\nprint("KG context preview:", (run_m.kg_context or "")[:400])\nprint("Answer:", run_m.answer)\n\nrep_m = explainer.explain(run_m, options=options_block)\nprint("Most influential node:", rep_m.most_influential_node)\nprint("Most influential edge:", rep_m.most_influential_edge)\nprint("Most influential subpath:", rep_m.most_influential_subpath)\n'

In [93]:
print(valid_qs)

['52fd537b-dc7f-4cf1-9a3a-6757ebd9ea31', 'bf28775c-403b-48c9-a5ba-07f1217f17e3', '3f020b60-da1a-40fa-93eb-462cbcc61ed2', 'f54f8baa-fab2-4eae-a6d5-2be43f1d2930', '80a922e3-e55d-4cdc-8a90-3100c3647e99', 'fc913366-8a30-442c-9769-0e4aa0ad5829', '3f63787d-7816-48fe-a623-b61ba10a3001', 'dea5db70-e900-4d6e-8851-e3e81b9df816', 'bc191b12-a85e-41ef-99b9-a1aad5fa6c37', 'd710d2f4-0d24-46b4-b50d-9633fe43b73f', '775156bc-1cc6-4d4d-ae0d-852182dfc94e', '566ebfa3-bb52-48eb-bd99-bfddee74eab9']


In [31]:
#print("outcomes:", rep_m.outcomes)
for i in rep_m.outcomes:
    print(i)
print("rq1: ", rep_m.rq1_sensitivity)


rq1:  RQ1Sensitivity(node_changed=0, edge_changed=0, subpath_changed=0, node_total=0, edge_total=0, subpath_total=0)


## 11. Optional: compact stats for your RQs

RQ2 (position proxy) depends on explainer implementation details.  
RQ3 (node types) depends on which entity types your triplet extractor emits.  
RQ4 (graph metrics) depends on whether your metrics module computes degree and betweenness.

If your report exposes these fields, we summarize them here.


## 12. Try your own question

Use this to quickly check whether the KG has enough coverage.

Note:
If `path_len` is often 0, it's usually an entity extraction/matching or KG coverage issue, not a RAG issue.


In [None]:
qid = "manual-1"
question = "How does psoas muscle relates to lumbar puncture"

run_m = pipeline.run(
    question_id=qid,
    question=question,
    gold_answer=None,
    entity_k=6,
    top_k_docs=4,
)

print("Entities:", run_m.entities)
print("Start, End:", run_m.start, run_m.end)
print("Path length:", len(run_m.path))
print("KG context preview:", (run_m.kg_context or "")[:400])
print("Answer:", run_m.answer)

rep_m = explainer.explain(run_m)
print("Most influential node:", rep_m.most_influential_node)
print("Most influential edge:", rep_m.most_influential_edge)
print("Most influential subpath:", rep_m.most_influential_subpath)


Entities: ['psoas muscle', 'lumbar puncture']
Start, End: psoas muscle lumbar puncture
Path length: 19
KG context preview: The psoas muscle joins the iliacus muscle, and together they form the iliopsoas, which inserts at the lesser trochanter of the femur. The iliopsoas receives innervation from the L1-3 branches of the lumbar plexus, and its blood supply is derived from branches of the abdominal aorta. The psoas major muscle is part of this complex and is associated with the iliopsoas. Issues with the iliopsoas can c
Answer: The psoas muscle relates to lumbar puncture as it is located near the lumbar spine and its dysfunction can cause low back pain and other symptoms that may be relevant to the procedure, and the lumbar plexus which innervates the psoas muscle is also involved in the lumbar puncture procedure
Most influential node: iliacus muscle
Most influential edge: psoas muscle::JOINS::iliacus muscle
Most influential subpath: (psoas muscle, JOINS, iliacus muscle)


In [None]:
# RQ1
print("\nRQ1 sensitivity:")
print(rep_m.rq1_sensitivity)

# RQ2
print("\nRQ2 positions:")
for r in (rep_m.rq2_positions or []):
    print(r)

# RQ3
print("\nRQ3 node types:")
for r in (rep_m.rq3_node_types or []):
    print(r)

# RQ4
print("\nRQ4 graph metrics:")
for r in (rep_m.rq4_graph_metrics or []):
    print(r)



RQ1 sensitivity:
RQ1Sensitivity(node_changed=18, edge_changed=19, subpath_changed=19, node_total=18, edge_total=19, subpath_total=19)

RQ2 positions:
RQ2PositionRecord(kind='node', removed='iliacus muscle', index=1, rel_pos=0.05263157894736842, answer_changed=True)
RQ2PositionRecord(kind='node', removed='lesser trochanter of the femur', index=2, rel_pos=0.10526315789473684, answer_changed=True)
RQ2PositionRecord(kind='node', removed='lesser trochanter of the femur', index=2, rel_pos=0.10526315789473684, answer_changed=True)
RQ2PositionRecord(kind='node', removed='L1-3 branches of the lumbar plexus', index=4, rel_pos=0.21052631578947367, answer_changed=True)
RQ2PositionRecord(kind='node', removed='L1-3 branches of the lumbar plexus', index=4, rel_pos=0.21052631578947367, answer_changed=True)
RQ2PositionRecord(kind='node', removed='Abdominal aorta', index=6, rel_pos=0.3157894736842105, answer_changed=True)
RQ2PositionRecord(kind='node', removed='Abdominal aorta', index=6, rel_pos=0.3157

In [None]:
qid = "manual-1"
question = "How does Tension pneumothorax relates to Septicemia"

run_m = pipeline.run(
    question_id=qid,
    question=question,
    gold_answer=None,
    entity_k=6,
    top_k_docs=4,
)

print("Entities:", run_m.entities)
print("Start, End:", run_m.start, run_m.end)
print("Path length:", len(run_m.path))
print("KG context preview:", (run_m.kg_context or "")[:400])
print("Answer:", run_m.answer)

rep_m = explainer.explain(run_m)
print("Most influential node:", rep_m.most_influential_node)
print("Most influential edge:", rep_m.most_influential_edge)
print("Most influential subpath:", rep_m.most_influential_subpath)


Entities: ['Tension', 'pneumothorax', 'septicemia']
Start, End: None None
Path length: 0
KG context preview: 
Answer: I cannot determine how tension pneumothorax relates to septicemia based on the provided context.
Most influential node: None
Most influential edge: None
Most influential subpath: None


### Optional: print RQ metrics for the manual question

Surface the RQ fields directly for quick inspection.


In [None]:
print("sensitivity:", rep_m.rq1_sensitivity)
print("Postion:", rep_m.rq2_positions)
print("Node types:", rep_m.rq3_node_types)
print("graph_metrics:", rep_m.rq4_graph_metrics)

### Debug: graph connectivity between start/end nodes

Check whether the start/end nodes exist, inspect neighbors, and probe shortest paths.


In [None]:
import networkx as nx

G = kg_store.g
start = run_m.start
end = run_m.end

print("start exists:", start in G)
print("end exists:", end in G)

if start in G:
    print("\nStart neighbors (out):", list(G.successors(start))[:30])
    print("Start neighbors (in):", list(G.predecessors(start))[:30])

if end in G:
    print("\nEnd neighbors (out):", list(G.successors(end))[:30])
    print("End neighbors (in):", list(G.predecessors(end))[:30])

if start in G and end in G:
    GU = G.to_undirected(as_view=True)
    try:
        p = nx.shortest_path(GU, start, end)
        print("\nUndirected shortest path length:", len(p) - 1)
        print("Path nodes:", p[:20], "..." if len(p) > 20 else "")
    except Exception as e:
        print("\nNo undirected path:", e)

# Additional: search for similar node names
def find_similar_nodes(q, limit=20):
    ql = (q or "").lower()
    hits = [n for n in G.nodes if ql in str(n).lower()]
    return hits[:limit]

print("\nSimilar to 'CTO':", find_similar_nodes("CTO", 30))
print("Similar to 'Chronic Total Occlusion':", find_similar_nodes("Chronic Total Occlusion", 30))
print("Similar to 'Stroke':", find_similar_nodes("Stroke", 30))
print("Similar to 'Ischemic Heart Disease':", find_similar_nodes("Ischemic Heart Disease", 30))
