Imports

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import pandas as pd
from tqdm import tqdm
import os
import re

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Dataset extraction from google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Dataset organization

In [None]:
class XrayDataset(Dataset):
    def __init__(self, csv_path, image_dir, tokenizer, transform):
        self.df = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.transform = transform

        self.labels = sorted(self.df["label"].unique())
        self.label2idx = {l: i for i, l in enumerate(self.labels)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        label = row["label"]
        label_idx = self.label2idx[label]

        image = Image.open(
            os.path.join(self.image_dir, row["image"])
        ).convert("RGB")
        image = self.transform(image)

        text = f"A chest X-ray showing findings of {label}."
        tokens = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=32,
            return_tensors="pt"
        )

        return {
            "image": image,
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0),
            "label_idx": label_idx
        }

Robust label extraction

In [None]:
def robust_extract_label_from_text(text, label_list):
    text_clean = re.sub(r"[^a-zA-Z0-9]", "", text.lower())
    for label in label_list:
        label_clean = re.sub(r"[^a-zA-Z0-9]", "", label.lower())
        if label_clean in text_clean:
            return label
    return None

VLM Model

In [None]:
class XrayVLM(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.image_encoder = nn.Sequential(*modules)
        self.img_fc = nn.Linear(resnet.fc.in_features, embed_dim)
        self.text_encoder = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        self.text_fc = nn.Linear(self.text_encoder.config.hidden_size, embed_dim)

    def forward(self, images, input_ids, attention_mask):
        img_feat = self.image_encoder(images).squeeze(-1).squeeze(-1)
        img_emb = self.img_fc(img_feat)
        img_emb = F.normalize(img_emb, dim=1)

        txt_feat = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
        txt_emb = self.text_fc(txt_feat)
        txt_emb = F.normalize(txt_emb, dim=1)
        return img_emb, txt_emb

    def encode_image(self, images):
        img_feat = self.image_encoder(images).squeeze(-1).squeeze(-1)
        img_emb = self.img_fc(img_feat)
        return F.normalize(img_emb, dim=1)

    def encode_text(self, input_ids, attention_mask):
        txt_feat = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
        txt_emb = self.text_fc(txt_feat)
        return F.normalize(txt_emb, dim=1)

Contrastive Loss

In [None]:
class MedicalContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.05):
        super().__init__()
        self.temperature = temperature
        self.ce = nn.CrossEntropyLoss()

    def forward(self, img_emb, txt_emb):
        logits = img_emb @ txt_emb.T / self.temperature
        labels = torch.arange(len(logits), device=logits.device)
        loss = (self.ce(logits, labels) + self.ce(logits.T, labels)) / 2
        return loss

Parameter definition and dataset split

In [2]:
BATCH_SIZE = 16
EPOCHS = 5
LR = 1e-4
TEMPERATURE = 0.05
IMAGE_SIZE = 224

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor()
])

CSV_PATH = "/content/drive/MyDrive/xray_data/metadata.csv"
IMG_DIR = "/content/drive/MyDrive/xray_data/images"

from torch.utils.data import random_split

full_dataset = XrayDataset(CSV_PATH, IMG_DIR, tokenizer, transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_dataset,
    [train_size, val_size]
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=16, shuffle=False)
train_label_list = train_dataset.dataset.labels


model = XrayVLM(embed_dim=256).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = MedicalContrastiveLoss(TEMPERATURE)


Using device: cuda
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 129MB/s]


pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

BertModel LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.predictions.decoder.weight             | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.predictions.bias                       | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Loop body

In [3]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        img = batch["image"].to(device)
        ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)

        img_emb, txt_emb = model(img, ids, mask)
        loss = criterion(img_emb, txt_emb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss/len(train_loader):.4f}")

100%|██████████| 35/35 [04:25<00:00,  7.59s/it]


Epoch 1/5 - Loss: 2.7886


100%|██████████| 35/35 [00:20<00:00,  1.69it/s]


Epoch 2/5 - Loss: 2.4038


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


Epoch 3/5 - Loss: 1.8503


100%|██████████| 35/35 [00:20<00:00,  1.68it/s]


Epoch 4/5 - Loss: 1.4938


100%|██████████| 35/35 [00:20<00:00,  1.74it/s]

Epoch 5/5 - Loss: 1.3901





Label embeddings

In [None]:
model.eval()
label_list = full_dataset.labels
label_prompts = [
    f"Radiographic evidence of {l} in a chest X-ray."
    for l in label_list
]

tokens = tokenizer(
    label_prompts,
    padding=True,
    truncation=True,
    max_length=32,
    return_tensors="pt"
).to(device)

with torch.no_grad():
    label_embs = model.encode_text(
        tokens["input_ids"],
        tokens["attention_mask"]
    )

label_embs = F.normalize(label_embs, dim=1).detach()

Evaluate function

In [4]:
@torch.no_grad()
def evaluate(model, dataloader, label_embs, device):
    model.eval()

    total = 0
    correct_top1 = 0
    correct_top3 = 0
    pos_sims = []
    neg_sims = []

    for batch in dataloader:
        images = batch["image"].to(device)
        gt_indices = batch["label_idx"].to(device)

        img_embs = model.encode_image(images)
        sims = img_embs @ label_embs.T

        top1 = sims.argmax(dim=1)
        top3 = sims.topk(3, dim=1).indices

        correct_top1 += (top1 == gt_indices).sum().item()
        correct_top3 += sum(
            gt_indices[i].item() in top3[i].tolist()
            for i in range(len(gt_indices))
        )

        total += images.size(0)

        pos_sim = sims[torch.arange(len(gt_indices)), gt_indices]
        pos_sims.append(pos_sim.mean().item())

        mask = torch.ones_like(sims, dtype=torch.bool)
        mask[torch.arange(len(gt_indices)), gt_indices] = False
        neg_sim = sims[mask].view(len(gt_indices), -1).mean(dim=1)
        neg_sims.append(neg_sim.mean().item())

    print(f"Top-1 Accuracy: {correct_top1 / total:.4f}")
    print(f"Top-3 Accuracy: {correct_top3 / total:.4f}")
    print(f"Mean positive similarity: {sum(pos_sims)/len(pos_sims):.4f}")
    print(f"Mean negative similarity: {sum(neg_sims)/len(neg_sims):.4f}")

Evaluation

In [5]:
evaluate(
    model=model,
    dataloader=val_loader,
    label_embs=label_embs,
    device=device
)
# Top-1 Accuracy: 0.0857
# Top-3 Accuracy: 0.1857
# Mean positive similarity: 0.2940
# Mean negative similarity: 0.2202

Top-1 Accuracy: 0.0857
Top-3 Accuracy: 0.1857
Mean positive similarity: 0.2940
Mean negative similarity: 0.2202
