# 🧠 Qwen2.5 Brain MRI Disease Diagnosis

Predict disease based on clinical history and image findings using Qwen2.5.

## 1. Setup & Imports

In [None]:
import os, json
from openai import OpenAI
from dotenv import load_dotenv
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

DATASET_DIR = "VLM-Seminar25-Dataset/nova_brain"
ANNOT_PATH = os.path.join(DATASET_DIR, "annotations.json")
RESULTS_DIR = "../results/nova_brain"
os.makedirs(RESULTS_DIR, exist_ok=True)

with open(ANNOT_PATH, "r") as f:
    annotations = json.load(f)
case_ids = list(annotations.keys())

load_dotenv(dotenv_path="config/user.env")
api_key = os.environ.get("NEBIUS_API_KEY")
client = OpenAI(base_url="https://api.studio.nebius.com/v1/", api_key=api_key)

In [None]:
do_new_inference = False

## 2. Model Inference

In [None]:
diagnosis_results = []
if do_new_inference:
    for case_id in tqdm(case_ids):
        case = annotations[case_id]
        clinical_history = case.get("clinical_history", "")
        findings = []
        for img_name, img_info in case.get("image_findings", {}).items():
            findings.append(f"{img_name}: {img_info.get('caption', '')}")
        findings_str = " ".join(findings)
        prompt = f"Based on the clinical history: {clinical_history} and image findings: {findings_str}, provide your diagnosis for the disease."
        completion = client.chat.completions.create(
            model="Qwen/Qwen2.5-VL-72B-Instruct",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt}
                    ],
                }
            ],
        )
        pred = completion.choices[0].message.content.strip()
        diagnosis_results.append({
            "case_id": case_id,
            "prediction": pred,
            "ground_truth": case.get("final_diagnosis", "")
        })

## 3. Save Model Predictions

In [None]:
with open(os.path.join(RESULTS_DIR, "qwen2.5_diagnosis_results.json"), "w") as f:
    json.dump(diagnosis_results, f, indent=2)
print("Saved diagnosis results.")

Load results

In [None]:
with open(os.path.join(RESULTS_DIR, "qwen2.5_diagnosis_results.json"), "r") as f:
    diagnosis_results = json.load(f)
print(f"Number of cases diagnosed: {len(diagnosis_results)}")

## 4. Evaluation & Metrics

In [None]:
# Optionally import from /code/eval_scripts/ if available
# from eval_scripts.diagnosis_eval import compute_metrics

gt = [x["ground_truth"] for x in diagnosis_results]
pred = [x["prediction"] for x in diagnosis_results]

# Dummy accuracy: exact string match (replace with your own metric if needed)
correct = sum([g.strip().lower() == p.strip().lower() for g, p in zip(gt, pred)])
accuracy = correct / len(gt) if gt else 0.0

eval_metrics = {"accuracy": accuracy}
with open(os.path.join(RESULTS_DIR, "diagnosis_eval_metrics.json"), "w") as f:
    json.dump(eval_metrics, f, indent=2)
print("Saved evaluation metrics.")

plt.figure(figsize=(4,4))
plt.bar(["Accuracy"], [accuracy], color="cornflowerblue")
plt.ylim(0, 1)
plt.title("Qwen2.5 MRI Diagnosis Accuracy")
plt.text(0, accuracy + 0.02, f"{accuracy:.2f}", ha='center', fontsize=12)
plt.savefig(os.path.join(RESULTS_DIR, "diagnosis_metrics.png"))
plt.show()
plt.close()
print("Saved accuracy plot.")