# Import and load test data

In [5]:
import os
os.chdir("E:/subject/PACE-UP/code/MedRAG_vdb")
os.environ["HF_HOME"] = "/media/pc1/Ubuntu/Extend_Data/hf_models"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [6]:
import json
import pandas as pd
import fitz
import os
import re
import time

from src.rag.rag import Rag
from google.genai import types
from PIL import Image
from google import genai
from typing import List
from tqdm import tqdm
from dotenv import load_dotenv


  from .autonotebook import tqdm as notebook_tqdm


In [25]:
load_dotenv()
# Configure the SDK
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY_2"))

In [3]:
def extract_text_from_pdf(pdf_path):
        # Open the PDF
        doc = fitz.open(pdf_path)
        all_text = ""

        # Loop through each page and extract text
        for page_num in range(len(doc)):
            page = doc.load_page(page_num)         
            text = page.get_text("text")           
            all_text += f"\n--- Page {page_num + 1} ---\n{text}"

        doc.close()
        return all_text

In [4]:
result_file = "data/processed/medgemma/results_reranker_1.json"
with open(result_file, "r", encoding="utf-8") as f:
    data = json.load(f)

# Generate answer from only medgemma

In [4]:
def safe_append_json(filepath, record):
    # Ensure the file exists with a valid empty JSON array
    if not os.path.exists(filepath):
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump([], f)

    # Load safely, reset if corrupted
    with open(filepath, "r+", encoding="utf-8") as f:
        try:
            data = json.load(f)
            if not isinstance(data, list):
                data = []  # Force reset if JSON is not a list
        except json.JSONDecodeError:
            data = []  # Reset if file is empty or invalid

        data.append(record)

        # Overwrite with updated JSON
        f.seek(0)
        json.dump(data, f, indent=4, ensure_ascii=False)
        f.truncate()

In [3]:
path_file = "data/raw/cases_test"
files = os.listdir(path_file)
text_lst = []
img_lst = []

for file in files:
    file_path = os.path.join(path_file, file)
    imgs = []

    # Load JSON
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Combine all fields except "Extracted Figure"
    combined_text = ""
    for key, value in data.items():
        if key not in ["Extracted Figure", "Further Reading"]:
            combined_text += f"{key}:\n{value}\n\n"

    # Add Descriptions from "Extracted Figure"
    descriptions = []
    for figure_group in data.get("Extracted Figure", []):
        for figure in figure_group:
            desc = figure.get("Description", "")
            descriptions.append(desc)

    if descriptions:
        combined_text += "Figure Descriptions:\n"
        for desc in descriptions:
            combined_text += f"{desc}\n"

    # Get image links
    image_paths = []
    for figure_group in data.get("Extracted Figure", []):
        for figure in figure_group:
            path = figure.get("Link")
            path = os.path.join('data/raw', path)
            if path:
                image_paths.append(path)

    # Final combined text
    text_lst.append(combined_text)
    img_lst.append(image_paths)


In [5]:
output_file = "data/processed/medgemma/results.json"
processed_files = []
with open(output_file, "r", encoding="utf-8") as f:
    data = json.load(f)
for case in data:
    for key, value in case.items():
            if key in ['original_file']:
                processed_files.append(value)

print(len(processed_files))

93


In [10]:
for query, imgs, file in zip(text_lst, img_lst, files):
    if file not in processed_files:
        answer = rag.get_answer_from_medgemma(
            query=query, 
            images_path=imgs, 
        )
        print(answer)

        record = {
            "query": query,
            "images_used": imgs,
            "answer": answer,
            "original_file": file
        }

        # Load existing data, append, and rewrite
        safe_append_json(output_file, record)


Based on the provided information, the most likely diagnosis is **Crusted Scabies (Norwegian Scabies)**.

Here's why:

*   **Clinical Presentation:** The widespread, exfoliative rash, especially in the axillae, buttocks, and thighs, along with fissures on the wrists and knees, is highly characteristic of crusted scabies. The presence of flakes of skin on the mattress further supports this diagnosis.
*   **Patient Demographics:** The patient is an indigenous Australian woman living in a remote community, which is a known risk factor for endemic scabies.
*   **Systemic Symptoms:** The patient's fever, tachycardia, tachypnea, and low blood pressure indicate sepsis, which is a common complication of crusted scabies due to secondary bacterial infections.
*   **Laboratory Findings:** The skin scrapings confirmed the presence of scabies mites and eggs, supporting the diagnosis. The blood cultures revealed *Staphylococcus aureus* infection, which is a common secondary bacterial infection in cr

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Based on the provided information, the most likely diagnosis for the 7-year-old girl from Peru with the chronic skin ulcer on her nose is **cutaneous leishmaniasis**.

Here's the reasoning:

*   **Clinical Presentation:** The description of the lesion (localized ulcer with raised edges, cobblestone-patterned bottom) is highly characteristic of cutaneous leishmaniasis.
*   **Geographic Location:** The patient lives in Peru, a known endemic area for cutaneous leishmaniasis.
*   **Travel History:** The patient's recent travel to a valley on the western slopes of the Andes further supports this diagnosis, as these areas are known to be endemic for leishmaniasis.
*   **Positive Leishmanin Skin Test:** A positive leishmanin skin test is a strong indicator of past or current exposure to Leishmania parasites.
*   **PCR Confirmation:** The positive PCR result confirms the presence of Leishmania parasites.
*   **Species Identification:** Identification of *Leishmania (Viannia) peruviana* provide

# Context Relevance

In [4]:
def clean_nejm_text(text: str) -> str:
    """
    Cleans extracted text from NEJM PDFs by removing repeated watermarks,
    copyright notices, and headers/footers.

    Args:
        text (str): The raw extracted text from the PDF.

    Returns:
        str: The cleaned text without repeated boilerplate.
    """
    # Pattern to match NEJM boilerplate lines
    patterns = [
        r"The New England Journal of Medicine is produced by NEJM Group.*",
        r"Downloaded from nejm\.org by .* on .*\. For personal use only\.",
        r"No other uses without permission\. Copyright © \d{4} Massachusetts Medical Society\. All rights reserved\."
    ]
    
    cleaned_text = text
    for pattern in patterns:
        cleaned_text = re.sub(pattern, "", cleaned_text, flags=re.IGNORECASE)

    # Remove repeated "--- Page X ---" markers if present
    cleaned_text = re.sub(r"--- Page \d+ ---", "", cleaned_text)

    # Remove excessive blank lines (collapse multiple newlines to one)
    cleaned_text = re.sub(r"\n\s*\n+", "\n\n", cleaned_text)

    # Strip leading/trailing whitespace
    return cleaned_text.strip()

In [5]:
def extract_context(pdf_paths: List[str], result_kg: str) -> str:
    result = ""
    result = result + "## RETRIEVED CLINICAL CASES FROM VECTOR DATABASE\n"
    for i, pdf_path in enumerate(pdf_paths):
        result = result + f"The {i+1}th case:\n"
        result = result + clean_nejm_text(extract_text_from_pdf(pdf_path)) + "\n"

    result = result + f"## RETRIEVED DISEASES WITH THEIR SYMPTOMS FROM KNOWLEDGE GRAPH\n{result_kg}\n"

    return result


In [6]:
def load_images(images_used: List[str]):
    results = []
    for image in images_used:
        if image:
            with open(image, 'rb') as f:
                image_bytes = f.read()
            results.append(
                types.Part.from_bytes(
                    data=image_bytes,
                    mime_type='image/jpeg',
                ),
            )
        else:
            continue
    return results

In [7]:
def load_base_prompt_context_relevance(query: str, retrieved_content: str) -> str:
    return f"""
You are **Dr. Trop**, an infectious‑disease expert clinician and RAG auditor.

## Task
Judge whether the combined evidence (Part A + Part B) is sufficient
to arrive at the *correct diagnosis* for the clinical case in the query.
Ignore subsidiary questions in the case; focus on diagnosis only.

## Evidence Types
* **Part A – Similar Clinical Cases**  
  • Each snippet is a previously solved case that the vector DB says is
    most similar to the query case.  
* **Part B – Disease ⇄ Symptom Facts**  
  • Each snippet is a structured fact (disease, key symptoms, lab signs)
    from the knowledge graph, intended for differential diagnosis.

## Scoring Rubric
| Score | Meaning |
|-------|---------|
| **2** | Evidence *collectively* gives all facts needed for a confident, specific diagnosis. |
| **1** | Evidence is partially useful but omits critical findings or leaves >1 plausible diagnoses. |
| **0** | Evidence is irrelevant or clearly insufficient for diagnosis. |

## Evaluation Protocol
1. **Paraphrase** the query case in ≤1 sentence.  
2. **Analyse Part A (Similar Cases)** – for each snippet:  
   a. Summarise in one sentence.  
   b. Does it share decisive findings with the query? (Yes/No & why)  
   c. Assign a snippet score (0‑2).  
3. **Analyse Part B (Disease ⇄ Symptom Facts)** – for each snippet:  
   a. Summarise in one sentence.  
   b. State if it directly supports or refutes a likely diagnosis.  
   c. Assign a snippet score (0‑2).  
4. **Synthesis & Self‑Consistency** – reconcile insights from both parts; if reasoning paths diverge, rethink and choose the most consistent.  
5. **Overall Score** – output one final 0‑2 value using the rubric.  
6. If the evidence is clearly inadequate, say: *“I do not have sufficient information.”*  
7. **Important:** your very last character must be that single digit (0|1|2) with nothing after it.

## Few‑Shot Examples
### Example 1
*Query:* “30‑year‑old in Vietnam with paroxysmal fever, splenomegaly, and ring‑form parasites on blood smear.”  
*Part A Snippet:* “Case of falciparum malaria treated successfully with DHA‑PPQ.”  
*Part B Snippet:* “Falciparum malaria → fever, anemia, splenomegaly; thick smear positive for ring forms.”  
*Evaluation:*  
- Part A snippet: score 2 (high match)  
- Part B snippet: score 2 (directly supports)  
**Overall = 2**

### Example 2
*Query:* “Child with rash and cough; query seeks measles vs rubella.”  
*Part A Snippet:* “Rash illness caused by parvovirus B19.”  
*Part B Snippet:* “Measles: cough, coryza, conjunctivitis, Koplik spots.”  
*Evaluation:*  
- Part A snippet: score 0 (different disease)  
- Part B snippet: score 1 (partial—gives measles signs but no rash description)  
**Overall = 1**

---

# QUERY
{query}

# RETRIEVED CONTEXT  
{retrieved_content}

# YOUR ANALYSIS AND FINAL SCORE"""


In [8]:
def safe_append_json(filepath, record):
    # Ensure the file exists with a valid empty JSON array
    if not os.path.exists(filepath):
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump([], f)

    # Load safely, reset if corrupted
    with open(filepath, "r+", encoding="utf-8") as f:
        try:
            data = json.load(f)
            if not isinstance(data, list):
                data = []  # Force reset if JSON is not a list
        except json.JSONDecodeError:
            data = []  # Reset if file is empty or invalid

        data.append(record)

        # Overwrite with updated JSON
        f.seek(0)
        json.dump(data, f, indent=4, ensure_ascii=False)
        f.truncate()

In [9]:
evaluated_data_context_relevance = []

In [11]:
evaluated_file = "data/processed/results_context_relevance.json"

processed_files = []
with open(evaluated_file, "r", encoding="utf-8") as f:
    processed_data = json.load(f)
for case in processed_data:
    for key, value in case.items():
            if key in ['original_file']:
                processed_files.append(value)

for test in tqdm(data):
    if test["original_file"] in processed_files:
        continue

    query_text = load_base_prompt_context_relevance(
        test["query"],
        extract_context(test["pdf_paths"], test["results_kg"])
    )

    image_parts = load_images(test["images_used"])  

    contents = []
    if image_parts:                                 
        contents.append(image_parts)                 
    contents.append(query_text)                      

    response = client.models.generate_content(
        model="gemini-2.0-flash-lite",
        contents=contents
    )

    test["evaluation"] = response.text
    test["score"] = response.text[-2]
    evaluated_data_context_relevance.append(test)
    safe_append_json(evaluated_file, test)
    time.sleep(4)

100%|██████████| 93/93 [16:24<00:00, 10.59s/it]


# Answer Faithfulness

In [None]:
def load_base_prompt_answer_faithfulness(
    query: str,
    draft_answer: str,
    retrieved_content: str,
) -> str:

    return f"""
You are **Dr. Trop**, an infectious‑disease specialist and RAG auditor.

## Task
Evaluate whether the *Draft Answer* below is **faithful** to the combined evidence
(Part A + Part B).  Judge factual support—NOT style or completeness.

## Evidence Types
* **Part A – Similar Clinical Cases**  
* **Part B – Disease ⇄ Symptom Facts**

## Faithfulness Scoring Rubric
| Score | Meaning |
|-------|---------|
| **2** | Every key claim in the draft answer is directly supported by ≥1 evidence snippet, and no claim contradicts the evidence. |
| **1** | Most claims are supported, but ≥1 minor claim lacks support **or** minor inconsistencies exist. |
| **0** | Major claims are unsupported or contradicted; hallucination or clear conflict with evidence. |

## Evaluation Protocol
1. **Restate** the draft answer’s main diagnosis in ≤1 sentence.  
2. **Claim‑by‑claim check**: For each distinct factual claim in the draft answer:  
   a. Quote or paraphrase the claim (≤15 words).  
   b. Cite supporting snippet IDs from Parts A/B *or* mark “No Support” / “Contradicted”.  
3. **Conflict scan**: Note any claim that opposes evidence.  
4. **Faithfulness Verdict**: Decide overall score (0‑2) using the rubric.  
5. If >50 % of claims lack support, explicitly say: *“Answer largely unsupported.”*  
6. **Important:** Finish your entire reply with the single digit score (0|1|2) and nothing after it.

## Few‑Shot Examples
### Example 1  (Score 2)
*Draft Answer:* “The patient has falciparum malaria; ACT (DHA‑PPQ) is indicated.”  
*Evidence snippets:* (Part A case of falciparum, Part B symptom profile).  
*Evaluation:* All claims supported → Overall = 2

### Example 2  (Score 0)
*Draft Answer:* “Child most likely has parvovirus B19 rash illness.”  
*Evidence snippets:* Part A & B highlight measles signs (Koplik spots) only.  
*Evaluation:* Central claim contradicted → Overall = 0

---

# QUERY (clinical case)
{query}

# DRAFT ANSWER TO EVALUATE
{draft_answer}

# RETRIEVED CONTEXT  
{retrieved_content}

# YOUR ANALYSIS AND FINAL SCORE"""


In [13]:
evaluated_data_answer_faithfulness = []

In [19]:
evaluated_file = "data/processed/results_answer_faithfulness.json"

processed_files = []
with open(evaluated_file, "r", encoding="utf-8") as f:
    processed_data = json.load(f)
for case in processed_data:
    for key, value in case.items():
            if key in ['original_file']:
                processed_files.append(value)

for test in tqdm(data):
    if test["original_file"] not in processed_files:
        query_text = load_base_prompt_answer_faithfulness(
            query=test["query"],
            draft_answer=test["answer"],
            retrieved_content=extract_context(test["pdf_paths"], test["results_kg"])
        )

        image_parts = load_images(test["images_used"])  

        contents = []
        if image_parts:                                 
            contents.append(image_parts)                 
        contents.append(query_text)                      

        response = client.models.generate_content(
            model="gemini-2.5-flash-lite",
            contents=contents
        )

        test["evaluation"] = response.text
        test["score"] = response.text[-1]
        evaluated_data_answer_faithfulness.append(test)
        safe_append_json(evaluated_file, test)
        time.sleep(4)

100%|██████████| 93/93 [11:50<00:00,  7.64s/it]


# Answer Relevance

In [22]:
def load_base_prompt_answer_relevance(
    query: str,
    draft_answer: str,
    retrieved_content: str,
) -> str:
    return f"""
You are **Dr. Trop**, an infectious‑disease specialist and Retrieval‑Augmented Generation auditor.

## Task
Assess the **Answer Relevance**: Does the *Draft Answer* appropriately address
the clinical‑case query **and** draw on the information contained in the
retrieved context (Part A + Part B)?

## Evidence Types
* **Part A – Similar Clinical Cases**  
* **Part B – Disease ⇄ Symptom Facts**

## Relevance Scoring Rubric
| Score | Meaning |
|-------|---------|
| **2** | Answer directly and completely addresses the query’s diagnostic goal, uses facts consistent with ≥1 retrieved snippet, and avoids off‑topic content. |
| **1** | Answer touches on the query but is incomplete **or** includes minor off‑topic / unsupported material. |
| **0** | Answer does not meaningfully address the query, or is dominated by content unrelated to the retrieved context. |

## Evaluation Protocol
1. **Restate** the essential question being asked (≤1 sentence).  
2. **Summarise** the draft answer’s key points (≤2 sentences).  
3. **Relevance Check**:  
   a. Does the answer directly solve the diagnostic question? (Yes/No & why)  
   b. Identify which snippets (Part A/B) the answer appears to rely on—or state “No linkage”.  
   c. Note any off‑topic or speculative content.  
4. **Verdict**: Choose a final relevance score (0‑2) using the rubric.  
5. If the answer barely references the retrieved evidence, say: *“Answer minimally grounded.”*  
6. **Important:** End your whole reply with that single digit (0|1|2) and nothing after it.

## Few‑Shot Examples
### Example 1 (Score 2 – Highly Relevant)
*Query:* “30‑year‑old with paroxysmal fever and ring‑form parasites—diagnosis?”  
*Draft Answer:* “These findings indicate falciparum malaria; ACT such as DHA‑PPQ is recommended.”  
*Context:* Part A similar malaria case; Part B malaria symptom profile.  
*Why:* Directly answers; cites same diagnostic facts → **Overall = 2**

### Example 2 (Score 1 – Partially Relevant)
*Query:* “Child with fever and purpuric rash—differential?”  
*Draft Answer:* “Could be meningococcemia. Start ceftriaxone.”  
*Context:* Part A lists meningococcemia & dengue; Part B symptom facts.  
*Why:* Mentions one plausible diagnosis but ignores dengue; no link to platelet findings → **Overall = 1**

### Example 3 (Score 0 – Irrelevant)
*Query:* “Identify viral family of dengue virus.”  
*Draft Answer:* “Treat dengue shock with IV fluids.”  
*Context:* Part A taxonomic case; Part B Flaviviridae facts.  
*Why:* Talks about treatment, not taxonomy → **Overall = 0**

---

# QUERY (clinical case)
{query}

# DRAFT ANSWER TO EVALUATE
{draft_answer}

# RETRIEVED CONTEXT  
{retrieved_content}

# YOUR ANALYSIS AND FINAL SCORE"""

In [23]:
evaluated_data_answer_relevance = []

In [45]:
evaluated_file = "data/processed/results_answer_relevance.json"

processed_files = []
with open(evaluated_file, "r", encoding="utf-8") as f:
    processed_data = json.load(f)
for case in processed_data:
    for key, value in case.items():
            if key in ['original_file']:
                processed_files.append(value)

for test in tqdm(data):
    if test["original_file"] not in processed_files:
        query_text = load_base_prompt_answer_relevance(
            query=test["query"],
            draft_answer=test["answer"],
            retrieved_content=extract_context(test["pdf_paths"], test["results_kg"])
        )

        image_parts = load_images(test["images_used"])  

        contents = []
        if image_parts:                                 
            contents.append(image_parts)                 
        contents.append(query_text)                      

        response = client.models.generate_content(
            model="gemini-2.5-flash-lite",
            contents=contents
        )

        test["evaluation"] = response.text
        test["score"] = response.text[-1]
        evaluated_data_answer_relevance.append(test)
        safe_append_json(evaluated_file, test)
        time.sleep(4)

100%|██████████| 93/93 [00:15<00:00,  6.16it/s]


# Result

In [48]:
files = ["data/processed/results_context_relevance.json", "data/processed/results_answer_faithfulness.json", "data/processed/results_answer_relevance.json"]

In [27]:
def score_str_to_int(s: str, valid: set[str] = {"0", "1", "2"}) -> int:
    if not s:
        return None
    last = s.strip()[-1]
    if last in valid:
        return int(last)
    return None

In [49]:
for file in files:
    res = 0
    print(f"Evaluation of {file[23:-5]}")
    with open(file, "r", encoding="utf-8") as f:
        eval_data = json.load(f)
    for i, case in enumerate(eval_data):
        # print(f"{i} {case['score']}\n{case['original_file']}")
        res += score_str_to_int(case["score"])
    res /= (len(eval_data) * 2)
    print(res)


Evaluation of context_relevance
0.3655913978494624
Evaluation of answer_faithfulness
0.3010752688172043
Evaluation of answer_relevance
0.20430107526881722


# Accuracy

In [7]:
import re
import pandas as pd

# --- helpers ---
DELIMS = r"[,/;|]"

# light stopword list to avoid false positives
STOPWORDS = {
    "disease","syndrome","infection","infections","acute","chronic","severe","mild",
    "of","and","or","the","a","an","with","without","due","to","by","in","on"
}

def normalize(text: str) -> str:
    """lowercase and collapse spaces"""
    return re.sub(r"\s+", " ", text.lower()).strip()

def tokens(text: str):
    """alphanumeric tokens"""
    return re.findall(r"[a-z0-9]+", normalize(text))

def candidate_keywords(candidate: str):
    """
    Split candidates on common delimiters and extract meaningful keywords.
    Example: 'Severe Falciparum Malaria / Cerebral malaria'
    -> {'falciparum','malaria','cerebral'}
    """
    parts = re.split(DELIMS, candidate)
    kws = set()
    for part in parts:
        for t in tokens(part):
            if t not in STOPWORDS and len(t) >= 3:
                kws.add(t)
    return kws

def reference_bag(reference: str):
    """bag of tokens for quick membership tests"""
    return set(tokens(reference))

def match_score(candidate: str, reference: str) -> float:
    """
    Returns ratio of candidate keywords found in the reference (0..1).
    """
    cand_kws = candidate_keywords(candidate)
    if not cand_kws:
        return 0.0
    ref_bag = reference_bag(reference)
    hits = sum(1 for kw in cand_kws if kw in ref_bag)
    return hits / len(cand_kws)

def phrase_match(candidate: str, reference: str) -> bool:
    """
    Exact phrase match after normalization (ignores punctuation/casing).
    """
    c = normalize(re.sub(r"[^\w\s]", " ", candidate))
    r = normalize(re.sub(r"[^\w\s]", " ", reference))
    # require at least one 3+ char token to avoid matching trivial strings
    return bool(c) and any(len(t) >= 3 for t in c.split()) and c in r

def match_any(candidate: str, reference: str) -> bool:
    """
    True if ANY important keyword is in the reference.
    """
    cand_kws = candidate_keywords(candidate)
    ref_bag = reference_bag(reference)
    return any(kw in ref_bag for kw in cand_kws)

def match_all(candidate: str, reference: str) -> bool:
    """
    True if ALL important keywords are in the reference.
    """
    cand_kws = candidate_keywords(candidate)
    ref_bag = reference_bag(reference)
    return len(cand_kws) > 0 and cand_kws.issubset(ref_bag)

def extract_disease_name(answer: str) -> str:
    if not isinstance(answer, str):
        return ""
    # Remove Markdown bolds (#, >, *) but KEEP hyphens
    cleaned = re.sub(r"[\*\#\>]", "", answer).strip()
    # Match "DISEASE_NAME:" in any format (with spaces/newlines)
    match = re.search(r"DISEASE_NAME\s*:\s*(.+)", cleaned, flags=re.IGNORECASE)
    if not match:
        return ""
    disease = match.group(1).strip()
    # Stop at the first line break (if extra text follows)
    disease = disease.split("\n")[0].strip()
    return disease

In [14]:
def evaluate(result_file, test_path = "data/processed/93_cases"):
    with open(result_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    test_diseases = []
    for case in data:
        file = os.path.join(test_path, case["original_file"])
        with open(file, "r", encoding="utf-8") as f:
            test_case = json.load(f)
        test_diseases.append(
            {
                "original_file": case["original_file"],
                "candidate_disease": extract_disease_name(case["answer"]),
                "reference_disease": test_case["extracted_disease"]
            }
        )
        # if not match_all(extract_disease_name(case["answer"]), test_case["extracted_disease"]):
        #     print(f'{extract_disease_name(case["answer"])} --- {test_case["extracted_disease"]}')
        # print(f'{case["answer"]}')
        
    df_test_diseases = pd.DataFrame(test_diseases, columns=["original_file", "candidate_disease", "reference_disease"])
    # df_test_diseases.info()
    # Optional: fix candidate_disease extraction to avoid cutting the last char
    df_test_diseases["score"] = df_test_diseases.apply(
        lambda r: match_score(r["candidate_disease"], r["reference_disease"]), axis=1
    )
    df_test_diseases["hit_any"] = df_test_diseases.apply(
        lambda r: match_any(r["candidate_disease"], r["reference_disease"]), axis=1
    )
    df_test_diseases["hit_all"] = df_test_diseases.apply(
        lambda r: match_all(r["candidate_disease"], r["reference_disease"]), axis=1
    )
    df_test_diseases["phrase"] = df_test_diseases.apply(
        lambda r: phrase_match(r["candidate_disease"], r["reference_disease"]), axis=1
    )
    # Print mismatched cases in a more readable table
    mismatched = df_test_diseases[df_test_diseases["hit_any"] == 0][["original_file", "candidate_disease", "reference_disease"]]
    # print("\n=== MISMATCHED CASES ===")
    # print(mismatched.to_string(index=False, justify="left", max_colwidth=40))
    # - strict accuracy: all keywords must appear
    strict_acc = df_test_diseases["hit_all"].mean()
    # - lenient accuracy: at least one keyword appears
    lenient_acc = df_test_diseases["hit_any"].mean()
    # - thresholded score: e.g., ≥ 0.6 of candidate keywords present
    thr = 0.6
    threshold_acc = (df_test_diseases["score"] >= thr).mean()
    print(round(strict_acc*100, 2))


In [15]:
experiments = [
    ("Only Gemma3 4B", "data/processed/medgemma/result_gemma3.json"),
    ("Gemma3 4B with Vector DB", "data/processed/medgemma/result_gemma3_vectordb.json"),
    ("Only MedGemma 4B", "data/processed/medgemma/results_medgemma.json"),
    ("MedGemma 4B with Vector DB", "data/processed/medgemma/result_reranker_1_vectordb.json"),
    ("MedGemma 4B with Knowledge Graph", "data/processed/medgemma/results_reranker_1_kg.json"),
    ("MedGemma 4B with Vector DB + Knowledge Graph", "data/processed/medgemma/results_reranker_1.json"),
    ("Only Quantized 4-bit Gemma3 27B", "data/processed/medgemma/result_gemma3_27b.json"),
    ("Quantized 4-bit Gemma3 27B with Vector DB", "data/processed/medgemma/result_gemma3_27_vectordb.json"),
    ("Only Quantized MedGemma 27B", "data/processed/medgemma/result_medgemma_27b.json"),
]

for name, path in experiments:
    print(f"=== {name} ===")
    evaluate(path)


=== Only Gemma3 4B ===
52.4
=== Gemma3 4B with Vector DB ===
37.6
=== Only MedGemma 4B ===
55.6
=== MedGemma 4B with Vector DB ===
36.4
=== MedGemma 4B with Knowledge Graph ===
44.8
=== MedGemma 4B with Vector DB + Knowledge Graph ===
22.4
=== Only Quantized 4-bit Gemma3 27B ===
24.8
=== Quantized 4-bit Gemma3 27B with Vector DB ===
25.6
=== Only Quantized MedGemma 27B ===
50.0


In [134]:
df_test_diseases[df_test_diseases["hit_any"] == 1]

Unnamed: 0,original_file,candidate_disease,reference_disease,score,hit_any,hit_all,phrase
0,1---A-20-Year-Old-Woman-from-Sudan-With-Fever-...,easoning – Patient Case Evaluation\n\nHere’s a...,Ebola virus disease,0.006452,True,False,False
1,10---A-55-Year-Old-Indigenous-Woman-from-Austr...,Sarcoptes scabiei infestation (Crusted/Norwegi...,Crusted scabies,0.333333,True,False,False
2,11---A-45-Year-Old-Male-Security-Guard-from-Ma...,easoning: Patient Case Evaluation\n\nHere’s a ...,spinal tuberculosis,0.006897,True,False,False
4,13---A-16-Year-Old-Girl-from-Malawi-With-Fever...,se Evaluation\n\nHere’s a clinical reasoning o...,Typhoid fever\nMalaria\nSchistosomiasis,0.018018,True,False,False
5,14---A-22-Year-Old-Woman-from-Bangladesh-With-...,"Cholera ( *Vibrio cholerae* O1, Serotype Ogawa...",cholera,0.142857,True,False,False
...,...,...,...,...,...,...,...
244,NEJMcpc0804149.json,easoning: Patient with Fever and Confusion\n\n...,Eastern equine encephalitis,0.011765,True,False,False
245,NEJMcpc0805311.json,Allergic Fungal Sinusitis (or Severe Chronic S...,Allergic fungal sinusitis,0.600000,True,False,False
246,NEJMcpc0805312.json,Acrodermatitis enteropathica,Acrodermatitis enteropathica-like syndrome,1.000000,True,True,True
248,NEJMcpc0806982.json,easoning – Patient Case Evaluation\n\nHere's a...,Intestinal schistosomiasis.,0.003448,True,False,False


In [30]:
result_file = "data/processed/medgemma/results_medgemma.json"
with open(result_file, "r", encoding="utf-8") as f:
    data = json.load(f)
    
test_diseases = []
test_path = "data/processed/93_cases"
for case in data:
    file = os.path.join(test_path, case["original_file"])
    with open(file, "r", encoding="utf-8") as f:
        test_case = json.load(f)
    test_diseases.append(
        {
            "original_file": case["original_file"],
            "candidate_disease": case["answer"][case["answer"].find("DISEASE_NAME: ") + 14: -1],
            "reference_disease": test_case["diagnosis"]
        }
    )

df_test_diseases = pd.DataFrame(test_diseases, columns=["original_file", "candidate_disease", "reference_disease"])
df_test_diseases


Unnamed: 0,original_file,candidate_disease,reference_disease
0,1---A-20-Year-Old-Woman-from-Sudan-With-Fever-...,Ebola Virus Disease,Ebola virus disease
1,10---A-55-Year-Old-Indigenous-Woman-from-Austr...,Crusted Scabies,Answer to Question 1 What is Your Provisional ...
2,11---A-45-Year-Old-Male-Security-Guard-from-Ma...,HIV-associated myelopathy,A presumed diagnosis of spinal TB was made
3,12---A-29-Year-Old-Man-from-The-Gambia-With-G_...,Chancroid,The Case Continued...\nThe patient was treated...
4,13---A-16-Year-Old-Girl-from-Malawi-With-Fever...,Malaria,"Typhoid fever, Malaria, and Schistosomiasis"
...,...,...,...
245,NEJMcpc0805311.json,Allergic Fungal Sinusitis,Allergic fungal sinusitis.
246,NEJMcpc0805312.json,Acrodermatitis enteropathica\n\nReasoning:\n\n...,Dermatitis due to zinc deficiency (acrodermati...
247,NEJMcpc0806980.json,Brucellosis,ANATOMICAL DIAGNOSIS\nLaboratory-acquired infe...
248,NEJMcpc0806982.json,Acute HIV infection\nReasoning:\n\n1. **Infor...,# Diagnosis\nΑΝΑΤΟMICAL DIAGNOSIS\nIntestinal ...


In [31]:
# Optional: fix candidate_disease extraction to avoid cutting the last char
def extract_candidate(ans: str) -> str:
    # safer than slicing with [-1]
    m = re.search(r"DISEASE_NAME:\s*(.+)", ans, flags=re.IGNORECASE)
    return m.group(1).strip() if m else ""

# If needed:
# df_test_diseases["candidate_disease"] = df_test_diseases["candidate_disease"].map(extract_candidate)

df_test_diseases["score"] = df_test_diseases.apply(
    lambda r: match_score(r["candidate_disease"], r["reference_disease"]), axis=1
)
df_test_diseases["hit_any"] = df_test_diseases.apply(
    lambda r: match_any(r["candidate_disease"], r["reference_disease"]), axis=1
)
df_test_diseases["hit_all"] = df_test_diseases.apply(
    lambda r: match_all(r["candidate_disease"], r["reference_disease"]), axis=1
)
df_test_diseases["phrase"] = df_test_diseases.apply(
    lambda r: phrase_match(r["candidate_disease"], r["reference_disease"]), axis=1
)

# Choose your metric:
# - strict accuracy: all keywords must appear
strict_acc = df_test_diseases["hit_all"].mean()

# - lenient accuracy: at least one keyword appears
lenient_acc = df_test_diseases["hit_any"].mean()

# - thresholded score: e.g., ≥ 0.6 of candidate keywords present
thr = 0.6
threshold_acc = (df_test_diseases["score"] >= thr).mean()

print(strict_acc, lenient_acc, threshold_acc)


0.464 0.768 0.5


# Seperate name of diseases

In [None]:
test_diseases = []
test_path = "data/processed/93_cases"

In [48]:
processed_files_name = [disease["query"] for disease in test_diseases]
processed_files_name

["1 A 20-Year-Old Woman from Sudan With Fever, Haemorrhage and Shock DANIEL G. BAUSCH Clinical Presentation History A 20-year-old housewife presents to a hospital in northern Uganda with a 2-day history of fever, severe asthenia, chest and abdominal pain, nausea, vomiting, diarrhoea and slight non-productive cough. The patient is a Sudanese refugee living in a camp in the region. She denies any contact with sick people. Clinical Findings The patient is prostrate and semiconscious on admission. Vital signs: temperature 39.6°C, (103.3°F) blood pressure 90/60 mmHg, pulse 90bpm, and respiratory rate 24 cycles per minute. Physical examination revealed abdominal ten- derness, especially in the right upper quadrant, hepatosple- nomegaly and bleeding from the gums. The lungs were clear. No rash or lymphadenopathy was noted. Questions 1. Is the patient's history and clinical presentation consistent with a haemorrhagic fever (HF) syndrome? 2. What degree of nursing precautions need to be impleme

In [49]:
len(processed_files_name)

157

In [50]:
import os
import json
import time

test_diseases = []

for case in data:
    file_path = os.path.join(test_path, case["original_file"])

    # Load the JSON file for this case
    with open(file_path, "r", encoding="utf-8") as f:
        test_case = json.load(f)

    # If we've already added it before, skip
    if test_case.get("extracted_disease"):  # <-- check the file data, not `case`
        test_diseases.append(test_case)
        continue

    diagnosis = (test_case.get("diagnosis") or "").strip()
    if not diagnosis:
        print(f"Skip {case['original_file']} — empty diagnosis")
        continue

    try:
        response = client.models.generate_content(
            model="gemini-2.0-flash-lite",
            contents=(
                f"This is the diagnosis: {diagnosis}. "
                "Only extract full name of the disease and don't do anything else."
            ),
        )
        # Depending on the SDK, response.text may or may not exist—guard it:
        extracted_disease = (getattr(response, "text", "") or "").strip()
        if not extracted_disease:
            print(f"Empty extraction for {case['original_file']}")
            continue

        print(extracted_disease)

        # Add new field and write back to file
        test_case["extracted_disease"] = extracted_disease
        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(test_case, f, ensure_ascii=False, indent=2)

        case["extracted_disease"] = extracted_disease

        test_diseases.append(test_case)
        time.sleep(2)  # rate limiting

    except Exception as e:
        print(f"API error on {case['original_file']}: {e}")
        continue


Mycobacterium marinum
lymphoma
Infective endocarditis; Staphylococcus epidermidis
Rheumatic mitral stenosis.
aspergillus species
Secondary Syphilis
herpes zoster affecting the first branch of the trigeminal nerve
Mycobacterium marinum Infection of the Hand
Granulomatous Pneumocystis carinii pneumonia
HIV-associated nephropathy
Cutaneous Melanoma Metastases
Whipple's disease
Amiodarone-Induced Skin Discoloration
Cat scratch encephalitis due to B. quintana.
Organizing subdural hygroma
Tuberculous peritonitis
Eales' disease (retinal periphlebitis); Tuberculous mediastinal lymphadenitis
Trichophyton mentagrophytes variant mentagrophytes.
Syphilitic aortitis with coronary ostial stenosis
Xanthoma disseminatum.
Lymphocytic interstitial pneumonitis associated with primary biliary cirrhosis
Subacute sclerosing panencephalitis
vertebral tuberculosis
Verruga Peruana
Necrosis of the anterolateral papillary muscle of the left ventricle, with mitral regurgitation (after aortic-valve replacement bec

# BLEU and ROUGE

In [57]:
def extract_final_diagnosis(answer: str) -> str:
    """
    Extract text between the 'Reflection & Final Diagnosis' section header
    and the next DIAGNOSIS header (WORKING/PROVISIONAL/FINAL) or end of text.
    Robust to spacing, **bold**, numbering '4.', and newline styles.
    """
    if not isinstance(answer, str):
        return ""

    # Normalize newlines
    text = answer.replace("\r\n", "\n")

    # Start: optional leading newline, optional "4.", optional **, flexible spaces
    start = r"(?:^|\n)\s*(?:4\.\s*)?\*?\*?\s*Reflection\s*&\s*Final\s*Diagnosis\s*:\s*\*?\*?\s*\n+"

    # End: the next diagnosis header OR end of string
    end = r"(?=\n{1,3}(?:WORKING|PROVISIONAL|FINAL)[_\s-]*DIAGNOSIS\s*:|\Z)"

    # Capture lazily between start and end
    pattern = start + r"(.*?)" + end

    m = re.search(pattern, text, flags=re.IGNORECASE | re.DOTALL)
    return m.group(1).strip() if m else ""

def clean_text(text: str) -> str:
    """
    Clean extracted diagnosis text:
    - Remove markdown bullets (*, -, +) at line starts
    - Remove markdown bold/italic markers (** or *)
    - Collapse multiple spaces/newlines into a single space
    - Strip leading/trailing whitespace
    """
    if not isinstance(text, str):
        return ""
    
    # Normalize newlines
    text = text.replace("\r\n", "\n")
    
    # Remove bullet markers at line start
    text = re.sub(r"^[\s]*[\*\-\+]\s*", "", text, flags=re.MULTILINE)
    
    # Remove markdown bold/italic (**word**, *word*)
    text = re.sub(r"\*{1,2}([^*]+)\*{1,2}", r"\1", text)
    
    # Replace newlines with space
    text = text.replace("\n", " ")
    
    # Collapse multiple spaces
    text = re.sub(r"\s+", " ", text)
    
    return text.strip()

In [61]:
result_file = "data/processed/medgemma/results_reranker_1.json"
with open(result_file, "r", encoding="utf-8") as f:
    data = json.load(f)
    
test_diseases = []
test_path = "data/processed/93_cases"

for case in data:
    file = os.path.join(test_path, case["original_file"])
    with open(file, "r", encoding="utf-8") as f:
        test_case = json.load(f)

    test_diseases.append(
        {
            "original_file": case["original_file"],
            "candidate_disease": clean_text(extract_final_diagnosis(case["answer"])),
            "reference_disease": clean_text(test_case["diagnosis"])
        }
    )


df_test_diseases = pd.DataFrame(test_diseases, columns=["original_file", "candidate_disease", "reference_disease"])
df_test_diseases.head()

Unnamed: 0,original_file,candidate_disease,reference_disease
0,1---A-20-Year-Old-Woman-from-Sudan-With-Fever-...,The patient's presentation is highly suggestiv...,Ebola virus disease
1,10---A-55-Year-Old-Indigenous-Woman-from-Austr...,,Answer to Question 1 What is Your Provisional ...
2,11---A-45-Year-Old-Male-Security-Guard-from-Ma...,,A presumed diagnosis of spinal TB was made
3,12---A-29-Year-Old-Man-from-The-Gambia-With-G_...,The patient's presentation is highly suggestiv...,The Case Continued... The patient was treated ...
4,13---A-16-Year-Old-Girl-from-Malawi-With-Fever...,"The patient's presentation, including fever, a...","Typhoid fever, Malaria, and Schistosomiasis"


In [62]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

# ---------- lightweight normalization/tokenization ----------
def normalize(text: str) -> str:
    if not isinstance(text, str):
        return ""
    return re.sub(r"\s+", " ", text.lower()).strip()

def word_tokens(text: str):
    # alphanumeric tokens only, no external downloads needed
    return re.findall(r"[a-z0-9]+", normalize(text))

# ---------- BLEU helpers ----------
_smooth = SmoothingFunction().method3  # robust for short phrases

def bleu_all(candidate: str, reference: str):
    cand = word_tokens(candidate)
    ref = word_tokens(reference)
    if not cand or not ref:
        return {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0}
    refs = [ref]
    return {
        "bleu1": sentence_bleu(refs, cand, weights=(1.0, 0, 0, 0), smoothing_function=_smooth),
        "bleu2": sentence_bleu(refs, cand, weights=(0.5, 0.5, 0, 0), smoothing_function=_smooth),
        "bleu3": sentence_bleu(refs, cand, weights=(1/3, 1/3, 1/3, 0), smoothing_function=_smooth),
        "bleu4": sentence_bleu(refs, cand, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=_smooth),
    }

# ---------- ROUGE helpers ----------
_rouge = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)

def rouge_all(candidate: str, reference: str):
    # rouge_scorer expects (target/reference, prediction/candidate)
    cand = candidate if isinstance(candidate, str) else ""
    ref = reference if isinstance(reference, str) else ""
    scores = _rouge.score(ref, cand)
    return {
        "rouge1_f": scores["rouge1"].fmeasure, "rouge1_p": scores["rouge1"].precision, "rouge1_r": scores["rouge1"].recall,
        "rouge2_f": scores["rouge2"].fmeasure, "rouge2_p": scores["rouge2"].precision, "rouge2_r": scores["rouge2"].recall,
        "rougeL_f": scores["rougeL"].fmeasure, "rougeL_p": scores["rougeL"].precision, "rougeL_r": scores["rougeL"].recall,
    }

# ---------- compute per-row scores ----------
bleu_rows = []
rouge_rows = []
for _, row in df_test_diseases.iterrows():
    bleu_rows.append(bleu_all(row["candidate_disease"], row["reference_disease"]))
    rouge_rows.append(rouge_all(row["candidate_disease"], row["reference_disease"]))

df_bleu = pd.DataFrame(bleu_rows)
df_rouge = pd.DataFrame(rouge_rows)

df_scores = pd.concat([df_test_diseases.reset_index(drop=True), df_bleu, df_rouge], axis=1)

# ---------- macro averages ----------
bleu_summary = df_bleu.mean(numeric_only=True)
rouge_summary = df_rouge.mean(numeric_only=True)

print("BLEU (macro avg):")
print(bleu_summary.to_string())
print("\nROUGE (macro avg):")
print(rouge_summary.to_string())

# Optional: inspect a few rows
df_scores.head(5)

BLEU (macro avg):
bleu1    0.020507
bleu2    0.008390
bleu3    0.004781
bleu4    0.003041

ROUGE (macro avg):
rouge1_f    0.033940
rouge1_p    0.035353
rouge1_r    0.100416
rouge2_f    0.004991
rouge2_p    0.005949
rouge2_r    0.014221
rougeL_f    0.023591
rougeL_p    0.023371
rougeL_r    0.080920


Unnamed: 0,original_file,candidate_disease,reference_disease,bleu1,bleu2,bleu3,bleu4,rouge1_f,rouge1_p,rouge1_r,rouge2_f,rouge2_p,rouge2_r,rougeL_f,rougeL_p,rougeL_r
0,1---A-20-Year-Old-Woman-from-Sudan-With-Fever-...,The patient's presentation is highly suggestiv...,Ebola virus disease,0.02439,0.01227,0.007778,0.005223,0.047059,0.02439,0.666667,0.0,0.0,0.0,0.047059,0.02439,0.666667
1,10---A-55-Year-Old-Indigenous-Woman-from-Austr...,,Answer to Question 1 What is Your Provisional ...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,11---A-45-Year-Old-Male-Security-Guard-from-Ma...,,A presumed diagnosis of spinal TB was made,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,12---A-29-Year-Old-Man-from-The-Gambia-With-G_...,The patient's presentation is highly suggestiv...,The Case Continued... The patient was treated ...,0.092974,0.033227,0.014963,0.008491,0.232258,0.382979,0.166667,0.052288,0.086957,0.037383,0.129032,0.212766,0.092593
4,13---A-16-Year-Old-Girl-from-Malawi-With-Fever...,"The patient's presentation, including fever, a...","Typhoid fever, Malaria, and Schistosomiasis",0.06,0.024744,0.014719,0.009596,0.109091,0.06,0.6,0.0,0.0,0.0,0.072727,0.04,0.4


In [None]:
thuyet phuc ng ko co kt AI, ly do la gi, phai co cach thuyet phuc,