# 🧠 Qwen2.5 Brain MRI Description Generation

Generate medical descriptions for each MRI slice using Qwen2.5 and evaluate with BLEU.

## 1. Setup & Imports

In [None]:
import os, json, base64
from openai import OpenAI
from dotenv import load_dotenv
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

def encode_image_to_data_uri(path: str) -> str:
    with open(path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode('utf-8')
    return f"data:image/png;base64,{b64}"

DATASET_DIR = "VLM-Seminar25-Dataset/nova_brain"
IMAGES_DIR = os.path.join(DATASET_DIR, "images")
ANNOT_PATH = os.path.join(DATASET_DIR, "annotations.json")
RESULTS_DIR = "../results/nova_brain"
os.makedirs(RESULTS_DIR, exist_ok=True)

with open(ANNOT_PATH, "r") as f:
    annotations = json.load(f)
case_ids = list(annotations.keys())

load_dotenv(dotenv_path="config/user.env")
api_key = os.environ.get("NEBIUS_API_KEY")
client = OpenAI(base_url="https://api.studio.nebius.com/v1/", api_key=api_key)

In [None]:
do_new_inference = False

## 2. Model Inference

In [None]:
description_results = []
if do_new_inference:
    for case_id in tqdm(case_ids):
        case = annotations[case_id]
        for img_name, img_info in case.get("image_findings", {}).items():
            img_path = os.path.join(IMAGES_DIR, img_name)
            data_uri = encode_image_to_data_uri(img_path)
            prompt = "Please describe the given medical image."
            completion = client.chat.completions.create(
                model="Qwen/Qwen2.5-VL-72B-Instruct",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image_url", "image_url": {"url": data_uri}},
                        ],
                    }
                ],
            )
            pred = completion.choices[0].message.content.strip()
            description_results.append({
                "case_id": case_id,
                "image": img_name,
                "prediction": pred,
                "ground_truth": img_info.get("caption", "")
            })

## 3. Save Model Predictions

In [None]:
with open(os.path.join(RESULTS_DIR, "qwen2.5_description_results.json"), "w") as f:
    json.dump(description_results, f, indent=2)
print("Saved description results.")

Load results

In [None]:
with open(os.path.join(RESULTS_DIR, "qwen2.5_description_results.json"), "r") as f:
    description_results = json.load(f)
print(f"Number of descriptions: {len(description_results)}")

## 4. Evaluation & Metrics

In [None]:
# Optionally import from /code/eval_scripts/ if available
import sys
sys.path.append("VLM-Seminar25-Dataset/scripts")
from evaluate_bleu import sentence_bleu
from nltk.tokenize import word_tokenize

gt = [x["ground_truth"] for x in description_results]
pred = [x["prediction"] for x in description_results]

bleu_scores = []
for ref, cand in zip(gt, pred):
    ref_tokens = word_tokenize(ref)
    cand_tokens = word_tokenize(cand)
    bleu = sentence_bleu([ref_tokens], cand_tokens, weights=(0.25, 0.25, 0.25, 0.25))
    bleu_scores.append(bleu)
mean_bleu = sum(bleu_scores)/len(bleu_scores) if bleu_scores else 0.0
print(f"Mean BLEU-4: {mean_bleu:.4f}")

eval_metrics = {"mean_bleu_4": mean_bleu}
with open(os.path.join(RESULTS_DIR, "description_eval_metrics.json"), "w") as f:
    json.dump(eval_metrics, f, indent=2)
print("Saved BLEU metrics.")

plt.figure(figsize=(4,4))
plt.bar(["BLEU-4"], [mean_bleu], color="orchid")
plt.ylim(0, 1)
plt.title("Qwen2.5 MRI Description BLEU-4")
plt.text(0, mean_bleu + 0.02, f"{mean_bleu:.2f}", ha='center', fontsize=12)
plt.savefig(os.path.join(RESULTS_DIR, "description_metrics.png"))
plt.show()
plt.close()
print("Saved BLEU-4 plot.")

## 5. Visualize Results: Correct Descriptions (BLEU ≥ 0.5)
Show examples where the BLEU score is high (≥ 0.5).

In [None]:
from PIL import Image

def show_description_examples(examples, title, max_n=4):
    n = min(len(examples), max_n)
    if n == 0:
        print(f"No examples for {title}")
        return
    fig, axes = plt.subplots(1, n, figsize=(6*n, 6))
    if n == 1:
        axes = [axes]
    for i, (img_path, gt, pred, bleu) in enumerate(examples[:n]):
        img = Image.open(img_path).convert("RGB")
        axes[i].imshow(img)
        axes[i].set_title(f"BLEU: {bleu:.2f}", fontsize=14)
        axes[i].axis('off')
        axes[i].text(0, -10, f"GT: {gt}\n\nPred: {pred}", fontsize=10, wrap=True)
    plt.suptitle(title, fontsize=18)
    plt.tight_layout()
    plt.show()

correct = []
for x, bleu in zip(description_results, bleu_scores):
    if bleu >= 0.5:
        img_path = os.path.join(IMAGES_DIR, x["image"])
        correct.append((img_path, x["ground_truth"], x["prediction"], bleu))

show_description_examples(correct, "Correct Descriptions (BLEU ≥ 0.5)")

## 6. Visualize Results: Incorrect Descriptions (BLEU < 0.5)
Show examples where the BLEU score is low (&lt; 0.5).

In [None]:
incorrect = []
for x, bleu in zip(description_results, bleu_scores):
    if bleu < 0.5:
        img_path = os.path.join(IMAGES_DIR, x["image"])
        incorrect.append((img_path, x["ground_truth"], x["prediction"], bleu))

show_description_examples(incorrect, "Incorrect Descriptions (BLEU < 0.5)")