In [None]:
!pip install --upgrade --quiet accelerate bitsandbytes transformers

In [None]:
pip uninstall tokenizers -y

In [None]:
pip install tokenizers==0.13.3

In [None]:
pip install --upgrade transformers

In [None]:
import os
os.environ["HF_TOKEN"] = "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"


from huggingface_hub import login
login(os.environ["HF_TOKEN"])

# Zero-Shot Prompting

In [None]:
import os, gc
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from transformers import pipeline, BitsAndBytesConfig

# =============== USER CONFIG ===============
MERGED_CSV     = "/kaggle/input/dsetindiana/indiana_merged.csv"
IMAGES_FOLDER  = "/kaggle/input/chest-xrays-indiana-university/images/images_normalized"
OUTPUT_CSV     = "medical_llama3_zero_shot.csv"

MAX_ROWS            = 100
MAX_NEW_TOKENS_TXT  = 400     # Medical-Llama3-8B
MAX_NEW_TOKENS_VLM  = 160     # LLaVA-Med observations
FLUSH_EVERY         = 50
CLEAN_CACHE_EVERY   = 25
DOWNSCALE_IMAGES    = True
MAX_IMAGE_SIDE      = 1024

LLAVA_MODEL_ID = "microsoft/llava-med-v1.5-mistral-7b"
LLAMA_MODEL_ID = "YOUR_ORG/Medical-Llama3-8B"  # <-- set your HF repo

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

SYSTEM_RAD = "You are an expert radiologist. Provide precise, clinically useful findings for chest X-rays."

VLM_PROMPT = (
    "You are a radiology assistant. Briefly list objective visual observations from this chest X-ray in 3‚Äì6 short bullets. "
    "Avoid interpretations or differentials; stick to what is visibly present (e.g., 'right upper lobe opacity', "
    "'cardiomediastinal silhouette enlarged', 'costophrenic blunting', 'no visible pneumothorax'). Be concise."
)

ZERO_SHOT_INSTR = (
    "Using the observations above, write a single concise radiology-style paragraph with only clinically relevant findings. "
    "State if normal when appropriate. Avoid differential diagnosis and avoid repeating the bullets verbatim."
)

def downscale(img, max_side=1024):
    if not DOWNSCALE_IMAGES:
        return img
    w, h = img.size
    m = max(w, h)
    if m <= max_side: return img
    s = max_side / m
    return img.resize((int(w*s), int(h*s)), Image.LANCZOS)

print("Loading VLM (LLaVA-Med)...")
vlm = pipeline(
    task="image-text-to-text",
    model=LLAVA_MODEL_ID,
    device_map="auto",
    quantization_config=bnb_cfg,
    trust_remote_code=True,
)
if hasattr(vlm.model, "generation_config"):
    vlm.model.generation_config.do_sample = False

print("Loading Medical-Llama3-8B (text-only)...")
llama = pipeline(
    task="text-generation",
    model=LLAMA_MODEL_ID,
    device_map="auto",
    quantization_config=bnb_cfg,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

def gen_observation(img, max_new=MAX_NEW_TOKENS_VLM):
    messages = [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_RAD}]},
        {"role": "user",   "content": [
            {"type": "text",  "text": VLM_PROMPT},
            {"type": "image", "image": img},
        ]}
    ]
    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = vlm(text=messages, max_new_tokens=max_new)
    return out[0]["generated_text"][-1]["content"].strip()

def craft_llama_prompt(observation_text):
    return f"{SYSTEM_RAD}\n\nObservations:\n{observation_text}\n\n{ZERO_SHOT_INSTR}"

def gen_report_text(prompt_text, max_new=MAX_NEW_TOKENS_TXT):
    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = llama(prompt_text, max_new_tokens=max_new, do_sample=False)
    text = out[0]["generated_text"]
    return text[len(prompt_text):].strip() if text.startswith(prompt_text) else text.strip()

print("Reading merged CSV...")
base_df = pd.read_csv(MERGED_CSV, usecols=["uid", "filename", "findings"]).drop_duplicates("uid").reset_index(drop=True)
if len(base_df) > MAX_ROWS: base_df = base_df.iloc[:MAX_ROWS].copy()

existing_df, done_uids = None, set()
if os.path.exists(OUTPUT_CSV):
    try:
        existing_df = pd.read_csv(OUTPUT_CSV)
        done_uids = set(existing_df["uid"].astype(str))
        before = len(base_df)
        base_df = base_df[~base_df["uid"].astype(str).isin(done_uids)].reset_index(drop=True)
        print(f"Resuming: {len(done_uids)} done; {len(base_df)} remaining (filtered {before - len(base_df)}).")
    except Exception as e:
        print("Starting fresh (could not read existing):", e)

print(f"Rows to process: {len(base_df)}")
results_buffer = []
if len(base_df) > 0:
    pbar = tqdm(base_df.itertuples(index=False), total=len(base_df), desc="Medical-Llama3 Zero-shot")

    for i, row in enumerate(pbar, start=1):
        uid, filename, finding = row.uid, row.filename, row.findings
        img_path = os.path.join(IMAGES_FOLDER, filename)

        obs, report = "ERROR: image not found", ""
        if os.path.exists(img_path):
            try:
                img = Image.open(img_path).convert("RGB")
                img = downscale(img, MAX_IMAGE_SIDE)
                obs = gen_observation(img)
                img.close()
            except Exception as e:
                obs = f"ERROR: {e}"

        try:
            prompt = craft_llama_prompt(obs)
            report = gen_report_text(prompt)
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            report = gen_report_text(prompt, max_new=200)
        except Exception as e:
            report = f"ERROR: {e}"

        results_buffer.append({
            "uid": uid,
            "filename": filename,
            "original_finding": finding,
            "vlm_visual_observation": obs,
            "llama_mode": "zero_shot",
            "medical_llama3_output": report
        })

        if i % CLEAN_CACHE_EVERY == 0:
            torch.cuda.empty_cache(); gc.collect()

        if i % FLUSH_EVERY == 0 or i == len(base_df):
            new_df = pd.DataFrame(results_buffer)
            combined = new_df if existing_df is None else pd.concat([existing_df, new_df], ignore_index=True)
            combined.to_csv(OUTPUT_CSV, index=False)
            existing_df = combined
            results_buffer.clear()
            pbar.set_postfix(saved_rows=len(existing_df))

print("‚úÖ Finished. CSV saved:", OUTPUT_CSV)

In [None]:
# ----------------- sec5b: Verify CSV contents -----------------
import os
import pandas as pd

# Ensure full cell display of long text
pd.set_option('display.max_colwidth', None)

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)
    print("Rows in CSV:", len(out))
    print("Columns:", out.columns.tolist())
    display(out.head())
else:
    print("CSV not found - something went wrong.")

In [None]:
import matplotlib.pyplot as plt

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)

    for i, row in out.head(2).iterrows():
        img_path = os.path.join(IMAGES_FOLDER, row["filename"])
        
        if not os.path.exists(img_path):
            print(f"[{row['uid']}] Image not found: {row['filename']}")
            continue

        # Show image
        img = Image.open(img_path).convert("RGB")
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"UID: {row['uid']} | File: {row['filename']}", fontsize=10)
        plt.show()

        # Show text info
        print("üîπ Original Finding:")
        print(row["original_finding"])
        print("\nüîπ medical_llama3 Output:")
        print(row["medical_llama3_output"])
        print("=" * 80)

## Evaluation

In [None]:
!pip install bert-score

In [None]:
import os
import pandas as pd
import numpy as np
from collections import Counter
from bert_score import score as bert_score

# ---------------- Utility functions ----------------

def tokenize(text):
    return text.lower().split()

# ---- Custom BLEU (unigram-based, with brevity penalty) ----
def bleu_score(reference, candidate, n=4):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    # brevity penalty
    ref_len, cand_len = len(ref_tokens), len(cand_tokens)
    bp = np.exp(1 - ref_len / cand_len) if cand_len < ref_len else 1
    
    # n-gram precisions
    precisions = []
    for i in range(1, n+1):
        ref_ngrams = Counter([tuple(ref_tokens[j:j+i]) for j in range(len(ref_tokens)-i+1)])
        cand_ngrams = Counter([tuple(cand_tokens[j:j+i]) for j in range(len(cand_tokens)-i+1)])
        
        overlap = sum((cand_ngrams & ref_ngrams).values())
        total = sum(cand_ngrams.values())
        precisions.append(overlap / total if total > 0 else 0)
    
    # geometric mean of precisions
    if all(p == 0 for p in precisions):
        return 0
    geo_mean = np.exp(np.mean([np.log(p) if p > 0 else -9999 for p in precisions]))
    
    return bp * geo_mean

# ---- Custom ROUGE ----
def rouge_n(reference, candidate, n=1):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens)-n+1)])
    cand_ngrams = Counter([tuple(cand_tokens[i:i+n]) for i in range(len(cand_tokens)-n+1)])
    
    overlap = sum((cand_ngrams & ref_ngrams).values())
    
    recall = overlap / sum(ref_ngrams.values()) if ref_ngrams else 0
    precision = overlap / sum(cand_ngrams.values()) if cand_ngrams else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
    
    return recall, precision, f1

def rouge_l(reference, candidate):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    m, n = len(ref_tokens), len(cand_tokens)
    
    # LCS (Longest Common Subsequence)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m):
        for j in range(n):
            if ref_tokens[i] == cand_tokens[j]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    lcs = dp[m][n]
    
    recall = lcs / m if m > 0 else 0
    precision = lcs / n if n > 0 else 0
    f1 = (2 * recall * precision) / (recall + precision) if (recall+precision) > 0 else 0
    return recall, precision, f1

# ---------------- Main ----------------

OUTPUT_CSV = "medical_llama3_zero_shot_results_batch.csv"
df = pd.read_csv(OUTPUT_CSV)

bleu_scores = []
rouge1_f1, rouge2_f1, rouge3_f1, rougel_f1 = [], [], [], []
refs, cands = [], []

for _, row in df.iterrows():
    ref = str(row["original_finding"])
    cand = str(row["medical_llama3_output"])
    
    # BLEU
    bleu_scores.append(bleu_score(ref, cand))
    
    # ROUGE
    _, _, r1 = rouge_n(ref, cand, 1)
    _, _, r2 = rouge_n(ref, cand, 2)
    _, _, r3 = rouge_n(ref, cand, 3)
    _, _, rl = rouge_l(ref, cand)
    
    rouge1_f1.append(r1)
    rouge2_f1.append(r2)
    rouge3_f1.append(r3)
    rougel_f1.append(rl)
    
    refs.append(ref)
    cands.append(cand)

# BERTScore
P, R, F1 = bert_score(cands, refs, lang="en", verbose=True)

# ---------------- Results ----------------
print("Average Metrics on dataset:")
print(f"BLEU:     {np.mean(bleu_scores):.4f}")
print(f"ROUGE-L:  {np.mean(rougel_f1):.4f}")
print(f"BERTScore: {F1.mean().item():.4f}")

# 3-Shot Prompting

In [None]:
import os, gc
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from transformers import pipeline, BitsAndBytesConfig

# =============== USER CONFIG ===============
MERGED_CSV     = "/kaggle/input/dsetindiana/indiana_merged.csv"
IMAGES_FOLDER  = "/kaggle/input/chest-xrays-indiana-university/images/images_normalized"
OUTPUT_CSV     = "medical_llama3_few_shot.csv"

MAX_ROWS            = 100
MAX_NEW_TOKENS_TXT  = 400
MAX_NEW_TOKENS_VLM  = 160
FLUSH_EVERY         = 50
CLEAN_CACHE_EVERY   = 25
DOWNSCALE_IMAGES    = True
MAX_IMAGE_SIDE      = 1024

LLAVA_MODEL_ID = "microsoft/llava-med-v1.5-mistral-7b"
LLAMA_MODEL_ID = "YOUR_ORG/Medical-Llama3-8B"  # <-- set your HF repo

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

SYSTEM_RAD = "You are an expert radiologist. Provide precise, clinically useful findings for chest X-rays."

VLM_PROMPT = (
    "You are a radiology assistant. Briefly list objective visual observations from this chest X-ray in 3‚Äì6 short bullets. "
    "Avoid interpretations or differentials; stick to visible findings only. Be concise."
)

FEW_SHOT_EXAMPLES = """
Example A:
Observations:
- Cardiomediastinal silhouette within normal size.
- No focal airspace consolidation.
- Costophrenic angles are sharp.
- No visible pneumothorax.

Report:
The cardiac and mediastinal contours are within normal limits. Lungs are clear without focal consolidation. No pleural effusion or pneumothorax identified.

Example B:
Observations:
- Increased interstitial markings bilaterally.
- Hyperinflated lungs.
- Biapical pleural thickening.
- No acute focal consolidation.

Report:
Findings suggest chronic interstitial changes with hyperinflation. Biapical pleural thickening is noted. No acute focal consolidation, pleural effusion, or pneumothorax.

Example C:
Observations:
- Left upper lobe patchy opacity.
- Slight elevation of the left hemidiaphragm.
- Cardiac size normal.
- No large effusion.

Report:
Patchy airspace opacity in the left upper lobe concerning for pneumonia in the appropriate clinical setting. Cardiomediastinal silhouette is normal. No large pleural effusion or pneumothorax.
""".strip()

def downscale(img, max_side=1024):
    if not DOWNSCALE_IMAGES:
        return img
    w, h = img.size
    m = max(w, h)
    if m <= max_side: return img
    s = max_side / m
    return img.resize((int(w*s), int(h*s)), Image.LANCZOS)

print("Loading VLM (LLaVA-Med)...")
vlm = pipeline(
    task="image-text-to-text",
    model=LLAVA_MODEL_ID,
    device_map="auto",
    quantization_config=bnb_cfg,
    trust_remote_code=True,
)
if hasattr(vlm.model, "generation_config"):
    vlm.model.generation_config.do_sample = False

print("Loading Medical-Llama3-8B (text-only)...")
llama = pipeline(
    task="text-generation",
    model=LLAMA_MODEL_ID,
    device_map="auto",
    quantization_config=bnb_cfg,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

def gen_observation(img, max_new=160):
    messages = [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_RAD}]},
        {"role": "user",   "content": [
            {"type": "text",  "text": VLM_PROMPT},
            {"type": "image", "image": img},
        ]}
    ]
    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = vlm(text=messages, max_new_tokens=max_new)
    return out[0]["generated_text"][-1]["content"].strip()

def craft_llama_prompt(observation_text):
    return (
        f"{SYSTEM_RAD}\n\n{FEW_SHOT_EXAMPLES}\n\nNow use the same style.\n\n"
        f"Observations:\n{observation_text}\n\nWrite the Report:"
    )

def gen_report_text(prompt_text, max_new=400):
    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = llama(prompt_text, max_new_tokens=max_new, do_sample=False)
    text = out[0]["generated_text"]
    return text[len(prompt_text):].strip() if text.startswith(prompt_text) else text.strip()

print("Reading merged CSV...")
base_df = pd.read_csv(MERGED_CSV, usecols=["uid", "filename", "findings"]).drop_duplicates("uid").reset_index(drop=True)
if len(base_df) > MAX_ROWS: base_df = base_df.iloc[:MAX_ROWS].copy()

existing_df, done_uids = None, set()
if os.path.exists(OUTPUT_CSV):
    try:
        existing_df = pd.read_csv(OUTPUT_CSV)
        done_uids = set(existing_df["uid"].astype(str))
        before = len(base_df)
        base_df = base_df[~base_df["uid"].astype(str).isin(done_uids)].reset_index(drop=True)
        print(f"Resuming: {len(done_uids)} done; {len(base_df)} remaining (filtered {before - len(base_df)}).")
    except Exception as e:
        print("Starting fresh (could not read existing):", e)

print(f"Rows to process: {len(base_df)}")
results_buffer = []
if len(base_df) > 0:
    pbar = tqdm(base_df.itertuples(index=False), total=len(base_df), desc="Medical-Llama3 Few-shot")

    for i, row in enumerate(pbar, start=1):
        uid, filename, finding = row.uid, row.filename, row.findings
        img_path = os.path.join(IMAGES_FOLDER, filename)

        obs, report = "ERROR: image not found", ""
        if os.path.exists(img_path):
            try:
                img = Image.open(img_path).convert("RGB")
                img = downscale(img, MAX_IMAGE_SIDE)
                obs = gen_observation(img, max_new=MAX_NEW_TOKENS_VLM)
                img.close()
            except Exception as e:
                obs = f"ERROR: {e}"

        try:
            prompt = craft_llama_prompt(obs)
            report = gen_report_text(prompt, max_new=MAX_NEW_TOKENS_TXT)
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            report = gen_report_text(prompt, max_new=200)
        except Exception as e:
            report = f"ERROR: {e}"

        results_buffer.append({
            "uid": uid,
            "filename": filename,
            "original_finding": finding,
            "vlm_visual_observation": obs,
            "llama_mode": "few_shot",
            "medical_llama3_output": report
        })

        if i % CLEAN_CACHE_EVERY == 0:
            torch.cuda.empty_cache(); gc.collect()

        if i % FLUSH_EVERY == 0 or i == len(base_df):
            new_df = pd.DataFrame(results_buffer)
            combined = new_df if existing_df is None else pd.concat([existing_df, new_df], ignore_index=True)
            combined.to_csv(OUTPUT_CSV, index=False)
            existing_df = combined
            results_buffer.clear()
            pbar.set_postfix(saved_rows=len(existing_df))

print("‚úÖ Finished. CSV saved:", OUTPUT_CSV)

In [None]:
# ----------------- sec5b: Verify CSV contents -----------------
import os
import pandas as pd

# Ensure full cell display of long text
pd.set_option('display.max_colwidth', None)

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)
    print("Rows in CSV:", len(out))
    print("Columns:", out.columns.tolist())
    display(out.head())
else:
    print("CSV not found - something went wrong.")

In [None]:
import matplotlib.pyplot as plt

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)

    for i, row in out.head(2).iterrows():
        img_path = os.path.join(IMAGES_FOLDER, row["filename"])
        
        if not os.path.exists(img_path):
            print(f"[{row['uid']}] Image not found: {row['filename']}")
            continue

        # Show image
        img = Image.open(img_path).convert("RGB")
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"UID: {row['uid']} | File: {row['filename']}", fontsize=10)
        plt.show()

        # Show text info
        print("üîπ Original Finding:")
        print(row["original_finding"])
        print("\nüîπ medical_llama3 Output:")
        print(row["medical_llama3_output"])
        print("=" * 80)

## Evaluation

In [None]:
import os
import pandas as pd
import numpy as np
from collections import Counter
from bert_score import score as bert_score

# ---------------- Utility functions ----------------

def tokenize(text):
    return text.lower().split()

# ---- Custom BLEU (unigram-based, with brevity penalty) ----
def bleu_score(reference, candidate, n=4):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    # brevity penalty
    ref_len, cand_len = len(ref_tokens), len(cand_tokens)
    bp = np.exp(1 - ref_len / cand_len) if cand_len < ref_len else 1
    
    # n-gram precisions
    precisions = []
    for i in range(1, n+1):
        ref_ngrams = Counter([tuple(ref_tokens[j:j+i]) for j in range(len(ref_tokens)-i+1)])
        cand_ngrams = Counter([tuple(cand_tokens[j:j+i]) for j in range(len(cand_tokens)-i+1)])
        
        overlap = sum((cand_ngrams & ref_ngrams).values())
        total = sum(cand_ngrams.values())
        precisions.append(overlap / total if total > 0 else 0)
    
    # geometric mean of precisions
    if all(p == 0 for p in precisions):
        return 0
    geo_mean = np.exp(np.mean([np.log(p) if p > 0 else -9999 for p in precisions]))
    
    return bp * geo_mean

# ---- Custom ROUGE ----
def rouge_n(reference, candidate, n=1):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens)-n+1)])
    cand_ngrams = Counter([tuple(cand_tokens[i:i+n]) for i in range(len(cand_tokens)-n+1)])
    
    overlap = sum((cand_ngrams & ref_ngrams).values())
    
    recall = overlap / sum(ref_ngrams.values()) if ref_ngrams else 0
    precision = overlap / sum(cand_ngrams.values()) if cand_ngrams else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
    
    return recall, precision, f1

def rouge_l(reference, candidate):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    m, n = len(ref_tokens), len(cand_tokens)
    
    # LCS (Longest Common Subsequence)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m):
        for j in range(n):
            if ref_tokens[i] == cand_tokens[j]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    lcs = dp[m][n]
    
    recall = lcs / m if m > 0 else 0
    precision = lcs / n if n > 0 else 0
    f1 = (2 * recall * precision) / (recall + precision) if (recall+precision) > 0 else 0
    return recall, precision, f1

# ---------------- Main ----------------

OUTPUT_CSV = "medical_llama3_3shot_results_batch.csv"
df = pd.read_csv(OUTPUT_CSV)

bleu_scores = []
rouge1_f1, rouge2_f1, rouge3_f1, rougel_f1 = [], [], [], []
refs, cands = [], []

for _, row in df.iterrows():
    ref = str(row["original_finding"])
    cand = str(row["medical_llama3_output"])
    
    # BLEU
    bleu_scores.append(bleu_score(ref, cand))
    
    # ROUGE
    _, _, r1 = rouge_n(ref, cand, 1)
    _, _, r2 = rouge_n(ref, cand, 2)
    _, _, r3 = rouge_n(ref, cand, 3)
    _, _, rl = rouge_l(ref, cand)
    
    rouge1_f1.append(r1)
    rouge2_f1.append(r2)
    rouge3_f1.append(r3)
    rougel_f1.append(rl)
    
    refs.append(ref)
    cands.append(cand)

# BERTScore
P, R, F1 = bert_score(cands, refs, lang="en", verbose=True)

# ---------------- Results ----------------
print("Average Metrics on dataset:")
print(f"BLEU:     {np.mean(bleu_scores):.4f}")
print(f"ROUGE-L:  {np.mean(rougel_f1):.4f}")
print(f"BERTScore: {F1.mean().item():.4f}")

# 5-Shot Prompting

In [None]:
import os, gc
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from transformers import pipeline, BitsAndBytesConfig

# ----------------- USER CONFIG -----------------
MERGED_CSV    = "/kaggle/input/dsetindiana/indiana_merged.csv"
IMAGES_FOLDER = "/kaggle/input/chest-xrays-indiana-university/images/images_normalized"
OUTPUT_CSV    = "medical_llama3_3shot_results_batch.csv"   # kept same for continuity

MAX_ROWS          = 100
MAX_NEW_TOKENS    = 400
FLUSH_EVERY       = 50
CLEAN_CACHE_EVERY = 25
DOWNSCALE_IMAGES  = True
MAX_IMAGE_SIDE    = 1024

# ‚îÄ‚îÄ‚îÄ Configuration ‚îÄ‚îÄ‚îÄ
SYSTEM_INSTRUCTION = """
You are an expert radiologist. 
Your role is to carefully analyze chest X-rays and provide accurate, clinically useful findings.
"""

# 5-shot examples
FEW_SHOT_EXAMPLES = """
Here are five examples of high-quality chest X-ray findings:

Example 1 (uid=1, filename=1_IM-0001-4001.dcm.png):
"The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no signs of a pleural effusion. There is no evidence of pneumothorax."

Example 2 (uid=4, filename=4_IM-2050-1001.dcm.png):
"There are diffuse bilateral interstitial and alveolar opacities consistent with chronic obstructive lung disease and bullous emphysema. There are irregular opacities in the left lung apex, that could represent a cavitary lesion in the left lung apex. There are streaky opacities in the right upper lobe, consistent with scarring. The cardiomediastinal silhouette is normal in size and contour. There is no pneumothorax or large pleural effusion."

Example 3 (uid=5, filename=5_IM-2117-1003002.dcm.png):
"The cardiomediastinal silhouette and pulmonary vasculature are within normal limits. There is no pneumothorax or pleural effusion. There are no focal areas of consolidation. Cholecystectomy clips are present. Small T-spine osteophytes are noted. There is biapical pleural thickening, unchanged from prior. Mildly hyperexpanded lungs."

Example 4 (uid=7, filename=7_IM-2263-1001.dcm.png):
"The cardiac contours are normal. Mild basilar atelectasis is present. The lungs are otherwise clear. Thoracic spondylosis is seen. Lower cervical spine arthritis is noted."

Example 5 (uid=8, filename=8_IM-2333-1001.dcm.png):
"The heart, pulmonary vasculature, and mediastinum are within normal limits. There is no pleural effusion or pneumothorax. There is no focal air space opacity to suggest pneumonia. An interim lower cervical spinal fusion is partly evaluated."
"""

PROMPT = f"""
{FEW_SHOT_EXAMPLES}

Now, based on the above style and level of detail, analyze the following chest X-ray 
and provide a short, single-paragraph summary of the findings. 
Focus only on the most relevant abnormalities (if any), or clearly state if the film appears normal. 
Keep the response concise and professional, suitable for a radiology report.
"""

# ----------------- LOAD / REUSE PIPELINE -----------------
if "pipe" in globals():
    print("Reusing existing `pipe` (model already in memory).")
else:
    print("Loading LLaVA-Med v1.5 (Mistral-7B) 4-bit...")
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    pipe = pipeline(
        "image-text-to-text",
        model="microsoft/llava-med-v1.5-mistral-7b",  # swapped model
        device_map="auto",
        quantization_config=bnb_cfg,
        trust_remote_code=True,  # helps with model-specific processors
    )
    if hasattr(pipe.model, "generation_config"):
        pipe.model.generation_config.do_sample = False
    print("Model loaded and ready.")

# ----------------- DOWNSCALE FUNCTION -----------------
def downscale(img, max_side=1024):
    if not DOWNSCALE_IMAGES:
        return img
    w, h = img.size
    m = max(w, h)
    if m <= max_side:
        return img
    scale = max_side / m
    return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

# ----------------- GENERATION FUNCTION -----------------
def safe_generate(messages, max_new_tokens):
    try:
        with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
            out = pipe(text=messages, max_new_tokens=max_new_tokens)
        return out[0]["generated_text"][-1]["content"].strip()
    except torch.cuda.OutOfMemoryError:
        torch.cuda.empty_cache()
        try:
            with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
                out = pipe(text=messages, max_new_tokens=min(160, max_new_tokens))
            return out[0]["generated_text"][-1]["content"].strip()
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            return "ERROR: CUDA OOM (retry failed)"
    except Exception as e:
        return f"ERROR: {e}"

# ----------------- MAIN LOOP -----------------
print("Reading merged CSV...")
base_df = pd.read_csv(MERGED_CSV, usecols=["uid", "filename", "findings"])
base_df = base_df.drop_duplicates(subset="uid").reset_index(drop=True)

if len(base_df) > MAX_ROWS:
    base_df = base_df.iloc[:MAX_ROWS].copy()

existing_df = None
done_uids = set()

if os.path.exists(OUTPUT_CSV):
    try:
        existing_df = pd.read_csv(OUTPUT_CSV)
        done_uids = set(existing_df["uid"].astype(str))
        before = len(base_df)
        base_df = base_df[~base_df["uid"].astype(str).isin(done_uids)].reset_index(drop=True)
        print(f"Resuming: {len(done_uids)} already processed; {len(base_df)} remaining (filtered {before - len(base_df)}).")
    except Exception as e:
        print("Could not read existing output; starting fresh.", e)
        existing_df = None

print("Rows to process in this run:", len(base_df))
if len(base_df) == 0:
    print("Nothing new to process. (All done.)")

if len(base_df) > 0:
    results_buffer = []
    pbar = tqdm(base_df.itertuples(index=False), total=len(base_df), desc="LLaVA-Med batch")

    for i, row in enumerate(pbar, start=1):
        uid       = row.uid
        filename  = row.filename
        finding   = row.findings
        img_path  = os.path.join(IMAGES_FOLDER, filename)

        if not os.path.exists(img_path):
            model_out = "ERROR: image not found"
        else:
            try:
                img = Image.open(img_path).convert("RGB")
                img = downscale(img, MAX_IMAGE_SIDE)
                messages = [
                    {"role": "system", "content": [{"type": "text", "text": SYSTEM_INSTRUCTION}]},
                    {"role": "user",   "content": [
                        {"type": "text",  "text": PROMPT},
                        {"type": "image", "image": img},
                    ]}
                ]
                model_out = safe_generate(messages, MAX_NEW_TOKENS)
                img.close()
            except Exception as e:
                model_out = f"ERROR: {e}"

        results_buffer.append({
            "uid":              uid,
            "filename":         filename,
            "original_finding": finding,
            "prompt":           "5-shot radiology prompt",
            "medical_llama3_output":  model_out  # kept name for downstream compatibility
        })

        # Memory hygiene
        if i % CLEAN_CACHE_EVERY == 0:
            torch.cuda.empty_cache()
            gc.collect()

        # Flush partial results
        if i % FLUSH_EVERY == 0 or i == len(base_df):
            new_df = pd.DataFrame(results_buffer)
            combined = new_df if existing_df is None else pd.concat([existing_df, new_df], ignore_index=True)
            combined.to_csv(OUTPUT_CSV, index=False)
            existing_df = combined
            results_buffer.clear()
            pbar.set_postfix(saved_rows=len(existing_df))

    print("‚úÖ Finished processing. CSV saved:", OUTPUT_CSV)


In [None]:
# ----------------- sec5b: Verify CSV contents -----------------
import os
import pandas as pd

# Ensure full cell display of long text
pd.set_option('display.max_colwidth', None)

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)
    print("Rows in CSV:", len(out))
    print("Columns:", out.columns.tolist())
    display(out.head())
else:
    print("CSV not found - something went wrong.")

In [None]:
import matplotlib.pyplot as plt

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)

    for i, row in out.head(2).iterrows():
        img_path = os.path.join(IMAGES_FOLDER, row["filename"])
        
        if not os.path.exists(img_path):
            print(f"[{row['uid']}] Image not found: {row['filename']}")
            continue

        # Show image
        img = Image.open(img_path).convert("RGB")
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"UID: {row['uid']} | File: {row['filename']}", fontsize=10)
        plt.show()

        # Show text info
        print("üîπ Original Finding:")
        print(row["original_finding"])
        print("\nüîπ medical_llama3 Output:")
        print(row["medical_llama3_output"])
        print("=" * 80)

## Evaluation

In [None]:
!pip install bert-score

In [None]:
import os
import pandas as pd
import numpy as np
from collections import Counter
from bert_score import score as bert_score

# ---------------- Utility functions ----------------

def tokenize(text):
    return text.lower().split()

# ---- Custom BLEU (unigram-based, with brevity penalty) ----
def bleu_score(reference, candidate, n=4):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    # brevity penalty
    ref_len, cand_len = len(ref_tokens), len(cand_tokens)
    bp = np.exp(1 - ref_len / cand_len) if cand_len < ref_len else 1
    
    # n-gram precisions
    precisions = []
    for i in range(1, n+1):
        ref_ngrams = Counter([tuple(ref_tokens[j:j+i]) for j in range(len(ref_tokens)-i+1)])
        cand_ngrams = Counter([tuple(cand_tokens[j:j+i]) for j in range(len(cand_tokens)-i+1)])
        
        overlap = sum((cand_ngrams & ref_ngrams).values())
        total = sum(cand_ngrams.values())
        precisions.append(overlap / total if total > 0 else 0)
    
    # geometric mean of precisions
    if all(p == 0 for p in precisions):
        return 0
    geo_mean = np.exp(np.mean([np.log(p) if p > 0 else -9999 for p in precisions]))
    
    return bp * geo_mean

# ---- Custom ROUGE ----
def rouge_n(reference, candidate, n=1):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens)-n+1)])
    cand_ngrams = Counter([tuple(cand_tokens[i:i+n]) for i in range(len(cand_tokens)-n+1)])
    
    overlap = sum((cand_ngrams & ref_ngrams).values())
    
    recall = overlap / sum(ref_ngrams.values()) if ref_ngrams else 0
    precision = overlap / sum(cand_ngrams.values()) if cand_ngrams else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
    
    return recall, precision, f1

def rouge_l(reference, candidate):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    m, n = len(ref_tokens), len(cand_tokens)
    
    # LCS (Longest Common Subsequence)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m):
        for j in range(n):
            if ref_tokens[i] == cand_tokens[j]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    lcs = dp[m][n]
    
    recall = lcs / m if m > 0 else 0
    precision = lcs / n if n > 0 else 0
    f1 = (2 * recall * precision) / (recall + precision) if (recall+precision) > 0 else 0
    return recall, precision, f1

# ---------------- Main ----------------

OUTPUT_CSV = "medical_llama3_5shot_results_batch.csv"
df = pd.read_csv(OUTPUT_CSV)

bleu_scores = []
rouge1_f1, rouge2_f1, rouge3_f1, rougel_f1 = [], [], [], []
refs, cands = [], []

for _, row in df.iterrows():
    ref = str(row["original_finding"])
    cand = str(row["medical_llama3_output"])
    
    # BLEU
    bleu_scores.append(bleu_score(ref, cand))
    
    # ROUGE
    _, _, r1 = rouge_n(ref, cand, 1)
    _, _, r2 = rouge_n(ref, cand, 2)
    _, _, r3 = rouge_n(ref, cand, 3)
    _, _, rl = rouge_l(ref, cand)
    
    rouge1_f1.append(r1)
    rouge2_f1.append(r2)
    rouge3_f1.append(r3)
    rougel_f1.append(rl)
    
    refs.append(ref)
    cands.append(cand)

# BERTScore
P, R, F1 = bert_score(cands, refs, lang="en", verbose=True)

# ---------------- Results ----------------
print("Average Metrics on dataset:")
print(f"BLEU:     {np.mean(bleu_scores):.4f}")
print(f"ROUGE-L:  {np.mean(rougel_f1):.4f}")
print(f"BERTScore: {F1.mean().item():.4f}")

# CoT Prompting

In [None]:
import os, gc
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from transformers import pipeline, BitsAndBytesConfig

# =============== USER CONFIG ===============
MERGED_CSV     = "/kaggle/input/dsetindiana/indiana_merged.csv"
IMAGES_FOLDER  = "/kaggle/input/chest-xrays-indiana-university/images/images_normalized"
OUTPUT_CSV     = "medical_llama3_cot.csv"

MAX_ROWS            = 100
MAX_NEW_TOKENS_TXT  = 1200   # reduce if OOM
MAX_NEW_TOKENS_VLM  = 300    # concise obs; reduce if OOM
FLUSH_EVERY         = 50
CLEAN_CACHE_EVERY   = 25
DOWNSCALE_IMAGES    = True
MAX_IMAGE_SIDE      = 1024

LLAVA_MODEL_ID = "microsoft/llava-med-v1.5-mistral-7b"
LLAMA_MODEL_ID = "YOUR_ORG/Medical-Llama3-8B"  # <-- set your HF repo

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# =============== PROMPTS ===============
SYSTEM_RAD = """
You are an expert radiologist.
You are required to actively search for subtle and chronic abnormalities and explicitly consider possible misses.
For every region, your default is NOT ‚Äúnormal‚Äù‚Äîyou must convince yourself it is truly normal by excluding all
reasonable pathology, or explain every possible abnormality, however mild, even if not clinically significant.
""".strip()

# VLM (image) -> observations
VLM_PROMPT = (
    "You are a radiology assistant. Briefly list objective visual observations from this chest X-ray in 3‚Äì6 short bullets. "
    "Avoid interpretations or differentials; stick to visible findings only (e.g., 'right upper lobe opacity', "
    "'cardiomediastinal silhouette enlarged', 'costophrenic blunting', 'no visible pneumothorax'). Be concise."
)

# Llama3 (text-only) -> final structured two-step output, based ONLY on observations
COT_STRUCTURED_INSTR = """
You are an expert radiologist. Using ONLY the Observations provided (no outside assumptions), do not rush.
Perform these two steps in order and output the structure exactly as specified.

-------------------------------
Step 1: FULL ORGAN ASSESSMENT
-------------------------------
For each region below, provide a separate assessment in this exact format:

Region: [name]
Status: Normal / Abnormal
Findings: [If normal, list ALL subtle/chronic pathologies explicitly ruled out by name, based on the Observations.
If abnormal, specify the abnormality type(s), severity (mild/moderate/severe if inferable), and confidence level (low/moderate/high).]

Regions to assess:
1. Heart
2. Aorta
3. Cardiac silhouette
4. Diaphragm
5. Costophrenic angle
6. Hilar region
7. Lungs
8. Mediastinum
9. Pleura
10. Pulmonary arteries
11. Trachea
12. Bones (ribs, clavicles, spine, AC joints)
13. Soft tissues

Do NOT skip any region. Do NOT summarize. Use one block per region.

---------------------------------------
Step 2: REEVALUATION OF CHALLENGED ORGANS
---------------------------------------
Re-examine ONLY these regions with greater scrutiny, challenging your initial call and explicitly searching for subtle,
chronic, or borderline findings. Even mild uncertainty must be described. Base this strictly on the Observations.

Organs to reevaluate:
- Aorta
- Cardiac silhouette
- Diaphragm
- Hilar region
- Lungs
- Mediastinum
- Pleura
- Pulmonary arteries
- Trachea
- Heart
- Costophrenic angle

For each, restate:
Region: [name]
Reevaluation: Normal / Abnormal
Details: [List anything you may have missed or overcalled, including subtle findings or areas of uncertainty. If normal,
state what subtle findings you explicitly ruled out again. If abnormal, specify details, severity, and confidence.]

Finish with exactly one sentence summarizing any remaining uncertainty for any organ.

Constraints:
- Base your output ONLY on the Observations text; if something is not supported by Observations, mark as "uncertain".
- Keep the writing concise and clinically useful. Follow the exact structure above.
""".strip()

# =============== HELPERS ===============
def downscale(img, max_side=1024):
    if not DOWNSCALE_IMAGES:
        return img
    w, h = img.size
    m = max(w, h)
    if m <= max_side:
        return img
    s = max_side / m
    return img.resize((int(w*s), int(h*s)), Image.LANCZOS)

# =============== LOAD MODELS ===============
print("Loading VLM (LLaVA-Med) for visual observations...")
vlm = pipeline(
    task="image-text-to-text",
    model=LLAVA_MODEL_ID,
    device_map="auto",
    quantization_config=bnb_cfg,
    trust_remote_code=True,
)
if hasattr(vlm.model, "generation_config"):
    vlm.model.generation_config.do_sample = False

print("Loading Medical-Llama3-8B (text-only) for structured CoT output...")
llama = pipeline(
    task="text-generation",
    model=LLAMA_MODEL_ID,
    device_map="auto",
    quantization_config=bnb_cfg,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

def gen_observation(img, max_new=MAX_NEW_TOKENS_VLM):
    messages = [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_RAD}]},
        {"role": "user",   "content": [
            {"type": "text",  "text": VLM_PROMPT},
            {"type": "image", "image": img},
        ]}
    ]
    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = vlm(text=messages, max_new_tokens=max_new)
    return out[0]["generated_text"][-1]["content"].strip()

def craft_llama_prompt(observation_text: str) -> str:
    # Feed SYSTEM_RAD, then explicit Observations, then your two-step structured instruction.
    return (
        f"{SYSTEM_RAD}\n\n"
        f"Observations:\n{observation_text}\n\n"
        f"{COT_STRUCTURED_INSTR}"
    )

def gen_report_text(prompt_text, max_new=MAX_NEW_TOKENS_TXT):
    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = llama(prompt_text, max_new_tokens=max_new, do_sample=False)
    text = out[0]["generated_text"]
    # Some pipelines echo the prompt; trim if present.
    return text[len(prompt_text):].strip() if text.startswith(prompt_text) else text.strip()

# =============== MAIN ===============
print("Reading merged CSV...")
base_df = pd.read_csv(MERGED_CSV, usecols=["uid", "filename", "findings"]).drop_duplicates("uid").reset_index(drop=True)
if len(base_df) > MAX_ROWS:
    base_df = base_df.iloc[:MAX_ROWS].copy()

existing_df, done_uids = None, set()
if os.path.exists(OUTPUT_CSV):
    try:
        existing_df = pd.read_csv(OUTPUT_CSV)
        done_uids = set(existing_df["uid"].astype(str))
        before = len(base_df)
        base_df = base_df[~base_df["uid"].astype(str).isin(done_uids)].reset_index(drop=True)
        print(f"Resuming: {len(done_uids)} done; {len(base_df)} remaining (filtered {before - len(base_df)}).")
    except Exception as e:
        print("Starting fresh (could not read existing):", e)

print(f"Rows to process: {len(base_df)}")
results_buffer = []
if len(base_df) > 0:
    pbar = tqdm(base_df.itertuples(index=False), total=len(base_df), desc="Medical-Llama3 CoT (2-step structured)")

    for i, row in enumerate(pbar, start=1):
        uid, filename, finding = row.uid, row.filename, row.findings
        img_path = os.path.join(IMAGES_FOLDER, filename)

        obs, report = "ERROR: image not found", ""
        if os.path.exists(img_path):
            try:
                img = Image.open(img_path).convert("RGB")
                img = downscale(img, MAX_IMAGE_SIDE)
                obs = gen_observation(img, max_new=MAX_NEW_TOKENS_VLM)
                img.close()
            except Exception as e:
                obs = f"ERROR: {e}"

        try:
            prompt = craft_llama_prompt(obs)
            report = gen_report_text(prompt, max_new=MAX_NEW_TOKENS_TXT)
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            try:
                report = gen_report_text(prompt, max_new=400)
            except Exception as e:
                report = f"ERROR: {e}"
        except Exception as e:
            report = f"ERROR: {e}"

        results_buffer.append({
            "uid": uid,
            "filename": filename,
            "original_finding": finding,
            "vlm_visual_observation": obs,
            "llama_mode": "cot_structured_2step",
            "medical_llama3_output": report
        })

        if i % CLEAN_CACHE_EVERY == 0:
            torch.cuda.empty_cache(); gc.collect()

        if i % FLUSH_EVERY == 0 or i == len(base_df):
            new_df = pd.DataFrame(results_buffer)
            combined = new_df if existing_df is None else pd.concat([existing_df, new_df], ignore_index=True)
            combined.to_csv(OUTPUT_CSV, index=False)
            existing_df = combined
            results_buffer.clear()
            pbar.set_postfix(saved_rows=len(existing_df))

print("‚úÖ Finished. CSV saved:", OUTPUT_CSV)

In [None]:
# ----------------- sec5b: Verify CSV contents -----------------
import os
import pandas as pd

# Ensure full cell display of long text
pd.set_option('display.max_colwidth', None)

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)
    print("Rows in CSV:", len(out))
    print("Columns:", out.columns.tolist())
    display(out.head())
else:
    print("CSV not found - something went wrong.")

In [None]:
import matplotlib.pyplot as plt

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)

    for i, row in out.head(2).iterrows():
        img_path = os.path.join(IMAGES_FOLDER, row["filename"])
        
        if not os.path.exists(img_path):
            print(f"[{row['uid']}] Image not found: {row['filename']}")
            continue

        # Show image
        img = Image.open(img_path).convert("RGB")
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"UID: {row['uid']} | File: {row['filename']}", fontsize=10)
        plt.show()

        # Show text info
        print("üîπ Original Finding:")
        print(row["original_finding"])
        print("\nüîπ medical_llama3 Output:")
        print(row["medical_llama3_output"])
        print("=" * 80)

## LLM-as-a-Judge

In [None]:
import os
import pandas as pd
from tqdm import tqdm
import google.generativeai as genai
import json
import re

# ---------------- Gemini Setup ----------------
genai.configure(api_key="AIzaSyAAcE7jO1PmaAxBr2c3Up8Qt7QcA72Owk4")
model = genai.GenerativeModel("gemini-2.5-flash")

# ---------------- Eval Function ----------------
def judge_with_gemini(original_finding, medical_llama3_output):
    prompt = f"""
You are acting as an impartial medical evaluator (LLM-as-a-judge). 
Compare the following two texts:

Original Finding:
{original_finding}

medical_llama3 Output:
{medical_llama3_output}

Evaluate along these 3 dimensions, and output ONLY a JSON object with keys:
- "similarity": float (0 to 1) ‚Üí semantic meaning similarity.
- "professional_writing": float (0 to 1) ‚Üí clarity, conciseness, and medical domain appropriateness of writing style.
- "correctness": float (0 to 1) ‚Üí medical accuracy and correctness of findings.

Do not explain. Just output the JSON object.
"""

    response = model.generate_content(prompt)
    text = response.text.strip()

    # Clean common formatting issues
    text = re.sub(r"^```json", "", text)
    text = re.sub(r"^```", "", text)
    text = re.sub(r"```$", "", text)
    text = text.strip()

    try:
        scores = json.loads(text)
        return (
            float(scores.get("similarity", 0)),
            float(scores.get("professional_writing", 0)),
            float(scores.get("correctness", 0)),
        )
    except Exception as e:
        print("‚ö†Ô∏è Parsing error:", e, "Raw response:", text)
        return (0.0, 0.0, 0.0)

# ---------------- Main Script ----------------
OUTPUT_CSV = "medical_llama3_results_batch.csv"
EVAL_CSV   = "medical_llama3_eval_results.csv"

if os.path.exists(OUTPUT_CSV):
    df = pd.read_csv(OUTPUT_CSV).head(100).copy()

    sims, pros, corrs = [], [], []

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Gemini LLM-as-a-judge"):
        s, p, c = judge_with_gemini(row["original_finding"], row["medical_llama3_output"])
        sims.append(s); pros.append(p); corrs.append(c)

    df["similarity"] = sims
    df["professional_writing"] = pros
    df["correctness"] = corrs

    df.to_csv(EVAL_CSV, index=False)
    print(f"‚úÖ Evaluation done. Saved to {EVAL_CSV}")

In [None]:
import pandas as pd

df = pd.read_csv("/kaggle/working/medical_llama3_eval_results.csv")
print(df.head(3))

In [None]:
import pandas as pd

# Load evaluation results
df = pd.read_csv("/kaggle/working/medical_llama3_eval_results.csv")

# Compute averages
avg_similarity = df["similarity"].mean()
avg_professional = df["professional_writing"].mean()
avg_correctness = df["correctness"].mean()

# Print neatly
print("Average Scores:")
print(f"Similarity:            {avg_similarity:.3f}")
print(f"Professional Writing:  {avg_professional:.3f}")
print(f"Correctness:           {avg_correctness:.3f}")