# Zero-Shot Prompting

In [None]:
import os, gc, io, time, base64, json
import pandas as pd
from PIL import Image
from tqdm import tqdm
import requests

# ----------------- USER CONFIG -----------------
MERGED_CSV    = "/kaggle/input/dsetindiana/indiana_merged.csv"   # <-- change if needed
IMAGES_FOLDER = "/kaggle/input/chest-xrays-indiana-university/images/images_normalized"
OUTPUT_CSV    = "gemini_zero_shot_results_batch.csv"  # renamed

MAX_ROWS          = 100     # max unique uids to process
MAX_NEW_TOKENS    = 400     # short paragraph target
FLUSH_EVERY       = 50
CLEAN_CACHE_EVERY = 25
DOWNSCALE_IMAGES  = True
MAX_IMAGE_SIDE    = 1024

# ‚îÄ‚îÄ‚îÄ Prompting ‚îÄ‚îÄ‚îÄ
SYSTEM_INSTRUCTION = """
You are an expert radiologist.
Your role is to carefully analyze chest X-rays and provide accurate, clinically useful findings.
Use precise, professional language and avoid speculation.
"""

PROMPT = """
You are an expert radiologist. Analyze the following chest X-ray and provide a short, single-paragraph summary of the findings.
Focus only on the most relevant abnormalities (if any), or clearly state if the film appears normal.
Keep the response concise and professional, suitable for a radiology report.
"""

# ‚îÄ‚îÄ‚îÄ Gemini Config ‚îÄ‚îÄ‚îÄ
GEMINI_API_KEYS = [
    # os.environ.get("GEMINI_API_KEY"),  # example (single key)
    # "YOUR_KEY_1", "YOUR_KEY_2", ...
]
# Gemini 2.5 Flash endpoint (update if needed)
GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent"

# ----------------- IMAGE HELPERS -----------------
def downscale(img, max_side=1024):
    if not DOWNSCALE_IMAGES:
        return img
    w, h = img.size
    m = max(w, h)
    if m <= max_side:
        return img
    scale = max_side / m
    return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

def pil_to_base64_jpeg(img, quality=90):
    """Convert PIL image to base64-encoded JPEG bytes (RGB)"""
    if img.mode != "RGB":
        img = img.convert("RGB")
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=quality, optimize=True)
    return base64.b64encode(buf.getvalue()).decode("utf-8"), "image/jpeg"

# ----------------- GEMINI CLIENT -----------------
class GeminiClient:
    def __init__(self, api_keys, endpoint, timeout=60):
        assert len(api_keys) > 0, "At least one API key is required."
        self.api_keys = api_keys
        self.endpoint = endpoint
        self.timeout = timeout
        self.idx = 0

    def _current_key(self):
        return self.api_keys[self.idx]

    def _rotate_key(self):
        self.idx = (self.idx + 1) % len(self.api_keys)

    def generate(self, system_instruction, user_text, img_b64, img_mime, max_output_tokens=400, temperature=0.0):
        """
        Returns (text, error) where one of them is None.
        Rotates keys on quota/429/5xx; retries with exponential backoff.
        """
        payload = {
            "systemInstruction": {"parts": [{"text": system_instruction.strip()}]},
            "contents": [{
                "role": "user",
                "parts": [
                    {"text": user_text.strip()},
                    {"inline_data": {"mime_type": img_mime, "data": img_b64}}
                ]
            }],
            "generationConfig": {
                "maxOutputTokens": int(max_output_tokens),
                "temperature": float(temperature),
                "topK": 40,
                "topP": 0.95,
            },
            # Safety settings optional; using defaults
        }

        # Up to (#keys * 2) attempts with rotation and backoff
        attempts = 0
        backoff = 1.0
        max_attempts = len(self.api_keys) * 2

        while attempts < max_attempts:
            attempts += 1
            key = self._current_key()
            try:
                r = requests.post(
                    f"{self.endpoint}?key={key}",
                    headers={"Content-Type": "application/json"},
                    data=json.dumps(payload),
                    timeout=self.timeout,
                )
            except requests.RequestException as e:
                # transient network error: backoff and retry
                time.sleep(backoff)
                backoff = min(backoff * 2, 16)
                continue

            # Handle HTTP status
            if r.status_code == 200:
                try:
                    data = r.json()
                except Exception as e:
                    return None, f"ERROR: bad JSON from Gemini: {e}"

                # Safety blocks or no candidates
                if "candidates" not in data or not data["candidates"]:
                    # Check for safety/prompt feedback
                    pf = data.get("promptFeedback", {})
                    safety = pf.get("blockReason", None)
                    if safety:
                        return None, f"ERROR: response blocked by safety: {safety}"
                    return None, "ERROR: no candidates returned"

                # Extract first candidate text
                try:
                    parts = data["candidates"][0]["content"]["parts"]
                    texts = [p.get("text", "") for p in parts if "text" in p]
                    out_text = "\n".join([t for t in texts if t]).strip()
                    if not out_text:
                        return None, "ERROR: empty text in candidate"
                    return out_text, None
                except Exception as e:
                    return None, f"ERROR: parse failure: {e}"

            elif r.status_code in (429, 500, 502, 503, 504):
                # rate limited or server issue: rotate key on 429/403 quota; otherwise backoff
                try:
                    body = r.json()
                except Exception:
                    body = {"error": {"message": r.text}}
                msg = (body.get("error", {}) or {}).get("message", r.text)
                # If quota/rate-limit-ish, rotate key
                if r.status_code == 429 or "quota" in msg.lower() or "rate" in msg.lower():
                    self._rotate_key()
                time.sleep(backoff)
                backoff = min(backoff * 2, 16)
                continue
            elif r.status_code in (401, 403):
                # invalid/expired key: rotate
                self._rotate_key()
                time.sleep(backoff)
                backoff = min(backoff * 2, 16)
                continue
            else:
                # Non-retryable error
                try:
                    body = r.json()
                except Exception:
                    body = {"error": {"message": r.text}}
                msg = (body.get("error", {}) or {}).get("message", r.text)
                return None, f"ERROR {r.status_code}: {msg}"

        return None, "ERROR: all API keys exhausted or repeated failures"

# ----------------- MAIN -----------------
def main():
    print("Reading merged CSV...")
    base_df = pd.read_csv(MERGED_CSV, usecols=["uid", "filename", "findings"])
    base_df = base_df.drop_duplicates(subset="uid").reset_index(drop=True)

    if len(base_df) > MAX_ROWS:
        base_df = base_df.iloc[:MAX_ROWS].copy()

    existing_df = None
    done_uids = set()

    if os.path.exists(OUTPUT_CSV):
        try:
            existing_df = pd.read_csv(OUTPUT_CSV)
            done_uids = set(existing_df["uid"].astype(str))
            before = len(base_df)
            base_df = base_df[~base_df["uid"].astype(str).isin(done_uids)].reset_index(drop=True)
            print(f"Resuming: {len(done_uids)} already processed; {len(base_df)} remaining (filtered {before - len(base_df)}).")
        except Exception as e:
            print("Could not read existing output; starting fresh.", e)
            existing_df = None

    print("Rows to process in this run:", len(base_df))
    if len(base_df) == 0:
        print("Nothing new to process. (All done.)")
        return

    client = GeminiClient(GEMINI_API_KEYS, GEMINI_ENDPOINT)

    results_buffer = []
    pbar = tqdm(base_df.itertuples(index=False), total=len(base_df), desc="Gemini batch")

    for i, row in enumerate(pbar, start=1):
        uid      = row.uid
        filename = row.filename
        finding  = row.findings
        img_path = os.path.join(IMAGES_FOLDER, filename)

        if not os.path.exists(img_path):
            model_out = "ERROR: image not found"
        else:
            try:
                with Image.open(img_path) as img_in:
                    img = downscale(img_in, MAX_IMAGE_SIDE)
                    img_b64, mime = pil_to_base64_jpeg(img)

                text, err = client.generate(
                    system_instruction=SYSTEM_INSTRUCTION,
                    user_text=PROMPT,
                    img_b64=img_b64,
                    img_mime=mime,
                    max_output_tokens=MAX_NEW_TOKENS,
                    temperature=0.0
                )
                model_out = text if err is None else err
            except Exception as e:
                model_out = f"ERROR: {e}"

        results_buffer.append({
            "uid":              uid,
            "filename":         filename,
            "original_finding": finding,
            "prompt":           PROMPT.strip(),
            "gemini_output":    model_out
        })

        # Light memory hygiene (even though no CUDA here)
        if i % CLEAN_CACHE_EVERY == 0:
            gc.collect()

        # Flush partial results
        if i % FLUSH_EVERY == 0 or i == len(base_df):
            new_df = pd.DataFrame(results_buffer)
            combined = new_df if existing_df is None else pd.concat([existing_df, new_df], ignore_index=True)
            combined.to_csv(OUTPUT_CSV, index=False)
            existing_df = combined
            results_buffer.clear()
            pbar.set_postfix(saved_rows=len(existing_df))

    print("‚úÖ Finished processing. CSV saved:", OUTPUT_CSV)

if __name__ == "__main__":
    main()

In [None]:
# ----------------- sec5b: Verify CSV contents -----------------
import os
import pandas as pd

# Ensure full cell display of long text
pd.set_option('display.max_colwidth', None)

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)
    print("Rows in CSV:", len(out))
    print("Columns:", out.columns.tolist())
    display(out.head())
else:
    print("CSV not found - something went wrong.")

In [None]:
import matplotlib.pyplot as plt

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)

    for i, row in out.head(2).iterrows():
        img_path = os.path.join(IMAGES_FOLDER, row["filename"])
        
        if not os.path.exists(img_path):
            print(f"[{row['uid']}] Image not found: {row['filename']}")
            continue

        # Show image
        img = Image.open(img_path).convert("RGB")
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"UID: {row['uid']} | File: {row['filename']}", fontsize=10)
        plt.show()

        # Show text info
        print("üîπ Original Finding:")
        print(row["original_finding"])
        print("\nüîπ Gemini Output:")
        print(row["gemini_output"])
        print("=" * 80)

## Evaluation

In [None]:
!pip install bert-score

In [None]:
import os
import pandas as pd
import numpy as np
from collections import Counter
from bert_score import score as bert_score

# ---------------- Utility functions ----------------

def tokenize(text):
    return text.lower().split()

# ---- Custom BLEU (unigram-based, with brevity penalty) ----
def bleu_score(reference, candidate, n=4):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    # brevity penalty
    ref_len, cand_len = len(ref_tokens), len(cand_tokens)
    bp = np.exp(1 - ref_len / cand_len) if cand_len < ref_len else 1
    
    # n-gram precisions
    precisions = []
    for i in range(1, n+1):
        ref_ngrams = Counter([tuple(ref_tokens[j:j+i]) for j in range(len(ref_tokens)-i+1)])
        cand_ngrams = Counter([tuple(cand_tokens[j:j+i]) for j in range(len(cand_tokens)-i+1)])
        
        overlap = sum((cand_ngrams & ref_ngrams).values())
        total = sum(cand_ngrams.values())
        precisions.append(overlap / total if total > 0 else 0)
    
    # geometric mean of precisions
    if all(p == 0 for p in precisions):
        return 0
    geo_mean = np.exp(np.mean([np.log(p) if p > 0 else -9999 for p in precisions]))
    
    return bp * geo_mean

# ---- Custom ROUGE ----
def rouge_n(reference, candidate, n=1):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens)-n+1)])
    cand_ngrams = Counter([tuple(cand_tokens[i:i+n]) for i in range(len(cand_tokens)-n+1)])
    
    overlap = sum((cand_ngrams & ref_ngrams).values())
    
    recall = overlap / sum(ref_ngrams.values()) if ref_ngrams else 0
    precision = overlap / sum(cand_ngrams.values()) if cand_ngrams else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
    
    return recall, precision, f1

def rouge_l(reference, candidate):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    m, n = len(ref_tokens), len(cand_tokens)
    
    # LCS (Longest Common Subsequence)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m):
        for j in range(n):
            if ref_tokens[i] == cand_tokens[j]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    lcs = dp[m][n]
    
    recall = lcs / m if m > 0 else 0
    precision = lcs / n if n > 0 else 0
    f1 = (2 * recall * precision) / (recall + precision) if (recall+precision) > 0 else 0
    return recall, precision, f1

# ---------------- Main ----------------

OUTPUT_CSV = "gemini_zero_shot_results_batch.csv"
df = pd.read_csv(OUTPUT_CSV)

bleu_scores = []
rouge1_f1, rouge2_f1, rouge3_f1, rougel_f1 = [], [], [], []
refs, cands = [], []

for _, row in df.iterrows():
    ref = str(row["original_finding"])
    cand = str(row["gemini_output"])
    
    # BLEU
    bleu_scores.append(bleu_score(ref, cand))
    
    # ROUGE
    _, _, r1 = rouge_n(ref, cand, 1)
    _, _, r2 = rouge_n(ref, cand, 2)
    _, _, r3 = rouge_n(ref, cand, 3)
    _, _, rl = rouge_l(ref, cand)
    
    rouge1_f1.append(r1)
    rouge2_f1.append(r2)
    rouge3_f1.append(r3)
    rougel_f1.append(rl)
    
    refs.append(ref)
    cands.append(cand)

# BERTScore
P, R, F1 = bert_score(cands, refs, lang="en", verbose=True)

# ---------------- Results ----------------
print("Average Metrics on dataset:")
print(f"BLEU:     {np.mean(bleu_scores):.4f}")
print(f"ROUGE-L:  {np.mean(rougel_f1):.4f}")
print(f"BERTScore: {F1.mean().item():.4f}")

# 3-Shot Prompting

In [None]:
import os, gc, io, time, base64, json
import pandas as pd
from PIL import Image
from tqdm import tqdm
import requests

# ----------------- USER CONFIG -----------------
MERGED_CSV    = "/kaggle/input/dsetindiana/indiana_merged.csv"
IMAGES_FOLDER = "/kaggle/input/chest-xrays-indiana-university/images/images_normalized"
OUTPUT_CSV    = "gemini_3shot_results_batch.csv"

MAX_ROWS          = 100
MAX_NEW_TOKENS    = 400
FLUSH_EVERY       = 50
CLEAN_CACHE_EVERY = 25
DOWNSCALE_IMAGES  = True
MAX_IMAGE_SIDE    = 1024

# ‚îÄ‚îÄ‚îÄ Prompting (3-shot) ‚îÄ‚îÄ‚îÄ
SYSTEM_INSTRUCTION = """
You are an expert radiologist. 
Your role is to carefully analyze chest X-rays and provide accurate, clinically useful findings.
Use precise, professional language and avoid speculation.
"""

FEW_SHOT_EXAMPLES = """
Here are three examples of high-quality chest X-ray findings:

Example 1 (uid=1, filename=1_IM-0001-4001.dcm.png):
"The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no signs of a pleural effusion. There is no evidence of pneumothorax."

Example 2 (uid=4, filename=4_IM-2050-1001.dcm.png):
"There are diffuse bilateral interstitial and alveolar opacities consistent with chronic obstructive lung disease and bullous emphysema. There are irregular opacities in the left lung apex, that could represent a cavitary lesion in the left lung apex. There are streaky opacities in the right upper lobe, consistent with scarring. The cardiomediastinal silhouette is normal in size and contour. There is no pneumothorax or large pleural effusion."

Example 3 (uid=5, filename=5_IM-2117-1003002.dcm.png):
"The cardiomediastinal silhouette and pulmonary vasculature are within normal limits. There is no pneumothorax or pleural effusion. There are no focal areas of consolidation. Cholecystectomy clips are present. Small T-spine osteophytes are noted. There is biapical pleural thickening, unchanged from prior. Mildly hyperexpanded lungs."
"""

PROMPT = f"""
{FEW_SHOT_EXAMPLES}

Now, based on the above style and level of detail, analyze the following chest X-ray 
and provide a short, single-paragraph summary of the findings. 
Focus only on the most relevant abnormalities (if any), or clearly state if the film appears normal. 
Keep the response concise and professional, suitable for a radiology report.
"""

# ‚îÄ‚îÄ‚îÄ Gemini Config ‚îÄ‚îÄ‚îÄ
GEMINI_API_KEYS = [
    # os.environ.get("GEMINI_API_KEY"),  # example (single key)
    # "YOUR_KEY_1", "YOUR_KEY_2", ...
]
GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent"

# ----------------- IMAGE HELPERS -----------------
def downscale(img, max_side=1024):
    if not DOWNSCALE_IMAGES:
        return img
    w, h = img.size
    m = max(w, h)
    if m <= max_side:
        return img
    scale = max_side / m
    return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

def pil_to_base64_jpeg(img, quality=90):
    if img.mode != "RGB":
        img = img.convert("RGB")
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=quality, optimize=True)
    return base64.b64encode(buf.getvalue()).decode("utf-8"), "image/jpeg"

# ----------------- GEMINI CLIENT -----------------
class GeminiClient:
    def __init__(self, api_keys, endpoint, timeout=60):
        assert api_keys, "At least one API key is required."
        self.api_keys = api_keys
        self.endpoint = endpoint
        self.timeout = timeout
        self.idx = 0

    def _current_key(self):
        return self.api_keys[self.idx]

    def _rotate_key(self):
        self.idx = (self.idx + 1) % len(self.api_keys)

    def generate(self, system_instruction, user_text, img_b64, img_mime,
                 max_output_tokens=400, temperature=0.0):
        payload = {
            "systemInstruction": {"parts": [{"text": system_instruction.strip()}]},
            "contents": [{
                "role": "user",
                "parts": [
                    {"text": user_text.strip()},
                    {"inline_data": {"mime_type": img_mime, "data": img_b64}}
                ]
            }],
            "generationConfig": {
                "maxOutputTokens": int(max_output_tokens),
                "temperature": float(temperature),
                "topK": 40,
                "topP": 0.95
            }
        }

        attempts, backoff = 0, 1.0
        max_attempts = len(self.api_keys) * 2

        while attempts < max_attempts:
            attempts += 1
            key = self._current_key()
            try:
                r = requests.post(
                    f"{self.endpoint}?key={key}",
                    headers={"Content-Type": "application/json"},
                    data=json.dumps(payload),
                    timeout=self.timeout,
                )
            except requests.RequestException:
                time.sleep(backoff); backoff = min(backoff * 2, 16)
                continue

            if r.status_code == 200:
                try:
                    data = r.json()
                except Exception as e:
                    return None, f"ERROR: bad JSON from Gemini: {e}"

                if "candidates" not in data or not data["candidates"]:
                    pf = data.get("promptFeedback", {})
                    safety = pf.get("blockReason")
                    if safety:
                        return None, f"ERROR: response blocked by safety: {safety}"
                    return None, "ERROR: no candidates returned"

                try:
                    parts = data["candidates"][0]["content"]["parts"]
                    texts = [p.get("text", "") for p in parts if "text" in p]
                    out_text = "\n".join([t for t in texts if t]).strip()
                    if not out_text:
                        return None, "ERROR: empty text in candidate"
                    return out_text, None
                except Exception as e:
                    return None, f"ERROR: parse failure: {e}"

            elif r.status_code in (429, 500, 502, 503, 504):
                # rate limit or transient server errors
                try:
                    body = r.json()
                except Exception:
                    body = {"error": {"message": r.text}}
                msg = (body.get("error", {}) or {}).get("message", r.text)
                if r.status_code == 429 or "quota" in msg.lower() or "rate" in msg.lower():
                    self._rotate_key()
                time.sleep(backoff); backoff = min(backoff * 2, 16)
                continue

            elif r.status_code in (401, 403):
                # bad/expired key; rotate and retry
                self._rotate_key()
                time.sleep(backoff); backoff = min(backoff * 2, 16)
                continue

            else:
                try:
                    body = r.json()
                except Exception:
                    body = {"error": {"message": r.text}}
                msg = (body.get("error", {}) or {}).get("message", r.text)
                return None, f"ERROR {r.status_code}: {msg}"

        return None, "ERROR: all API keys exhausted or repeated failures"

# ----------------- MAIN -----------------
def main():
    print("Reading merged CSV...")
    base_df = pd.read_csv(MERGED_CSV, usecols=["uid", "filename", "findings"])
    base_df = base_df.drop_duplicates(subset="uid").reset_index(drop=True)

    if len(base_df) > MAX_ROWS:
        base_df = base_df.iloc[:MAX_ROWS].copy()

    existing_df = None
    done_uids = set()

    if os.path.exists(OUTPUT_CSV):
        try:
            existing_df = pd.read_csv(OUTPUT_CSV)
            done_uids = set(existing_df["uid"].astype(str))
            before = len(base_df)
            base_df = base_df[~base_df["uid"].astype(str).isin(done_uids)].reset_index(drop=True)
            print(f"Resuming: {len(done_uids)} already processed; {len(base_df)} remaining (filtered {before - len(base_df)}).")
        except Exception as e:
            print("Could not read existing output; starting fresh.", e)
            existing_df = None

    print("Rows to process in this run:", len(base_df))
    if len(base_df) == 0:
        print("Nothing new to process. (All done.)")
        return

    client = GeminiClient(GEMINI_API_KEYS, GEMINI_ENDPOINT)

    results_buffer = []
    pbar = tqdm(base_df.itertuples(index=False), total=len(base_df), desc="Gemini 3-shot batch")

    for i, row in enumerate(pbar, start=1):
        uid       = row.uid
        filename  = row.filename
        finding   = row.findings
        img_path  = os.path.join(IMAGES_FOLDER, filename)

        if not os.path.exists(img_path):
            model_out = "ERROR: image not found"
        else:
            try:
                with Image.open(img_path) as img_in:
                    img = downscale(img_in, MAX_IMAGE_SIDE)
                    img_b64, mime = pil_to_base64_jpeg(img)

                text, err = client.generate(
                    system_instruction=SYSTEM_INSTRUCTION,
                    user_text=PROMPT,
                    img_b64=img_b64,
                    img_mime=mime,
                    max_output_tokens=MAX_NEW_TOKENS,
                    temperature=0.0
                )
                model_out = text if err is None else err
            except Exception as e:
                model_out = f"ERROR: {e}"

        results_buffer.append({
            "uid":              uid,
            "filename":         filename,
            "original_finding": finding,
            "prompt":           "3-shot radiology prompt",
            "gemini_output":    model_out
        })

        # Light memory hygiene
        if i % CLEAN_CACHE_EVERY == 0:
            gc.collect()

        # Flush partial results
        if i % FLUSH_EVERY == 0 or i == len(base_df):
            new_df = pd.DataFrame(results_buffer)
            combined = new_df if existing_df is None else pd.concat([existing_df, new_df], ignore_index=True)
            combined.to_csv(OUTPUT_CSV, index=False)
            existing_df = combined
            results_buffer.clear()
            pbar.set_postfix(saved_rows=len(existing_df))

    print("‚úÖ Finished processing. CSV saved:", OUTPUT_CSV)

if __name__ == "__main__":
    main()

In [None]:
# ----------------- sec5b: Verify CSV contents -----------------
import os
import pandas as pd

# Ensure full cell display of long text
pd.set_option('display.max_colwidth', None)

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)
    print("Rows in CSV:", len(out))
    print("Columns:", out.columns.tolist())
    display(out.head())
else:
    print("CSV not found - something went wrong.")

In [None]:
import matplotlib.pyplot as plt

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)

    for i, row in out.head(2).iterrows():
        img_path = os.path.join(IMAGES_FOLDER, row["filename"])
        
        if not os.path.exists(img_path):
            print(f"[{row['uid']}] Image not found: {row['filename']}")
            continue

        # Show image
        img = Image.open(img_path).convert("RGB")
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"UID: {row['uid']} | File: {row['filename']}", fontsize=10)
        plt.show()

        # Show text info
        print("üîπ Original Finding:")
        print(row["original_finding"])
        print("\nüîπ Gemini Output:")
        print(row["gemini_output"])
        print("=" * 80)

## Evaluation

In [None]:
import os
import pandas as pd
import numpy as np
from collections import Counter
from bert_score import score as bert_score

# ---------------- Utility functions ----------------

def tokenize(text):
    return text.lower().split()

# ---- Custom BLEU (unigram-based, with brevity penalty) ----
def bleu_score(reference, candidate, n=4):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    # brevity penalty
    ref_len, cand_len = len(ref_tokens), len(cand_tokens)
    bp = np.exp(1 - ref_len / cand_len) if cand_len < ref_len else 1
    
    # n-gram precisions
    precisions = []
    for i in range(1, n+1):
        ref_ngrams = Counter([tuple(ref_tokens[j:j+i]) for j in range(len(ref_tokens)-i+1)])
        cand_ngrams = Counter([tuple(cand_tokens[j:j+i]) for j in range(len(cand_tokens)-i+1)])
        
        overlap = sum((cand_ngrams & ref_ngrams).values())
        total = sum(cand_ngrams.values())
        precisions.append(overlap / total if total > 0 else 0)
    
    # geometric mean of precisions
    if all(p == 0 for p in precisions):
        return 0
    geo_mean = np.exp(np.mean([np.log(p) if p > 0 else -9999 for p in precisions]))
    
    return bp * geo_mean

# ---- Custom ROUGE ----
def rouge_n(reference, candidate, n=1):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens)-n+1)])
    cand_ngrams = Counter([tuple(cand_tokens[i:i+n]) for i in range(len(cand_tokens)-n+1)])
    
    overlap = sum((cand_ngrams & ref_ngrams).values())
    
    recall = overlap / sum(ref_ngrams.values()) if ref_ngrams else 0
    precision = overlap / sum(cand_ngrams.values()) if cand_ngrams else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
    
    return recall, precision, f1

def rouge_l(reference, candidate):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    m, n = len(ref_tokens), len(cand_tokens)
    
    # LCS (Longest Common Subsequence)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m):
        for j in range(n):
            if ref_tokens[i] == cand_tokens[j]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    lcs = dp[m][n]
    
    recall = lcs / m if m > 0 else 0
    precision = lcs / n if n > 0 else 0
    f1 = (2 * recall * precision) / (recall + precision) if (recall+precision) > 0 else 0
    return recall, precision, f1

# ---------------- Main ----------------

OUTPUT_CSV = "gemini_3shot_results_batch.csv"
df = pd.read_csv(OUTPUT_CSV)

bleu_scores = []
rouge1_f1, rouge2_f1, rouge3_f1, rougel_f1 = [], [], [], []
refs, cands = [], []

for _, row in df.iterrows():
    ref = str(row["original_finding"])
    cand = str(row["gemini_output"])
    
    # BLEU
    bleu_scores.append(bleu_score(ref, cand))
    
    # ROUGE
    _, _, r1 = rouge_n(ref, cand, 1)
    _, _, r2 = rouge_n(ref, cand, 2)
    _, _, r3 = rouge_n(ref, cand, 3)
    _, _, rl = rouge_l(ref, cand)
    
    rouge1_f1.append(r1)
    rouge2_f1.append(r2)
    rouge3_f1.append(r3)
    rougel_f1.append(rl)
    
    refs.append(ref)
    cands.append(cand)

# BERTScore
P, R, F1 = bert_score(cands, refs, lang="en", verbose=True)

# ---------------- Results ----------------
print("Average Metrics on dataset:")
print(f"BLEU:     {np.mean(bleu_scores):.4f}")
print(f"ROUGE-L:  {np.mean(rougel_f1):.4f}")
print(f"BERTScore: {F1.mean().item():.4f}")

# 5-Shot Prompting

In [None]:
import os, gc, io, time, base64, json
import pandas as pd
from PIL import Image
from tqdm import tqdm
import requests

# ----------------- USER CONFIG -----------------
MERGED_CSV    = "/kaggle/input/dsetindiana/indiana_merged.csv"
IMAGES_FOLDER = "/kaggle/input/chest-xrays-indiana-university/images/images_normalized"
OUTPUT_CSV    = "gemini_5shot_results_batch.csv"

MAX_ROWS          = 100
MAX_NEW_TOKENS    = 400
FLUSH_EVERY       = 50
CLEAN_CACHE_EVERY = 25
DOWNSCALE_IMAGES  = True
MAX_IMAGE_SIDE    = 1024

# ‚îÄ‚îÄ‚îÄ Configuration ‚îÄ‚îÄ‚îÄ
SYSTEM_INSTRUCTION = """
You are an expert radiologist. 
Your role is to carefully analyze chest X-rays and provide accurate, clinically useful findings.
Use precise, professional language and avoid speculation.
"""

# 5-shot examples
FEW_SHOT_EXAMPLES = """
Here are five examples of high-quality chest X-ray findings:

Example 1 (uid=1, filename=1_IM-0001-4001.dcm.png):
"The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no signs of a pleural effusion. There is no evidence of pneumothorax."

Example 2 (uid=4, filename=4_IM-2050-1001.dcm.png):
"There are diffuse bilateral interstitial and alveolar opacities consistent with chronic obstructive lung disease and bullous emphysema. There are irregular opacities in the left lung apex, that could represent a cavitary lesion in the left lung apex. There are streaky opacities in the right upper lobe, consistent with scarring. The cardiomediastinal silhouette is normal in size and contour. There is no pneumothorax or large pleural effusion."

Example 3 (uid=5, filename=5_IM-2117-1003002.dcm.png):
"The cardiomediastinal silhouette and pulmonary vasculature are within normal limits. There is no pneumothorax or pleural effusion. There are no focal areas of consolidation. Cholecystectomy clips are present. Small T-spine osteophytes are noted. There is biapical pleural thickening, unchanged from prior. Mildly hyperexpanded lungs."

Example 4 (uid=7, filename=7_IM-2263-1001.dcm.png):
"The cardiac contours are normal. Mild basilar atelectasis is present. The lungs are otherwise clear. Thoracic spondylosis is seen. Lower cervical spine arthritis is noted."

Example 5 (uid=8, filename=8_IM-2333-1001.dcm.png):
"The heart, pulmonary vasculature, and mediastinum are within normal limits. There is no pleural effusion or pneumothorax. There is no focal air space opacity to suggest pneumonia. An interim lower cervical spinal fusion is partly evaluated."
"""

PROMPT = f"""
{FEW_SHOT_EXAMPLES}

Now, based on the above style and level of detail, analyze the following chest X-ray 
and provide a short, single-paragraph summary of the findings. 
Focus only on the most relevant abnormalities (if any), or clearly state if the film appears normal. 
Keep the response concise and professional, suitable for a radiology report.
"""

# ‚îÄ‚îÄ‚îÄ Gemini Config ‚îÄ‚îÄ‚îÄ
GEMINI_API_KEYS = [
    # os.environ.get("GEMINI_API_KEY"),  # example (single key)
    # "YOUR_KEY_1", "YOUR_KEY_2", ...
]
GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent"

# ----------------- IMAGE HELPERS -----------------
def downscale(img, max_side=1024):
    if not DOWNSCALE_IMAGES:
        return img
    w, h = img.size
    m = max(w, h)
    if m <= max_side:
        return img
    scale = max_side / m
    return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

def pil_to_base64_jpeg(img, quality=90):
    if img.mode != "RGB":
        img = img.convert("RGB")
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=quality, optimize=True)
    return base64.b64encode(buf.getvalue()).decode("utf-8"), "image/jpeg"

# ----------------- GEMINI CLIENT -----------------
class GeminiClient:
    def __init__(self, api_keys, endpoint, timeout=60):
        assert api_keys, "At least one API key is required."
        self.api_keys = api_keys
        self.endpoint = endpoint
        self.timeout = timeout
        self.idx = 0

    def _current_key(self):
        return self.api_keys[self.idx]

    def _rotate_key(self):
        self.idx = (self.idx + 1) % len(self.api_keys)

    def generate(self, system_instruction, user_text, img_b64, img_mime,
                 max_output_tokens=400, temperature=0.0):
        payload = {
            "systemInstruction": {"parts": [{"text": system_instruction.strip()}]},
            "contents": [{
                "role": "user",
                "parts": [
                    {"text": user_text.strip()},
                    {"inline_data": {"mime_type": img_mime, "data": img_b64}}
                ]
            }],
            "generationConfig": {
                "maxOutputTokens": int(max_output_tokens),
                "temperature": float(temperature),
                "topK": 40,
                "topP": 0.95
            }
        }

        attempts, backoff = 0, 1.0
        max_attempts = len(self.api_keys) * 2

        while attempts < max_attempts:
            attempts += 1
            key = self._current_key()
            try:
                r = requests.post(
                    f"{self.endpoint}?key={key}",
                    headers={"Content-Type": "application/json"},
                    data=json.dumps(payload),
                    timeout=self.timeout,
                )
            except requests.RequestException:
                time.sleep(backoff); backoff = min(backoff * 2, 16)
                continue

            if r.status_code == 200:
                try:
                    data = r.json()
                except Exception as e:
                    return None, f"ERROR: bad JSON from Gemini: {e}"

                if "candidates" not in data or not data["candidates"]:
                    pf = data.get("promptFeedback", {})
                    safety = pf.get("blockReason")
                    if safety:
                        return None, f"ERROR: response blocked by safety: {safety}"
                    return None, "ERROR: no candidates returned"

                try:
                    parts = data["candidates"][0]["content"]["parts"]
                    texts = [p.get("text", "") for p in parts if "text" in p]
                    out_text = "\n".join([t for t in texts if t]).strip()
                    if not out_text:
                        return None, "ERROR: empty text in candidate"
                    return out_text, None
                except Exception as e:
                    return None, f"ERROR: parse failure: {e}"

            elif r.status_code in (429, 500, 502, 503, 504):
                # rate limit or transient server errors
                try:
                    body = r.json()
                except Exception:
                    body = {"error": {"message": r.text}}
                msg = (body.get("error", {}) or {}).get("message", r.text)
                if r.status_code == 429 or "quota" in msg.lower() or "rate" in msg.lower():
                    self._rotate_key()
                time.sleep(backoff); backoff = min(backoff * 2, 16)
                continue

            elif r.status_code in (401, 403):
                self._rotate_key()
                time.sleep(backoff); backoff = min(backoff * 2, 16)
                continue

            else:
                try:
                    body = r.json()
                except Exception:
                    body = {"error": {"message": r.text}}
                msg = (body.get("error", {}) or {}).get("message", r.text)
                return None, f"ERROR {r.status_code}: {msg}"

        return None, "ERROR: all API keys exhausted or repeated failures"

# ----------------- MAIN -----------------
def main():
    print("Reading merged CSV...")
    base_df = pd.read_csv(MERGED_CSV, usecols=["uid", "filename", "findings"])
    base_df = base_df.drop_duplicates(subset="uid").reset_index(drop=True)

    if len(base_df) > MAX_ROWS:
        base_df = base_df.iloc[:MAX_ROWS].copy()

    existing_df = None
    done_uids = set()

    if os.path.exists(OUTPUT_CSV):
        try:
            existing_df = pd.read_csv(OUTPUT_CSV)
            done_uids = set(existing_df["uid"].astype(str))
            before = len(base_df)
            base_df = base_df[~base_df["uid"].astype(str).isin(done_uids)].reset_index(drop=True)
            print(f"Resuming: {len(done_uids)} already processed; {len(base_df)} remaining (filtered {before - len(base_df)}).")
        except Exception as e:
            print("Could not read existing output; starting fresh.", e)
            existing_df = None

    print("Rows to process in this run:", len(base_df))
    if len(base_df) == 0:
        print("Nothing new to process. (All done.)")
        return

    client = GeminiClient(GEMINI_API_KEYS, GEMINI_ENDPOINT)

    results_buffer = []
    pbar = tqdm(base_df.itertuples(index=False), total=len(base_df), desc="Gemini 5-shot batch")

    for i, row in enumerate(pbar, start=1):
        uid       = row.uid
        filename  = row.filename
        finding   = row.findings
        img_path  = os.path.join(IMAGES_FOLDER, filename)

        if not os.path.exists(img_path):
            model_out = "ERROR: image not found"
        else:
            try:
                with Image.open(img_path) as img_in:
                    img = downscale(img_in, MAX_IMAGE_SIDE)
                    img_b64, mime = pil_to_base64_jpeg(img)

                text, err = client.generate(
                    system_instruction=SYSTEM_INSTRUCTION,
                    user_text=PROMPT,
                    img_b64=img_b64,
                    img_mime=mime,
                    max_output_tokens=MAX_NEW_TOKENS,
                    temperature=0.0
                )
                model_out = text if err is None else err
            except Exception as e:
                model_out = f"ERROR: {e}"

        results_buffer.append({
            "uid":              uid,
            "filename":         filename,
            "original_finding": finding,
            "prompt":           "5-shot radiology prompt",
            "gemini_output":    model_out
        })

        if i % CLEAN_CACHE_EVERY == 0:
            gc.collect()

        if i % FLUSH_EVERY == 0 or i == len(base_df):
            new_df = pd.DataFrame(results_buffer)
            combined = new_df if existing_df is None else pd.concat([existing_df, new_df], ignore_index=True)
            combined.to_csv(OUTPUT_CSV, index=False)
            existing_df = combined
            results_buffer.clear()
            pbar.set_postfix(saved_rows=len(existing_df))

    print("‚úÖ Finished processing. CSV saved:", OUTPUT_CSV)

if __name__ == "__main__":
    main()

In [None]:
# ----------------- sec5b: Verify CSV contents -----------------
import os
import pandas as pd

# Ensure full cell display of long text
pd.set_option('display.max_colwidth', None)

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)
    print("Rows in CSV:", len(out))
    print("Columns:", out.columns.tolist())
    display(out.head())
else:
    print("CSV not found - something went wrong.")

In [None]:
import matplotlib.pyplot as plt

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)

    for i, row in out.head(2).iterrows():
        img_path = os.path.join(IMAGES_FOLDER, row["filename"])
        
        if not os.path.exists(img_path):
            print(f"[{row['uid']}] Image not found: {row['filename']}")
            continue

        # Show image
        img = Image.open(img_path).convert("RGB")
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"UID: {row['uid']} | File: {row['filename']}", fontsize=10)
        plt.show()

        # Show text info
        print("üîπ Original Finding:")
        print(row["original_finding"])
        print("\nüîπ Gemini Output:")
        print(row["gemini_output"])
        print("=" * 80)

## Evaluation

In [None]:
import os
import pandas as pd
import numpy as np
from collections import Counter
from bert_score import score as bert_score

# ---------------- Utility functions ----------------

def tokenize(text):
    return text.lower().split()

# ---- Custom BLEU (unigram-based, with brevity penalty) ----
def bleu_score(reference, candidate, n=4):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    # brevity penalty
    ref_len, cand_len = len(ref_tokens), len(cand_tokens)
    bp = np.exp(1 - ref_len / cand_len) if cand_len < ref_len else 1
    
    # n-gram precisions
    precisions = []
    for i in range(1, n+1):
        ref_ngrams = Counter([tuple(ref_tokens[j:j+i]) for j in range(len(ref_tokens)-i+1)])
        cand_ngrams = Counter([tuple(cand_tokens[j:j+i]) for j in range(len(cand_tokens)-i+1)])
        
        overlap = sum((cand_ngrams & ref_ngrams).values())
        total = sum(cand_ngrams.values())
        precisions.append(overlap / total if total > 0 else 0)
    
    # geometric mean of precisions
    if all(p == 0 for p in precisions):
        return 0
    geo_mean = np.exp(np.mean([np.log(p) if p > 0 else -9999 for p in precisions]))
    
    return bp * geo_mean

# ---- Custom ROUGE ----
def rouge_n(reference, candidate, n=1):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    
    ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens)-n+1)])
    cand_ngrams = Counter([tuple(cand_tokens[i:i+n]) for i in range(len(cand_tokens)-n+1)])
    
    overlap = sum((cand_ngrams & ref_ngrams).values())
    
    recall = overlap / sum(ref_ngrams.values()) if ref_ngrams else 0
    precision = overlap / sum(cand_ngrams.values()) if cand_ngrams else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
    
    return recall, precision, f1

def rouge_l(reference, candidate):
    ref_tokens = tokenize(reference)
    cand_tokens = tokenize(candidate)
    m, n = len(ref_tokens), len(cand_tokens)
    
    # LCS (Longest Common Subsequence)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m):
        for j in range(n):
            if ref_tokens[i] == cand_tokens[j]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    lcs = dp[m][n]
    
    recall = lcs / m if m > 0 else 0
    precision = lcs / n if n > 0 else 0
    f1 = (2 * recall * precision) / (recall + precision) if (recall+precision) > 0 else 0
    return recall, precision, f1

# ---------------- Main ----------------

OUTPUT_CSV = "gemini_5shot_results_batch.csv"
df = pd.read_csv(OUTPUT_CSV)

bleu_scores = []
rouge1_f1, rouge2_f1, rouge3_f1, rougel_f1 = [], [], [], []
refs, cands = [], []

for _, row in df.iterrows():
    ref = str(row["original_finding"])
    cand = str(row["gemini_output"])
    
    # BLEU
    bleu_scores.append(bleu_score(ref, cand))
    
    # ROUGE
    _, _, r1 = rouge_n(ref, cand, 1)
    _, _, r2 = rouge_n(ref, cand, 2)
    _, _, r3 = rouge_n(ref, cand, 3)
    _, _, rl = rouge_l(ref, cand)
    
    rouge1_f1.append(r1)
    rouge2_f1.append(r2)
    rouge3_f1.append(r3)
    rougel_f1.append(rl)
    
    refs.append(ref)
    cands.append(cand)

# BERTScore
P, R, F1 = bert_score(cands, refs, lang="en", verbose=True)

# ---------------- Results ----------------
print("Average Metrics on dataset:")
print(f"BLEU:     {np.mean(bleu_scores):.4f}")
print(f"ROUGE-L:  {np.mean(rougel_f1):.4f}")
print(f"BERTScore: {F1.mean().item():.4f}")

# CoT Prompting

In [None]:
import os, gc, io, time, base64, json
import pandas as pd
from PIL import Image
from tqdm import tqdm
import requests

# ----------------- USER CONFIG -----------------
MERGED_CSV    = "/kaggle/input/dsetindiana/indiana_merged.csv"
IMAGES_FOLDER = "/kaggle/input/chest-xrays-indiana-university/images/images_normalized"
OUTPUT_CSV    = "gemini_cot_results_batch.csv"

MAX_ROWS          = 100
MAX_NEW_TOKENS    = 1200
FLUSH_EVERY       = 50
CLEAN_CACHE_EVERY = 25
DOWNSCALE_IMAGES  = True
MAX_IMAGE_SIDE    = 1024

# ‚îÄ‚îÄ‚îÄ Configuration ‚îÄ‚îÄ‚îÄ
SYSTEM_INSTRUCTION = """
You are an expert radiologist. 
You are required to actively search for subtle and chronic abnormalities and explicitly consider possible misses. 
For every region, your default is NOT ‚Äúnormal‚Äù‚Äîyou must convince yourself it is truly normal by excluding all 
reasonable pathology, or explain every possible abnormality, however mild, even if not clinically significant.
"""

PROMPT = """
You are an expert radiologist. For the following chest X-ray, do not rush. Perform these two steps in order:

-------------------------------
Step 1: FULL ORGAN ASSESSMENT
-------------------------------
For **each** of these organs/regions, provide a **separate assessment** in this explicit format:

Region: [name]
Status: Normal / Abnormal
Findings: [describe findings; if normal, list ALL subtle/chronic pathologies actively ruled out by name; if abnormal, specify type, severity, and confidence level.]

Regions to assess:
1. Heart
2. Aorta
3. Cardiac silhouette
4. Diaphragm
5. Costophrenic angle
6. Hilar region
7. Lungs
8. Mediastinum
9. Pleura
10. Pulmonary arteries
11. Trachea
12. Bones (ribs, clavicles, spine, AC joints)
13. Soft tissues

**Do NOT skip any region. Do NOT summarize. Use one block per region.**

---------------------------------------
Step 2: REEVALUATION OF CHALLENGED ORGANS
---------------------------------------
Now, re-examine ONLY these regions with even greater scrutiny, challenging your initial call and explicitly searching for subtle, chronic, or borderline findings. Even mild uncertainty must be described.

Organs to reevaluate (both possible false negatives and false positives):
- Aorta
- Cardiac silhouette
- Diaphragm
- Hilar region
- Lungs
- Mediastinum
- Pleura
- Pulmonary arteries
- Trachea
- Heart
- Costophrenic angle

For each, restate:
Region: [name]
Reevaluation: Normal / Abnormal
Details: [List anything you may have missed or overcalled, including subtle findings or areas of uncertainty. If normal, state what subtle findings you explicitly ruled out again. If abnormal, specify details, severity, and confidence.]

**Finish with a one-sentence summary of any remaining uncertainty for any organ.**

Strictly follow the two steps above. Do not omit or merge regions. Structure your answer exactly as instructed.
"""

# ‚îÄ‚îÄ‚îÄ Gemini Config ‚îÄ‚îÄ‚îÄ
# Put your keys here or read from environment (recommended).
GEMINI_API_KEYS = [
    # os.environ.get("GEMINI_API_KEY"),  # example (single key)
    # "YOUR_KEY_1", "YOUR_KEY_2", ...
]
if not GEMINI_API_KEYS:
    env_keys = [k.strip() for k in os.environ.get("GEMINI_API_KEYS", "").split(",") if k.strip()]
    GEMINI_API_KEYS = env_keys

GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent"

# ----------------- IMAGE HELPERS -----------------
def downscale(img, max_side=1024):
    if not DOWNSCALE_IMAGES:
        return img
    w, h = img.size
    m = max(w, h)
    if m <= max_side:
        return img
    scale = max_side / m
    return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

def pil_to_base64_jpeg(img, quality=90):
    if img.mode != "RGB":
        img = img.convert("RGB")
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=quality, optimize=True)
    return base64.b64encode(buf.getvalue()).decode("utf-8"), "image/jpeg"

# ----------------- GEMINI CLIENT -----------------
class GeminiClient:
    def __init__(self, api_keys, endpoint, timeout=90):
        assert api_keys, "At least one Gemini API key is required. Set GEMINI_API_KEYS or env GEMINI_API_KEYS."
        self.api_keys = api_keys
        self.endpoint = endpoint
        self.timeout = timeout
        self.idx = 0

    def _current_key(self):
        return self.api_keys[self.idx]

    def _rotate_key(self):
        self.idx = (self.idx + 1) % len(self.api_keys)

    def generate(self, system_instruction, user_text, img_b64, img_mime,
                 max_output_tokens=1200, temperature=0.0):
        # Force structured final output only (no hidden thoughts), but keep the explicit 2-step format.
        payload = {
            "systemInstruction": {"parts": [{"text": system_instruction.strip()}]},
            "contents": [{
                "role": "user",
                "parts": [
                    {"text": user_text.strip()},
                    {"inline_data": {"mime_type": img_mime, "data": img_b64}}
                ]
            }],
            "generationConfig": {
                "maxOutputTokens": int(max_output_tokens),
                "temperature": float(temperature),
                "topK": 40,
                "topP": 0.95
            }
        }

        attempts, backoff = 0, 1.0
        max_attempts = len(self.api_keys) * 2

        while attempts < max_attempts:
            attempts += 1
            key = self._current_key()
            try:
                r = requests.post(
                    f"{self.endpoint}?key={key}",
                    headers={"Content-Type": "application/json"},
                    data=json.dumps(payload),
                    timeout=self.timeout,
                )
            except requests.RequestException:
                time.sleep(backoff); backoff = min(backoff * 2, 16)
                continue

            if r.status_code == 200:
                try:
                    data = r.json()
                except Exception as e:
                    return None, f"ERROR: bad JSON from Gemini: {e}"

                if "candidates" not in data or not data["candidates"]:
                    pf = data.get("promptFeedback", {})
                    safety = pf.get("blockReason")
                    if safety:
                        return None, f"ERROR: response blocked by safety: {safety}"
                    return None, "ERROR: no candidates returned"

                try:
                    parts = data["candidates"][0]["content"]["parts"]
                    texts = [p.get("text", "") for p in parts if "text" in p]
                    out_text = "\n".join([t for t in texts if t]).strip()
                    if not out_text:
                        return None, "ERROR: empty text in candidate"
                    return out_text, None
                except Exception as e:
                    return None, f"ERROR: parse failure: {e}"

            elif r.status_code in (429, 500, 502, 503, 504):
                # rate limit / transient server errors
                try:
                    body = r.json()
                except Exception:
                    body = {"error": {"message": r.text}}
                msg = (body.get("error", {}) or {}).get("message", r.text)
                if r.status_code == 429 or "quota" in msg.lower() or "rate" in msg.lower():
                    self._rotate_key()
                time.sleep(backoff); backoff = min(backoff * 2, 16)
                continue

            elif r.status_code in (401, 403):
                # invalid/expired key
                self._rotate_key()
                time.sleep(backoff); backoff = min(backoff * 2, 16)
                continue

            else:
                try:
                    body = r.json()
                except Exception:
                    body = {"error": {"message": r.text}}
                msg = (body.get("error", {}) or {}).get("message", r.text)
                return None, f"ERROR {r.status_code}: {msg}"

        return None, "ERROR: all API keys exhausted or repeated failures"

# ----------------- MAIN -----------------
def main():
    print("Reading merged CSV...")
    base_df = pd.read_csv(MERGED_CSV, usecols=["uid", "filename", "findings"])
    base_df = base_df.drop_duplicates(subset="uid").reset_index(drop=True)

    if len(base_df) > MAX_ROWS:
        base_df = base_df.iloc[:MAX_ROWS].copy()

    existing_df = None
    done_uids = set()

    if os.path.exists(OUTPUT_CSV):
        try:
            existing_df = pd.read_csv(OUTPUT_CSV)
            done_uids = set(existing_df["uid"].astype(str))
            before = len(base_df)
            base_df = base_df[~base_df["uid"].astype(str).isin(done_uids)].reset_index(drop=True)
            print(f"Resuming: {len(done_uids)} already processed; {len(base_df)} remaining (filtered {before - len(base_df)}).")
        except Exception as e:
            print("Could not read existing output; starting fresh.", e)
            existing_df = None

    print("Rows to process in this run:", len(base_df))
    if len(base_df) == 0:
        print("Nothing new to process. (All done.)")
        return

    client = GeminiClient(GEMINI_API_KEYS, GEMINI_ENDPOINT)

    results_buffer = []
    pbar = tqdm(base_df.itertuples(index=False), total=len(base_df), desc="Gemini CoT batch")

    for i, row in enumerate(pbar, start=1):
        uid       = row.uid
        filename  = row.filename
        finding   = row.findings
        img_path  = os.path.join(IMAGES_FOLDER, filename)

        if not os.path.exists(img_path):
            model_out = "ERROR: image not found"
        else:
            try:
                with Image.open(img_path) as img_in:
                    img = downscale(img_in, MAX_IMAGE_SIDE)
                    img_b64, mime = pil_to_base64_jpeg(img)

                text, err = client.generate(
                    system_instruction=SYSTEM_INSTRUCTION,
                    user_text=PROMPT,
                    img_b64=img_b64,
                    img_mime=mime,
                    max_output_tokens=MAX_NEW_TOKENS,
                    temperature=0.0
                )
                model_out = text if err is None else err
            except Exception as e:
                model_out = f"ERROR: {e}"

        results_buffer.append({
            "uid":              uid,
            "filename":         filename,
            "original_finding": finding,
            "prompt":           "CoT radiology prompt (2-step organ-wise + reevaluation)",
            "gemini_output":    model_out
        })

        if i % CLEAN_CACHE_EVERY == 0:
            gc.collect()

        if i % FLUSH_EVERY == 0 or i == len(base_df):
            new_df = pd.DataFrame(results_buffer)
            combined = new_df if existing_df is None else pd.concat([existing_df, new_df], ignore_index=True)
            combined.to_csv(OUTPUT_CSV, index=False)
            existing_df = combined
            results_buffer.clear()
            pbar.set_postfix(saved_rows=len(existing_df))

    print("‚úÖ Finished processing. CSV saved:", OUTPUT_CSV)

if __name__ == "__main__":
    main()

In [None]:
# ----------------- sec5b: Verify CSV contents -----------------
import os
import pandas as pd

# Ensure full cell display of long text
pd.set_option('display.max_colwidth', None)

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)
    print("Rows in CSV:", len(out))
    print("Columns:", out.columns.tolist())
    display(out.head())
else:
    print("CSV not found - something went wrong.")

In [None]:
import matplotlib.pyplot as plt

if os.path.exists(OUTPUT_CSV):
    out = pd.read_csv(OUTPUT_CSV)

    for i, row in out.head(2).iterrows():
        img_path = os.path.join(IMAGES_FOLDER, row["filename"])
        
        if not os.path.exists(img_path):
            print(f"[{row['uid']}] Image not found: {row['filename']}")
            continue

        # Show image
        img = Image.open(img_path).convert("RGB")
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"UID: {row['uid']} | File: {row['filename']}", fontsize=10)
        plt.show()

        # Show text info
        print("üîπ Original Finding:")
        print(row["original_finding"])
        print("\nüîπ Gemini Output:")
        print(row["gemini_output"])
        print("=" * 80)

## LLM-as-a-Judge

In [None]:
import os
import pandas as pd
from tqdm import tqdm
import google.generativeai as genai
import json
import re

# ---------------- Gemini Setup ----------------
genai.configure(api_key="YOUR_KEY_HERE")  # or set via environment variable
model = genai.GenerativeModel("gemini-2.5-flash")

# ---------------- Eval Function ----------------
def judge_with_gemini(original_finding, medgemma_output):
    prompt = f"""
You are acting as an impartial medical evaluator (LLM-as-a-judge). 
Compare the following two texts:

Original Finding:
{original_finding}

MedGemma Output:
{medgemma_output}

Evaluate along these 3 dimensions, and output ONLY a JSON object with keys:
- "similarity": float (0 to 1) ‚Üí semantic meaning similarity.
- "professional_writing": float (0 to 1) ‚Üí clarity, conciseness, and medical domain appropriateness of writing style.
- "correctness": float (0 to 1) ‚Üí medical accuracy and correctness of findings.

Do not explain. Just output the JSON object.
"""

    response = model.generate_content(prompt)
    text = response.text.strip()

    # Clean common formatting issues
    text = re.sub(r"^```json", "", text)
    text = re.sub(r"^```", "", text)
    text = re.sub(r"```$", "", text)
    text = text.strip()

    try:
        scores = json.loads(text)
        return (
            float(scores.get("similarity", 0)),
            float(scores.get("professional_writing", 0)),
            float(scores.get("correctness", 0)),
        )
    except Exception as e:
        print("‚ö†Ô∏è Parsing error:", e, "Raw response:", text)
        return (0.0, 0.0, 0.0)

# ---------------- Main Script ----------------
OUTPUT_CSV = "gemini_cot_results_batch.csv"
EVAL_CSV   = "gemini_cot_eval_results.csv"

if os.path.exists(OUTPUT_CSV):
    df = pd.read_csv(OUTPUT_CSV).head(100).copy()

    sims, pros, corrs = [], [], []

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Gemini LLM-as-a-judge"):
        s, p, c = judge_with_gemini(row["original_finding"], row["gemini_output"])
        sims.append(s); pros.append(p); corrs.append(c)

    df["similarity"] = sims
    df["professional_writing"] = pros
    df["correctness"] = corrs

    df.to_csv(EVAL_CSV, index=False)
    print(f"‚úÖ Evaluation done. Saved to {EVAL_CSV}")

In [None]:
import pandas as pd

df = pd.read_csv("/kaggle/working/gemini_cot_eval_results.csv")
print(df.head(3))

In [None]:
import pandas as pd

# Load evaluation results
df = pd.read_csv("/kaggle/working/gemini_cot_eval_results.csv")

# Compute averages
avg_similarity = df["similarity"].mean()
avg_professional = df["professional_writing"].mean()
avg_correctness = df["correctness"].mean()

# Print neatly
print("Average Scores:")
print(f"Similarity:            {avg_similarity:.3f}")
print(f"Professional Writing:  {avg_professional:.3f}")
print(f"Correctness:           {avg_correctness:.3f}")