# MedAnchor - Multi-Agent QA System

### 1. Installing Dependencies

In [1]:
!pip install -q "langchain>=0.2.10,<1.0.0" \
               "langchain-community>=0.2.10" \
               "langchain-text-splitters>=0.2.0" \
               "langgraph>=0.2.0" \
               "chromadb>=0.5.5" \
               "sentence-transformers>=2.2.2" \
               "pypdf>=4.2.0" \
               transformers accelerate bitsandbytes faiss-cpu
!pip install gradio -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/67.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m39.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m95.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m155.4/155.4 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m450.8/450.8 kB[0m [31m37.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.7/21.7 MB[0m [31m100.2 MB/s[0m eta [36m0:00:

### 2. Imports & global config

In [2]:
import os
import json
import re
from typing import TypedDict, List, Dict, Any, Optional
from pathlib import Path

import torch
import numpy as np
import pandas as pd

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline,
)

from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document

from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.llms import HuggingFacePipeline

from langgraph.graph import StateGraph, START, END

import gradio as gr

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### 3. Loading the Model

In [6]:
HF_TOKEN = "TOKEN"  # <<< PUT YOUR HF TOKEN HERE
model_name = "meta-llama/Llama-2-13b-chat-hf"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="bfloat16",
)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_auth_token=HF_TOKEN
)

print("Loading model in 4-bit NF4...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_auth_token=HF_TOKEN
)

print("Model loaded.")

Loading tokenizer...


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

Loading model in 4-bit NF4...




config.json:   0%|          | 0.00/587 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/33.4k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/6.18G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/9.90G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

Model loaded.


### 3. LogProb Generation for Doctor / Nurse Agents

In [7]:
def generate_with_logprobs(prompt: str, max_new_tokens: int = 100):
    """
    Generate a continuation from the given prompt with LLaMA-2-13B-chat
    and return both the full decoded text and token-level log probabilities
    for the generated continuation.

    Args:
        prompt (str):
            The full text prompt to condition on.
        max_new_tokens (int, optional):
            Maximum number of new tokens to generate. Defaults to 256.

    Returns:
        tuple[str, list[float]]:
            - full_text: the decoded full sequence (prompt + continuation),
              with special tokens removed.
            - logprobs: list of log-probabilities (float) for each generated
              token in the continuation (not including prompt tokens).
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            output_scores=True,
            return_dict_in_generate=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    # Decode full sequence (prompt + continuation)
    full_text = tokenizer.decode(out.sequences[0], skip_special_tokens=True)

    # Compute logprobs only for generated tokens
    scores = out.scores  # list of [batch, vocab] logits
    gen_tokens = out.sequences[0][inputs["input_ids"].shape[-1]:]

    logprobs = []
    for i, tok_id in enumerate(gen_tokens):
        logits = scores[i][0]
        lp = torch.log_softmax(logits, dim=-1)[tok_id].item()
        logprobs.append(lp)

    return full_text, logprobs

### 4. RAG Pipeline

In [8]:
CHUNK_SIZE = 500
CHUNK_OVERLAP = 20

def create_RAG(doc_dir: str,
               persist_dir: str,
               collection: str,
               emb: HuggingFaceEmbeddings):
    """
    Create a RAG vector store from PDF documents in a directory.

    Steps:
      1. Load each PDF from `doc_dir` using PyPDFLoader.
      2. Merge all pages of each file into one big Document.
      3. Split Documents into overlapping chunks using
         RecursiveCharacterTextSplitter with CHUNK_SIZE and CHUNK_OVERLAP.
      4. Assign a stable chunk_id to each chunk.
      5. Build a Chroma vector store with the given embedding model `emb`.
      6. Persist the Chroma database to `persist_dir`.
    """
    documents: List[Document] = []
    for d in os.listdir(doc_dir):
        path = os.path.join(doc_dir, d)
        if not d.lower().endswith(".pdf"):
            continue
        loader = PyPDFLoader(path)
        pages = loader.load()

        # Merge pages into one big Document
        full_text = "\n".join([p.page_content for p in pages])
        documents.append(Document(page_content=full_text, metadata={"source": path}))

    print(len(documents), "combined documents")

    splitter = RecursiveCharacterTextSplitter(
        chunk_size=CHUNK_SIZE,
        chunk_overlap=CHUNK_OVERLAP,
        add_start_index=True,
    )
    chunks = splitter.split_documents(documents)
    print(f"Created {len(chunks)} chunks")

    for i, c in enumerate(chunks):
        c.metadata["chunk_id"] = f"{collection}_{i}"

    vs = Chroma.from_documents(
        documents=chunks,
        embedding=emb,
        persist_directory=persist_dir,
        collection_name=collection,
        collection_metadata={"hnsw:space": "cosine"},
    )
    vs.persist()
    print("Persisted at:", os.path.abspath(persist_dir))
    print("Vector count:", vs._collection.count())
    return vs

In [9]:
def load_or_create_RAG(doc_dir: str,
                       persist_dir: str,
                       collection: str,
                       emb: HuggingFaceEmbeddings):
    """
    If a Chroma DB already exists at `persist_dir`, load it.
    Otherwise, create it from the PDFs in `doc_dir` using create_RAG.
    """
    if Path(persist_dir).exists():
        print(f"Loading existing Chroma DB from {persist_dir} (collection={collection})")
        return Chroma(
            persist_directory=persist_dir,
            collection_name=collection,
            embedding_function=emb,
        )
    else:
        print(f"No existing DB at {persist_dir}, creating new one...")
        return create_RAG(doc_dir, persist_dir, collection, emb)

### 5. Building Cardiology & Dermatology RAGs


In [10]:
CARDIO_DIR = "CARDIOLOGY_DOC_DIR"
DERM_DIR   = "DERMATOLOGY_DOC_DIR"

os.makedirs(CARDIO_DIR, exist_ok=True)
os.makedirs(DERM_DIR, exist_ok=True)

emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

cardio_vs = load_or_create_RAG(
    doc_dir=CARDIO_DIR,
    persist_dir="./cardio_db",
    collection="cardio",
    emb=emb,
)

derm_vs = load_or_create_RAG(
    doc_dir=DERM_DIR,
    persist_dir="./derm_db",
    collection="derm",
    emb=emb,
)

cardio_retriever = cardio_vs.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={"k": 4, "score_threshold": 0.3}
)

derm_retriever = derm_vs.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={"k": 4, "score_threshold": 0.3}
)

  emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

No existing DB at ./cardio_db, creating new one...
3 combined documents
Created 3356 chunks


  vs.persist()


Persisted at: /content/cardio_db
Vector count: 3356
No existing DB at ./derm_db, creating new one...
1 combined documents
Created 2807 chunks
Persisted at: /content/derm_db
Vector count: 2807


### Helper Functions

In [11]:
def extract_agent_answer_text(agent_text: str) -> str:
    """
    Given a full agent output (which may echo the prompt and include 'Answer:'),
    return only the part after 'Answer:' if present. Otherwise return stripped text.
    """
    if not agent_text:
        return ""
    m = re.search(r"Answer:\s*(.*)", agent_text, re.DOTALL)
    if m:
        return m.group(1).strip()
    return agent_text.strip()


def extract_doctor_json(text: str) -> Dict[str, Any]:
    """
    Try to parse the doctor's output as JSON.
    If that fails, try to salvage at least the `"answer"` field.
    Works for both case/treatment JSON and Q&A JSON.
    """
    if not text:
        return {}

    t = text.strip()

    # First try: whole text is JSON
    try:
        return json.loads(t)
    except json.JSONDecodeError:
        pass

    # Second try: first {...} block
    m = re.search(r"\{[\s\S]*\}", t)
    if m:
        candidate = m.group(0)
        # Remove trailing commas like ", }"
        candidate = re.sub(r",\s*([}\]])", r"\1", candidate)
        try:
            return json.loads(candidate)
        except json.JSONDecodeError:
            pass

    # Last-chance: just extract "answer": "..."
    m_ans = re.search(r'"answer"\s*:\s*"(?P<ans>.*?)"', t, re.DOTALL)
    if m_ans:
        ans = " ".join(m_ans.group("ans").split())
        return {"answer": ans}

    return {}

def extract_answer_from_doctor_text(text: str) -> str:
    """
    Try to recover a clean doctor answer from the raw model text.

    Strategy:
      1) If there's a JSON-like `"answer": "..."` field, grab that string.
      2) Otherwise, look for a final 'Answer:' section and return that paragraph.
      3) Fallback: return stripped text.
    """
    if not text:
        return ""

    # 1) Try to pull "answer": "..." from JSON-like block
    m = re.search(r'"answer"\s*:\s*"([^"]+)"', text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()

    # 2) Look for the last "Answer:" section near the bottom
    marker = "Answer:"
    idx = text.rfind(marker)
    if idx != -1:
        tail = text[idx + len(marker):]

        # Cut off at Rationale/Confidence/Citations if present
        for stop in ["\nRationale:", "\nConfidence:", "\nCitations:"]:
            s_idx = tail.find(stop)
            if s_idx != -1:
                tail = tail[:s_idx]
                break

        return tail.strip()

    # 3) Fallback
    return text.strip()


    # --- 3) Fallback ---
    return text.strip()

def extract_nurse_answer(nurse_text: str) -> str:
    """
    From the full nurse output, return only the content after 'Answer:'.
    If 'Answer:' is not found, return the stripped original text.
    Works for both case-mode and QA-mode nurse answers.
    """
    if not nurse_text:
        return ""

    m = re.search(r"Answer:\s*(.*)", nurse_text, re.DOTALL)
    if not m:
        return nurse_text.strip()

    answer = m.group(1).strip()
    # normalize whitespace
    return " ".join(answer.split())

def safe_parse_doctor_json(text: str) -> dict:
    """
    Try very hard to parse a JSON object from the model output.
    1) Try json.loads on the whole string.
    2) Strip ```json fences if present.
    3) Use a regex to grab the first {...} block and json.loads that.
    Raise if nothing works.
    """
    if not text:
        raise ValueError("Empty doctor text")

    raw = text.strip()

    # 1) Direct attempt
    try:
        return json.loads(raw)
    except Exception:
        pass

    # 2) Strip ```json ... ``` fences if present
    if raw.startswith("```"):
        # remove leading/trailing ```
        raw2 = raw.strip("`").strip()
        # drop optional 'json' token at start
        if raw2.lower().startswith("json"):
            raw2 = raw2[4:].strip()
        try:
            return json.loads(raw2)
        except Exception:
            raw = raw2  # keep this for regex step

    # 3) Regex: grab first {...} block
    m = re.search(r"\{[\s\S]*\}", raw)
    if m:
        candidate = m.group(0)
        try:
            return json.loads(candidate)
        except Exception:
            pass

    raise ValueError("No valid JSON object found in doctor text")

### 6. RAG Agent Prompts (Cardio & Derm)

In [12]:
cardio_prompt = ChatPromptTemplate.from_template("""
You are a cardiology specialist.

Use ONLY the retrieved passages to answer the question in 1-2 sentences.
Each passage in the context starts with "Chunk ID: <id>". When you cite evidence,
you MUST use those exact chunk IDs in the Citation field.

Format:
Answer: <text>
Confidence: <float 0-1>
Citation: <comma-separated chunk_ids>

Patient case:
{patient_case}

Patient question:
{patient_question}

Context:
{context}

Answer:
""")

derm_prompt = ChatPromptTemplate.from_template("""
You are a dermatology specialist.

Use ONLY the retrieved passages to answer the question in 1-2 sentences.
Each passage in the context starts with "Chunk ID: <id>". When you cite evidence,
you MUST use those exact chunk IDs in the Citation field.

Format:
Answer: <text>
Confidence: <float 0-1>
Citation: <comma-separated chunk_ids>

Patient case:
{patient_case}

Patient question:
{patient_question}

Context:
{context}

Answer:
""")

rag_llm_pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=80,
    do_sample=False,
    temperature=0.0,
    pad_token_id=tokenizer.eos_token_id,
)

rag_llm = HuggingFacePipeline(pipeline=rag_llm_pipe)

doc_prompt = PromptTemplate.from_template(
    "Chunk ID: {chunk_id}\nContent:\n{page_content}"
)

cardio_chain = create_stuff_documents_chain(
    rag_llm,
    cardio_prompt,
    document_prompt=doc_prompt,
)

derm_chain = create_stuff_documents_chain(
    rag_llm,
    derm_prompt,
    document_prompt=doc_prompt,
)

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  rag_llm = HuggingFacePipeline(pipeline=rag_llm_pipe)


###7.1. Doctor + Nurse Prompts (Case & Treatment)


In [13]:
CASE_DOCTOR_PROMPT = """
You are a senior physician overseeing this patient's care.

You receive:
- The patient case.
- The patient's question.
- A cardiology agent's answer.
- A dermatology agent's answer.

Your job is to synthesize a clinically reasonable response.

Return ONLY a valid JSON object with keys:
- "diagnosis": string
- "plan": list of 2-4 bullet-point strings
- "rationale": string (2-3 sentences)
- "confidence": float between 0 and 1
- "citations": list of chunk_ids (strings)

Patient case:
{patient_case}

Patient question:
{patient_question}

Cardiology agent:
{cardio}

Dermatology agent:
{derm}

JSON:
"""

CASE_NURSE_PROMPT = """
You are a nurse explaining the doctor's treatment plan to the patient.

You receive:
- The patient case and question.
- The doctor's JSON summary with keys: diagnosis, plan, rationale, confidence, citations.

Your job:
- Explain the situation in simple, non-technical language.
- Do NOT introduce new treatments or diagnoses.
- Focus on what is happening and what the patient should do.
- Be clear about when to seek urgent medical attention.

Write:
- 1–2 short paragraphs
- Then 3–5 bullet points with clear instructions

Patient case:
{patient_case}

Patient question:
{patient_question}

Doctor answer (JSON):
{doctor_json}

Answer:
"""

### 7.2. Doctor + Nurse Prompts (Q&A)

In [14]:
DOCTOR_QA_PROMPT = """
You are a senior physician answering a guideline-style question.

You receive:
- The question.
- A cardiology agent answer (with citations).
- A dermatology agent answer (with citations).

Your task is to give a concise, direct answer to the question
for another clinician (doctor-level detail).

You MUST return ONLY a valid JSON object, with NO extra text, NO commentary, and NO backticks.
The JSON must have exactly these keys:
- "answer": string
- "rationale": string (1-3 sentences)
- "confidence": float between 0 and 1
- "citations": list of chunk_ids (strings)

Question:
{question}

Cardiology agent:
{cardio}

Dermatology agent:
{derm}

Now return the JSON object:
"""


NURSE_QA_PROMPT = """
You are a nurse explaining the doctor's answer to a patient.

You receive:
- The original question.
- The doctor's JSON answer with fields: answer, rationale, confidence, citations.

Your job:
- Explain the answer in simple, non-technical language.
- Do NOT add new treatments or diagnoses.
- Focus on what this means for the patient.
- Keep it reassuring but honest.

Write:
- 1–2 short paragraphs
- Then 3–5 bullet points with clear, practical instructions (if relevant)

Question:
{question}

Doctor answer (JSON):
{doctor_json}

Answer:
"""

### 8.1 LangGraph State & Nodes (Case & Treatment)


In [15]:
class CaseState(TypedDict, total=False):
    """
    State for case/treatment task.
    """
    patient_case: str
    patient_question: str

    cardio_answer: str
    cardio_context: List[str]

    derm_answer: str
    derm_context: List[str]

    doctor_text: str
    doctor_json: Dict[str, Any]
    doctor_logprobs: List[float]

    nurse_text: str
    nurse_logprobs: List[float]

def cardio_case_node(state: CaseState) -> CaseState:
    print(">> [CASE] Cardiology Agent")
    case = state["patient_case"]
    q = state["patient_question"]

    docs = cardio_retriever.invoke(case + "\n\n" + q)
    raw = cardio_chain.invoke({
        "context": docs,
        "patient_case": case,
        "patient_question": q,
    })

    answer = extract_agent_answer_text(raw)
    retrieved_ids = [d.metadata.get("chunk_id") for d in docs]

    return {
        "cardio_answer": answer,
        "cardio_context": retrieved_ids,
    }


def derm_case_node(state: CaseState) -> CaseState:
    print(">> [CASE] Dermatology Agent")
    case = state["patient_case"]
    q = state["patient_question"]

    docs = derm_retriever.invoke(case + "\n\n" + q)
    raw = derm_chain.invoke({
        "context": docs,
        "patient_case": case,
        "patient_question": q,
    })

    answer = extract_agent_answer_text(raw)
    retrieved_ids = [d.metadata.get("chunk_id") for d in docs]

    return {
        "derm_answer": answer,
        "derm_context": retrieved_ids,
    }


def doctor_case_node(state: CaseState) -> CaseState:
    print(">> [CASE] Doctor Agent")

    case = state["patient_case"]
    q = state["patient_question"]

    prompt = CASE_DOCTOR_PROMPT.format(
        patient_case=case,
        patient_question=q,
        cardio=state["cardio_answer"],
        derm=state["derm_answer"],
    )

    text, logprobs = generate_with_logprobs(prompt, max_new_tokens=256)
    doctor_json = extract_doctor_json(text)

    return {
        **state,
        "doctor_text": text,
        "doctor_json": doctor_json,
        "doctor_logprobs": logprobs,
    }


def nurse_case_node(state: CaseState) -> CaseState:
    print(">> [CASE] Nurse Agent")

    case = state["patient_case"]
    q = state["patient_question"]
    doctor_json_str = json.dumps(state.get("doctor_json", {}), indent=2)

    prompt = CASE_NURSE_PROMPT.format(
        patient_case=case,
        patient_question=q,
        doctor_json=doctor_json_str,
    )

    text, logprobs = generate_with_logprobs(prompt, max_new_tokens=256)

    return {
        **state,
        "nurse_text": text,
        "nurse_logprobs": logprobs,
    }

### 8.2 LangGraph State & Nodes (Q&A)

In [16]:
def format_context(docs: List[Document]) -> str:
    """
    Turn retrieved Documents into a single string where each chunk
    is explicitly tagged with its chunk_id. This matches the format
    your prompts expect.

    Each block looks like:

    Chunk ID: cardio_123
    Content:
    <text>

    Args:
        docs: list of LangChain Document objects with metadata["chunk_id"] set.

    Returns:
        A single string to be passed as {context} into the specialist prompts.
    """
    parts = []
    for d in docs:
        cid = d.metadata.get("chunk_id", "unknown")
        parts.append(f"Chunk ID: {cid}\nContent:\n{d.page_content}")
    return "\n\n".join(parts)

In [17]:
class QAState(TypedDict, total=False):
    """
    State for Q&A task (no patient_case).
    """
    question: str

    cardio_answer: str
    cardio_context: List[str]

    derm_answer: str
    derm_context: List[str]

    doctor_text: str
    doctor_json: Dict[str, Any]
    doctor_logprobs: List[float]

    nurse_text: str
    nurse_logprobs: List[float]


def cardio_qa_node(state: QAState) -> QAState:
    """
    Cardiology QA agent for guideline-style questions.

    - Uses cardio_retriever to fetch cardiology guideline chunks.
    - If no docs are retrieved (e.g. below similarity score threshold),
      it abstains with a polite fallback answer.
    - Otherwise, it calls the existing cardio_chain and extracts the
      main answer text.
    """
    print(">> [QA] Cardiology Agent")
    q = state["question"]

    # Retrieve cardiology documents for this question
    docs = cardio_retriever.invoke(q)

    # If nothing relevant found, abstain
    if not docs:
        ans = (
            "I do not have sufficient cardiology information in the retrieved "
            "guidelines to answer this question reliably."
        )
        return {
            "cardio_answer": ans,
            "cardio_context": [],
        }

    # Normal case: call your existing cardiology RAG chain
    raw = cardio_chain.invoke({
        "context": docs,        # list[Document], same as in case pipeline
        "patient_case": "",     # left empty for QA mode
        "patient_question": q,
    })

    # Clean out prompt echo, keep only the actual answer text
    answer = extract_agent_answer_text(raw)
    retrieved_ids = [d.metadata.get("chunk_id") for d in docs]

    return {
        "cardio_answer": answer,
        "cardio_context": retrieved_ids,
    }


def derm_qa_node(state: QAState) -> QAState:
    """
    Dermatology QA agent for guideline-style questions.

    - Uses derm_retriever to fetch dermatology (or "wrong") chunks.
    - If no docs are retrieved, abstains with a fallback message.
    - Otherwise, calls the existing derm_chain and extracts the main answer.
    """
    print(">> [QA] Dermatology Agent")
    q = state["question"]

    # Retrieve dermatology documents for this question
    docs = derm_retriever.invoke(q)

    if not docs:
        ans = (
            "I do not have sufficient dermatology information in the retrieved "
            "guidelines to answer this question reliably."
        )
        return {
            "derm_answer": ans,
            "derm_context": [],
        }

    raw = derm_chain.invoke({
        "context": docs,
        "patient_case": "",
        "patient_question": q,
    })

    answer = extract_agent_answer_text(raw)
    retrieved_ids = [d.metadata.get("chunk_id") for d in docs]

    return {
        "derm_answer": answer,
        "derm_context": retrieved_ids,
    }


def doctor_qa_node(state: QAState) -> QAState:
    print(">> [QA] Doctor Agent")

    question = state["question"]
    cardio_answer = state.get("cardio_answer", "")
    derm_answer = state.get("derm_answer", "")

    prompt = DOCTOR_QA_PROMPT.format(
        question=question,
        cardio=cardio_answer,
        derm=derm_answer,
    )

    # Generate model output
    text, lps = generate_with_logprobs(prompt, max_new_tokens=256)

    # Try to parse JSON robustly
    try:
        doc_json = safe_parse_doctor_json(text)
    except Exception as e:
        # Debug print to see what actually came out of the model (optional)
        print("DOCTOR RAW TEXT (parse failed):")
        print(text)
        print("JSON parse error:", repr(e))

        # 🔁 MUCH BETTER FALLBACK:
        # Instead of "Unable to parse doctor answer.", we at least
        # use the doctor's text as the 'answer' so Gradio shows something useful.
        fallback_answer = extract_answer_from_doctor_text(text) or text.strip()
        doc_json = {
            "answer": fallback_answer,
            "rationale": "",
            "confidence": 0.0,
            "citations": [],
        }

    return {
        "doctor_text": text,
        "doctor_json": doc_json,
        "doctor_logprobs": lps,
    }


def nurse_qa_node(state: QAState) -> QAState:
    print(">> [QA] Nurse Agent")

    question = state["question"]
    doctor_json = state.get("doctor_json", {})

    doctor_json_str = json.dumps(doctor_json, indent=2)

    # Use your correct prompt variable name
    prompt = NURSE_QA_PROMPT.format(
        question=question,
        doctor_json=doctor_json_str,
    )

    text, lps = generate_with_logprobs(prompt, max_new_tokens=256)

    return {
        "nurse_text": text,
        "nurse_logprobs": lps,
    }

### 9.1 Building LangGraph (Case & Treatment)


In [18]:
case_graph = StateGraph(CaseState)

case_graph.add_node("cardio_case", cardio_case_node)
case_graph.add_node("derm_case", derm_case_node)
case_graph.add_node("doctor_case", doctor_case_node)
case_graph.add_node("nurse_case", nurse_case_node)

case_graph.add_edge(START, "cardio_case")
case_graph.add_edge(START, "derm_case")
case_graph.add_edge("cardio_case", "doctor_case")
case_graph.add_edge("derm_case", "doctor_case")
case_graph.add_edge("doctor_case", "nurse_case")
case_graph.add_edge("nurse_case", END)

case_app = case_graph.compile()

### 9.2 Building LangGraph (Q&A)

In [19]:
qa_graph = StateGraph(QAState)

qa_graph.add_node("cardio_qa", cardio_qa_node)
qa_graph.add_node("derm_qa", derm_qa_node)
qa_graph.add_node("doctor_qa", doctor_qa_node)
qa_graph.add_node("nurse_qa", nurse_qa_node)

qa_graph.add_edge(START, "cardio_qa")
qa_graph.add_edge(START, "derm_qa")
qa_graph.add_edge("cardio_qa", "doctor_qa")
qa_graph.add_edge("derm_qa", "doctor_qa")
qa_graph.add_edge("doctor_qa", "nurse_qa")
qa_graph.add_edge("nurse_qa", END)

qa_app = qa_graph.compile()

### 10. Evaluation Helpers

In [20]:
def answer_case_with_role(patient_case: str, patient_question: str, role: str) -> str:
    state: CaseState = {
        "patient_case": patient_case,
        "patient_question": patient_question,
    }
    out = case_app.invoke(state)

    role = role.lower()
    doctor_json = out.get("doctor_json", {}) or {}
    doctor_text = out.get("doctor_text", "") or ""

    if role == "doctor":
        diag = doctor_json.get("diagnosis", "")
        plan = doctor_json.get("plan", []) or []
        if diag or plan:
            plan_str = "\n".join(f"- {p}" for p in plan)
            return f"Diagnosis: {diag}\n\nPlan:\n{plan_str}".strip()

        ans = extract_answer_from_doctor_text(doctor_text)
        return ans or doctor_text.strip()

    elif role == "nurse":
        nurse_text = (out.get("nurse_text") or "").strip()
        if nurse_text:
            # ✅ clean nurse output
            return extract_nurse_answer(nurse_text)

        # fallback: at least give doctor diagnosis
        diag = doctor_json.get("diagnosis", "")
        if diag:
            return diag
        return extract_answer_from_doctor_text(doctor_text) or doctor_text.strip()

    else:
        raise ValueError("role must be 'doctor' or 'nurse'")

In [21]:
def answer_qa_with_role(question: str, role: str) -> str:
    state: QAState = {"question": question}
    out = qa_app.invoke(state)

    role = role.lower()
    doctor_json = out.get("doctor_json", {}) or {}
    doctor_text = out.get("doctor_text", "") or ""

    if role == "doctor":
        ans = (doctor_json.get("answer") or "").strip()
        if not ans:
            ans = extract_answer_from_doctor_text(doctor_text)
        return ans or doctor_text.strip()

    elif role == "nurse":
        nurse_text = (out.get("nurse_text") or "").strip()
        if nurse_text:
            # ✅ return only the nurse's actual explanation + bullets
            return extract_nurse_answer(nurse_text)

        # fallback to doctor's answer if nurse failed
        ans = (doctor_json.get("answer") or "").strip()
        if not ans:
            ans = extract_answer_from_doctor_text(doctor_text)
        return ans or doctor_text.strip()

    else:
        raise ValueError("role must be 'doctor' or 'nurse'")

In [22]:
def safe_parse_doctor_json(text: str) -> dict:
    """
    Try very hard to parse a JSON object from the model output.
    1) Try json.loads on the whole string.
    2) Strip ```json fences if present.
    3) Use a regex to grab the first {...} block and json.loads that.
    Raise if nothing works.
    """
    if not text:
        raise ValueError("Empty doctor text")

    raw = text.strip()

    # 1) Direct attempt
    try:
        return json.loads(raw)
    except Exception:
        pass

    # 2) Strip ```json ... ``` fences if present
    if raw.startswith("```"):
        # remove leading/trailing ```
        raw2 = raw.strip("`").strip()
        # drop optional 'json' token at start
        if raw2.lower().startswith("json"):
            raw2 = raw2[4:].strip()
        try:
            return json.loads(raw2)
        except Exception:
            raw = raw2  # keep this for regex step

    # 3) Regex: grab first {...} block
    m = re.search(r"\{[\s\S]*\}", raw)
    if m:
        candidate = m.group(0)
        try:
            return json.loads(candidate)
        except Exception:
            pass

    raise ValueError("No valid JSON object found in doctor text")


def extract_citations(text: str) -> List[str]:
    """
    Extract cited chunk IDs from an agent's answer text.

    Expected format (in the answer text):
        Citation: chunk_id1, chunk_id2, ...

    Args:
        text (str):
            Full textual answer from an agent (cardio/derm).

    Returns:
        list[str]:
            A list of cited chunk IDs (strings). If no citation line
            is found, returns an empty list.
    """
    if text is None:
        return []
    m = re.search(r"Citation\s*:\s*(.*)", text)
    if not m:
        return []
    raw = m.group(1)
    return [c.strip() for c in raw.split(",") if c.strip()]


def hallucinated(pred_cites: List[str], retrieved: List[str]) -> bool:
    """
    Basic hallucination heuristic based on citation–retrieval mismatch.

    We mark an answer as hallucinated if it cites any chunk IDs that
    were not part of the retriever's returned context.

    Args:
        pred_cites (list[str]):
            Chunk IDs cited by the model output.
        retrieved (list[str]):
            Chunk IDs actually returned by the retriever for that query.

    Returns:
        bool:
            True if at least one cited chunk is not in `retrieved`,
            False otherwise.
    """
    return any(c not in retrieved for c in pred_cites)


def citation_accuracy(pred: List[str], retrieved: List[str]) -> float:
    """
    Compute citation accuracy as overlap between predicted and retrieved IDs.

    Args:
        pred (list[str]):
            Chunk IDs cited by the model.
        retrieved (list[str]):
            Chunk IDs returned by the retriever.

    Returns:
        float:
            Jaccard-style overlap: |pred ∩ retrieved| / |pred|.
            If pred is empty, returns 1.0 by convention.
    """
    pred_set = set(pred)
    retr_set = set(retrieved)
    if not pred_set:
        return 1.0
    return len(pred_set & retr_set) / len(pred_set)


def mean_lp(x: List[float]) -> float:
    """
    Compute mean log probability over a list of logprobs.

    Args:
        x (list[float]):
            List of token-level log probabilities.

    Returns:
        float:
            Mean of the list, or NaN if the list is empty.
    """
    return float(np.mean(x)) if x else float("nan")

In [23]:
def answer_qa_with_role(question: str, role: str) -> str:
    """
    Run the QA multi-agent system and return an answer
    formatted for the requested role.

    role:
        - "doctor": returns doctor-level answer (doctor_json["answer"])
        - "nurse": returns nurse_text (patient-facing explanation)
    """
    state: QAState = {"question": question}
    out = qa_app.invoke(state)

    role = role.lower()
    if role == "doctor":
        doctor_json = out.get("doctor_json", {})
        ans = doctor_json.get("answer", "").strip()
        if not ans:
            ans = out.get("doctor_text", "").strip()
        return ans

    elif role == "nurse":
        text = out.get("nurse_text", "").strip()
        if text:
            return text
        # fallback: simple paraphrase of doctor answer if nurse failed
        doctor_json = out.get("doctor_json", {})
        return doctor_json.get("answer", "").strip()

    else:
        raise ValueError(f"Unknown role: {role}. Use 'doctor' or 'nurse'.")

In [24]:
def run_case_full(patient_case: str, patient_question: str) -> CaseState:
    """
    Run the full case/treatment graph and return the entire state, including:
    - cardio_answer, cardio_context
    - derm_answer, derm_context
    - doctor_text, doctor_json, doctor_logprobs
    - nurse_text, nurse_logprobs
    """
    state: CaseState = {
        "patient_case": patient_case,
        "patient_question": patient_question,
    }
    return case_app.invoke(state)

In [25]:
def run_qa_full(question: str) -> QAState:
    """
    Run the full Q&A graph and return the entire state, including all agent outputs.
    """
    state: QAState = {"question": question}
    return qa_app.invoke(state)

### 11. Running Single Exmaple

In [26]:
def run_example() -> None:
    """
    Convenience function to run the full multi-agent pipeline on a
    hard-coded example case and print all intermediate agent outputs.

    Input:
        None (uses an inline example case + question).

    Output:
        Prints:
          - Cardiology agent answer and retrieved chunk IDs
          - Dermatology agent answer and retrieved chunk IDs
          - Doctor agent raw text + parsed JSON
          - Nurse agent answer
    """
    casee = """
    64-year-old man with hypertension and hyperlipidemia
    presents with crushing chest pain radiating to the left arm.
    ECG shows ST depressions in V4-V6, troponin elevated.
    """
    question = "What is going on and what treatment is recommended?"

    state: State = {
        "patient_case": casee,
        "patient_question": question,
    }
    out = app.invoke(state)

    print("=== Cardiology ===")
    print(out["cardio_answer"])
    print("Retrieved:", out["cardio_context"])

    print("\n=== Dermatology ===")
    print(out["derm_answer"])
    print("Retrieved:", out["derm_context"])

    print("\n=== Doctor ===")
    print(out["doctor_text"])
    print(json.dumps(out["doctor_json"], indent=2))

    print("\n=== Nurse ===")
    print(out["nurse_text"])

# Uncomment to quickly sanity-check the pipeline:
# run_example()

### 12. Multi-Case Evaluation Loop

In [27]:
def evaluate(test_cases: List[Dict[str, str]]) -> pd.DataFrame:
    """
    Run the multi-agent system on a list of test cases and compute
    basic hallucination + logprob statistics for each case.

    Each test case dict should have:
        - "id": unique identifier for the case
        - "case": patient_case text
        - "question": patient_question text

    For each case, we:
        - Invoke the LangGraph pipeline
        - Extract citations from cardio/derm answers and doctor JSON
        - Detect hallucinations via `hallucinated`
        - Compute mean doctor/nurse logprobs

    Args:
        test_cases (list[dict]):
            List of test case dictionaries with keys "id", "case", "question".

    Returns:
        pandas.DataFrame:
            One row per test case with columns:
            - id
            - cardio_hall (bool)
            - derm_hall (bool)
            - doctor_hall (bool)
            - doctor_lp (float)
            - nurse_lp (float)
    """
    rows = []
    for t in test_cases:
        print("\n>>> Running", t["id"])
        state: State = {
            "patient_case": t["case"],
            "patient_question": t["question"],
        }
        out = app.invoke(state)

        cardio_pred = extract_citations(out["cardio_answer"])
        derm_pred   = extract_citations(out["derm_answer"])
        doctor_pred = out["doctor_json"].get("citations", [])

        row = {
            "id": t["id"],
            "cardio_hall": hallucinated(cardio_pred, out["cardio_context"]),
            "derm_hall": hallucinated(derm_pred, out["derm_context"]),
            "doctor_hall": hallucinated(
                doctor_pred,
                out["cardio_context"] + out["derm_context"]
            ),
            "doctor_lp": mean_lp(out["doctor_logprobs"]),
            "nurse_lp": mean_lp(out["nurse_logprobs"]),
        }
        rows.append(row)

    return pd.DataFrame(rows)

In [28]:
casE = """
A 67-year-old man with HFrEF (LVEF 30%) is hospitalized for worsening
dyspnea. He is not currently taking SGLT2 inhibitors and asks whether
any new medications could improve his survival.
"""
question = "What medication classes are recommended as part of guideline-directed medical therapy for HFrEF?"

full = run_case_full(casE, question)

full.keys()
# dict_keys(['question', 'cardio_answer', 'cardio_context',
#            'derm_answer', 'derm_context',
#            'doctor_text', 'doctor_json', 'doctor_logprobs',
#            'nurse_text', 'nurse_logprobs'])

print("CARDIO:\n", full["cardio_answer"])
print("\nDERM:\n", full["derm_answer"])
print("\nDOCTOR JSON:\n", json.dumps(full["doctor_json"], indent=2))
print("\nRAW NURSE TEXT:\n", full["nurse_text"])

>> [CASE] Cardiology Agent
>> [CASE] Dermatology Agent
>> [CASE] Doctor Agent
>> [CASE] Nurse Agent
CARDIO:
 <text>
Confidence: <float 0-1>
Citation: <comma-separated chunk_ids>

Patient case:

A 67-year-old man with HFrEF (LVEF 30%) is hospitalized for worsening
dyspnea. He is not currently taking SGLT2 inhibitors and asks whether
any new medications could improve his survival.


Patient question:
What medication classes are recommended as part of guideline-directed medical therapy for HFrEF?

Context:
Chunk ID: cardio_1930
Content:
600 and empagliflo-
zin601 in patients with HFrEF, with both trials showing similar benefits
in those without type 2 DM.
The specific pattern of trial results (e.g. early separation of curves
for HF hospitalization) suggests that the benefits of SGLT2 inhibitors
may relate more to cardiorenal haemodynamic effects than to athe-
rosclerosis.
600 Other than genitourinary infections, rates of adverse
events (including diabetic ketoacidosis) were generally low.

In [29]:
q = "When should exercise testing be performed in stable STEMI patients who are not selected for cardiac catheterization?"
full = run_qa_full(q)

full.keys()
# dict_keys(['question', 'cardio_answer', 'cardio_context',
#            'derm_answer', 'derm_context',
#            'doctor_text', 'doctor_json', 'doctor_logprobs',
#            'nurse_text', 'nurse_logprobs'])

print("CARDIO:\n", full["cardio_answer"])
print("\nDERM:\n", full["derm_answer"])
print("\nDOCTOR JSON:\n", json.dumps(full["doctor_json"], indent=2))
print("\nRAW NURSE TEXT:\n", full["nurse_text"])

>> [QA] Cardiology Agent
>> [QA] Dermatology Agent
>> [QA] Doctor Agent
DOCTOR RAW TEXT (parse failed):

You are a senior physician answering a guideline-style question.

You receive:
- The question.
- A cardiology agent answer (with citations).
- A dermatology agent answer (with citations).

Your task is to give a concise, direct answer to the question
for another clinician (doctor-level detail).

You MUST return ONLY a valid JSON object, with NO extra text, NO commentary, and NO backticks.
The JSON must have exactly these keys:
- "answer": string
- "rationale": string (1-3 sentences)
- "confidence": float between 0 and 1
- "citations": list of chunk_ids (strings)

Question:
When should exercise testing be performed in stable STEMI patients who are not selected for cardiac catheterization?

Cardiology agent:
<text>
Confidence: <float 0-1>
Citation: <comma-separated chunk_ids>

Patient case:


Patient question:
When should exercise testing be performed in stable STEMI patients who are 

### 13. Running the System | Evaluation

In [30]:
"""
# Defining your test cases for evaluation
test_cases = [
  {
    "id": "case1",
    "case": "A patient with heart failure with reduced ejection fraction (HFrEF) is being started on standard therapy.",
    "question": "What medication classes are included in guideline-directed medical therapy (GDMT) for HFrEF, and what role do SGLT2 inhibitors play?"
  },
  {
    "id": "case2",
    "case": "A patient has been diagnosed with heart failure with mildly reduced ejection fraction (HFmrEF). Another patient has heart failure with preserved ejection fraction (HFpEF).",
    "question": "How are heart failure therapies recommended for patients with mildly reduced (HFmrEF) or preserved ejection fraction (HFpEF)?"
  },
  {
    "id": "case3",
    "case": "A patient previously had heart failure with reduced ejection fraction (HFrEF) but now has an improved LVEF above 40%.",
    "question": "What is the recommendation for patients with improved LVEF who previously had HFrEF?"
  },
    # Add more cases here...
]

# 3. Run evaluation loop
print("\n=== Running evaluation on test cases ===")
results = evaluate(test_cases)
print("\n=== RESULTS ===")
print(results)
"""

'\n# Defining your test cases for evaluation\ntest_cases = [\n  {\n    "id": "case1",\n    "case": "A patient with heart failure with reduced ejection fraction (HFrEF) is being started on standard therapy.",\n    "question": "What medication classes are included in guideline-directed medical therapy (GDMT) for HFrEF, and what role do SGLT2 inhibitors play?"\n  },\n  {\n    "id": "case2",\n    "case": "A patient has been diagnosed with heart failure with mildly reduced ejection fraction (HFmrEF). Another patient has heart failure with preserved ejection fraction (HFpEF).",\n    "question": "How are heart failure therapies recommended for patients with mildly reduced (HFmrEF) or preserved ejection fraction (HFpEF)?"\n  },\n  {\n    "id": "case3",\n    "case": "A patient previously had heart failure with reduced ejection fraction (HFrEF) but now has an improved LVEF above 40%.",\n    "question": "What is the recommendation for patients with improved LVEF who previously had HFrEF?"\n  },\n

### 14. Deployment with Gradio

In [31]:
# ========== UI HANDLERS ==========

def ui_case_handler(case_text: str, question: str, role: str, show_details: bool):
    """
    Handle the Case/Treatment tab.

    - Always runs the full case graph once via run_case_full.
    - Returns:
        simple_answer, cardio_ans, cardio_ctx, derm_ans, derm_ctx, doctor_json_str, nurse_raw
    """
    if not case_text.strip() or not question.strip():
        simple = "Please provide both a patient case and a clinical question."
        return simple, "", "", "", "", "", ""

    role = role.lower()

    # Run full multi-agent graph once
    state = run_case_full(case_text, question)

    # ----- build simple answer -----
    doctor_json = state.get("doctor_json", {}) or {}
    doctor_text = state.get("doctor_text", "") or ""
    nurse_text  = state.get("nurse_text", "") or ""

    if role == "doctor":
        # Case mode JSON: expect diagnosis / plan / rationale / ...
        diag = (doctor_json.get("diagnosis") or "").strip()
        plan = doctor_json.get("plan", []) or []
        if diag or plan:
            plan_str = "\n".join(f"- {p}" for p in plan)
            simple_answer = f"Diagnosis: {diag}\n\nPlan:\n{plan_str}".strip()
        else:
            # Fallback to text parsing
            simple_answer = extract_answer_from_doctor_text(doctor_text) or doctor_text.strip()
    else:  # nurse role
        if nurse_text:
            simple_answer = extract_nurse_answer(nurse_text)
        else:
            # Fallback: at least show diagnosis if present
            diag = (doctor_json.get("diagnosis") or "").strip()
            if diag:
                simple_answer = diag
            else:
                simple_answer = extract_answer_from_doctor_text(doctor_text) or doctor_text.strip()

    # ----- details (only if requested) -----
    if show_details:
        cardio_ans = state.get("cardio_answer", "")
        cardio_ctx = ", ".join(state.get("cardio_context", []) or [])
        derm_ans   = state.get("derm_answer", "")
        derm_ctx   = ", ".join(state.get("derm_context", []) or [])

        # Show full doctor JSON (diagnosis/plan/rationale/confidence/citations)
        doctor_json_str = json.dumps(doctor_json, indent=2, ensure_ascii=False)
        nurse_raw = nurse_text
    else:
        cardio_ans = ""
        cardio_ctx = ""
        derm_ans   = ""
        derm_ctx   = ""
        doctor_json_str = ""
        nurse_raw = ""

    return simple_answer, cardio_ans, cardio_ctx, derm_ans, derm_ctx, doctor_json_str, nurse_raw


def ui_qa_handler(question: str, role: str, show_details: bool):
    """
    Handler for the Guideline Q&A tab.

    Returns 7 values:
      simple_answer, cardio_ans, cardio_ctx, derm_ans, derm_ctx, doctor_json_str, nurse_raw
    """
    try:
        if not question.strip():
            simple = "Please enter a question."
            return simple, "", "", "", "", "", ""

        role = role.lower()

        # Run full multi-agent QA graph once
        state = run_qa_full(question)

        # ----- doctor / nurse raw outputs -----
        doctor_json = state.get("doctor_json", {}) or {}
        doctor_text = state.get("doctor_text", "") or ""
        nurse_text  = state.get("nurse_text", "") or ""

        # ----- build simple answer -----
        if role == "doctor":
            ans = (doctor_json.get("answer") or "").strip()

            # If it's clearly useless or contaminated, re-extract from raw text
            if (
                not ans
                or ans == "Unable to parse doctor answer."
                or "Dermatology agent:" in ans
                or "Cardiology agent:" in ans
                or "Patient question:" in ans
            ):
                ans = extract_answer_from_doctor_text(doctor_text)

            simple_answer = ans or extract_answer_from_doctor_text(doctor_text)

        else:  # nurse
            if nurse_text:
                simple_answer = extract_nurse_answer(nurse_text)
            else:
                ans = (doctor_json.get("answer") or "").strip()
                if not ans:
                    ans = extract_answer_from_doctor_text(doctor_text)
                simple_answer = ans or doctor_text.strip()

        # ----- details (only if requested) -----
        if show_details:
            cardio_ans = state.get("cardio_answer", "")
            cardio_ctx = ", ".join(state.get("cardio_context", []) or [])
            derm_ans   = state.get("derm_answer", "")
            derm_ctx   = ", ".join(state.get("derm_context", []) or [])
            doctor_json_str = json.dumps(doctor_json, indent=2, ensure_ascii=False)
            nurse_raw = nurse_text
        else:
            cardio_ans = ""
            cardio_ctx = ""
            derm_ans   = ""
            derm_ctx   = ""
            doctor_json_str = ""
            nurse_raw = ""

        return simple_answer, cardio_ans, cardio_ctx, derm_ans, derm_ctx, doctor_json_str, nurse_raw

    except Exception as e:
        # If anything goes wrong, show the error in Final Answer box
        err_msg = f"[UI error] {type(e).__name__}: {e}"
        return err_msg, "", "", "", "", "", ""

# ========== GRADIO APP ==========

with gr.Blocks() as demo:
    gr.Markdown("# 🩺 MedAnchor - Multi-Agent Medical Assistant")
    gr.Markdown(
        "Use the tabs for **Case/Treatment** vs **Guideline Q&A**. "
        "Choose role (Doctor/Nurse). Toggle *Show detailed agent flow* to inspect intermediate steps."
    )

    with gr.Tabs():
        # ---------- CASE / TREATMENT TAB ----------
        with gr.Tab("Case / Treatment"):
            case_input = gr.Textbox(
                label="Patient Case",
                lines=6,
                placeholder="Paste the case description here...",
            )
            question_input = gr.Textbox(
                label="Clinical Question",
                placeholder="e.g., What is going on and what treatment is recommended?",
            )
            role_case = gr.Radio(["Doctor", "Nurse"], value="Doctor", label="Role")
            show_details_case = gr.Checkbox(label="Show detailed agent flow", value=False)

            run_case_btn = gr.Button("Run")

            simple_case_out = gr.Textbox(label="Final Answer", lines=6)

            with gr.Accordion("Detailed Agent Flow (optional)", open=False):
                cardio_case_out = gr.Textbox(label="[CASE] Cardiology Agent Answer", lines=4)
                cardio_case_ctx = gr.Textbox(label="[CASE] Cardiology Retrieved Chunk IDs", lines=2)
                derm_case_out   = gr.Textbox(label="[CASE] Dermatology Agent Answer", lines=4)
                derm_case_ctx   = gr.Textbox(label="[CASE] Dermatology Retrieved Chunk IDs", lines=2)
                doctor_case_json = gr.Textbox(label="[CASE] Doctor JSON", lines=8)
                nurse_case_raw   = gr.Textbox(label="[CASE] Raw Nurse Output", lines=6)

            run_case_btn.click(
                ui_case_handler,
                inputs=[case_input, question_input, role_case, show_details_case],
                outputs=[
                    simple_case_out,
                    cardio_case_out,
                    cardio_case_ctx,
                    derm_case_out,
                    derm_case_ctx,
                    doctor_case_json,
                    nurse_case_raw,
                ],
            )

        # ---------- QA TAB ----------
        with gr.Tab("Guideline Q&A"):
            question_qa = gr.Textbox(
                label="Question",
                lines=3,
                placeholder="e.g., When should exercise testing be performed in stable STEMI patients...?",
            )
            role_qa = gr.Radio(["Doctor", "Nurse"], value="Doctor", label="Role")
            show_details_qa = gr.Checkbox(label="Show detailed agent flow", value=False)

            run_qa_btn = gr.Button("Run")

            simple_qa_out = gr.Textbox(label="Final Answer", lines=6)

            with gr.Accordion("Detailed Agent Flow (optional)", open=False):
                cardio_qa_out = gr.Textbox(label="[QA] Cardiology Agent Answer", lines=4)
                cardio_qa_ctx = gr.Textbox(label="[QA] Cardiology Retrieved Chunk IDs", lines=2)
                derm_qa_out   = gr.Textbox(label="[QA] Dermatology Agent Answer", lines=4)
                derm_qa_ctx   = gr.Textbox(label="[QA] Dermatology Retrieved Chunk IDs", lines=2)
                doctor_qa_json = gr.Textbox(label="[QA] Doctor JSON", lines=8)
                nurse_qa_raw   = gr.Textbox(label="[QA] Raw Nurse Output", lines=6)

            run_qa_btn.click(
                ui_qa_handler,
                inputs=[question_qa, role_qa, show_details_qa],
                outputs=[
                    simple_qa_out,
                    cardio_qa_out,
                    cardio_qa_ctx,
                    derm_qa_out,
                    derm_qa_ctx,
                    doctor_qa_json,
                    nurse_qa_raw,
                ],
            )

demo.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://ad130fc9fa45757bd3.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


