# 🫁 Qwen2.5 Chest X-ray Classification

Classify each chest X-ray as 'healthy' or 'unhealthy' using Qwen2.5 and evaluate with accuracy/F1.

In [4]:
import os, json, base64
from openai import OpenAI
from dotenv import load_dotenv
from tqdm import tqdm

def encode_image_to_data_uri(path: str) -> str:
    with open(path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode('utf-8')
    return f"data:image/png;base64,{b64}"

DATASET_DIR = "VLM-Seminar25-Dataset/chest_xrays"
IMAGES_DIR = os.path.join(DATASET_DIR, "images")
ANNOT_PATH = os.path.join(DATASET_DIR, "annotations_len_50.json")

with open(ANNOT_PATH, "r") as f:
    annotations = json.load(f)
image_ids = list(annotations.keys())

load_dotenv(dotenv_path="../config/user.env")
api_key = os.environ.get("OPENROUTER_API_KEY")
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)

In [None]:
classification_results = []
for img_id in tqdm(image_ids):
    img_path = os.path.join(IMAGES_DIR, img_id + ".png")
    data_uri = encode_image_to_data_uri(img_path)
    prompt = "Given the medical image, classify it as 'healthy' or 'unhealthy'. It is very important that you only output only either 'healthy' or 'unhealthy'."
    completion = client.chat.completions.create(
        model="qwen/qwen2.5-vl-72b-instruct:free",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": data_uri}},
                ],
            }
        ],
    )
    pred = completion.choices[0].message.content.strip().lower()
    classification_results.append({"id": img_id, "prediction": pred})

 24%|██▍       | 12/50 [01:17<05:35,  8.84s/it]

In [None]:
# Save predictions
with open("VLM-Seminar25-Dataset/chest_xrays/qwen2_5_classification_results.json", "w") as f:
    json.dump(classification_results, f, indent=2)
print("Saved classification results.")

In [None]:
# Evaluation using direct import from provided script
import sys
sys.path.append("VLM-Seminar25-Dataset/scripts")
from evaluate_metrics import accuracy_score, f1_score

gt = [annotations[x["id"]]["status"] for x in classification_results]
pred = [x["prediction"] for x in classification_results]

accuracy = accuracy_score(gt, pred)
f1 = f1_score(gt, pred, pos_label='unhealthy')
print(f"Accuracy: {accuracy:.3f}")
print(f"F1 Score: {f1:.3f}")

## Visualize Results: Correct Predictions

Show images where the prediction matches the ground truth.

In [None]:
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import os

# Visualization config (from instruction.ipynb)
FONT_PATH = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
FONT_SIZE = 32

def draw_text(draw, pos, text, font, color, outline=2):
    for dx in [-outline, outline]:
        for dy in [-outline, outline]:
            draw.text((pos[0]+dx, pos[1]+dy), text, font=font, fill='black')
    draw.text(pos, text, font=font, fill=color)

def visualize_img(img_id, gt_label, pred_label):
    img_path = os.path.join(IMAGES_DIR, img_id + ".png")
    img = Image.open(img_path).convert("RGB")
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype(FONT_PATH, FONT_SIZE)
    except:
        font = ImageFont.load_default()
    text = f"GT: {gt_label}\nPred: {pred_label}"
    draw_text(draw, (10, 10), text, font, color="green" if gt_label == pred_label else "red")
    return img

# Split into correct and incorrect
correct = []
incorrect = []
for x in classification_results:
    gt_label = annotations[x["id"]]["status"]
    pred_label = x["prediction"]
    if gt_label == pred_label:
        correct.append((x["id"], gt_label, pred_label))
    else:
        incorrect.append((x["id"], gt_label, pred_label))

def show_examples(examples, title, max_n=6):
    n = min(len(examples), max_n)
    if n == 0:
        print(f"No examples for {title}")
        return
    fig, axes = plt.subplots(1, n, figsize=(5*n, 5))
    if n == 1:
        axes = [axes]
    for i, (img_id, gt_label, pred_label) in enumerate(examples[:n]):
        img = visualize_img(img_id, gt_label, pred_label)
        axes[i].imshow(img)
        axes[i].set_title(f"ID: {img_id}", fontsize=14)
        axes[i].axis('off')
    plt.suptitle(title, fontsize=18)
    plt.tight_layout()
    plt.show()

show_examples(correct, "Correct Predictions")

## Visualize Results: Incorrect Predictions

Show images where the prediction does not match the ground truth.

In [None]:
show_examples(incorrect, "Incorrect Predictions")