# Colab Notebook: Multimodal MIMIC-CXR (Images + Metadata)
#   - Fusion from scratch: Concatenation & Cross-Attention
#   - Foundation Model (MedCLIP) Adaptation: Linear, Partial FT, LoRA
#   - Metrics: AUC, Accuracy, F1 (per class)

In [19]:
!pip install -q pydicom
!pip install -q timm==1.0.9 albumentations==1.2.1 torchmetrics==1.4.0.post0 scikit-learn==1.5.2 numpy==1.24 scipy==1.10 medclip==0.0.3 open_clip_torch peft transformers timm albumentations pandas jax jaxlib

import os, math, json, random, warnings
warnings.filterwarnings("ignore")
from typing import List
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import timm
import torchmetrics # Use torchmetrics instead of sklearn
from torch.cuda.amp import autocast, GradScaler

import albumentations as A
from albumentations.pytorch import ToTensorV2

import pydicom

try:
    from medclip import MedCLIPModel, MedCLIPVisionModelViT
    _HAS_MEDCLIP = True
except Exception:
    _HAS_MEDCLIP = False
try:
    import open_clip
    _HAS_OPENCLIP = True
except Exception:
    _HAS_OPENCLIP = True
from peft import LoraConfig, get_peft_model

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m90.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Getting requirements to build wheel ... [?25l[?25herror
[1;31merror[0m: [1msubprocess-exited-with-error[0m

[31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
[31m│[0m exit code: [1;36m1[0m
[31m╰─>[0m See above for output.

[1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
Device: cuda


In [4]:
DATA_PATH = "./MIMIC-CXR"
!mkdir -p {DATA_PATH}
!wget https://uni-bonn.sciebo.de/s/YHuwFOg6q6sw1ZX/download -O {DATA_PATH}/MIMIC-CXR.zip
!unzip {DATA_PATH}/MIMIC-CXR.zip -d {DATA_PATH}
!curl -L --progress-bar https://uni-bonn.sciebo.de/s/XbomHCb6yL6nYN4/download -o ./radiomics.csv
!curl -L --progress-bar https://uni-bonn.sciebo.de/s/e7fKPxDYcs83J67/download -o ./labels.csv

print("Data download complete.")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: ./MIMIC-CXR/MIMIC-CXR/segmentation/18691393/56760870/a6698e16-f7a79eac-4fcf0ed1-88e88142-3abc8371.png  
  inflating: ./MIMIC-CXR/MIMIC-CXR/segmentation/18713249/57588311/4ffc60ec-ffd298d5-42243c1c-2ec47d75-64323bc4.png  
  inflating: ./MIMIC-CXR/MIMIC-CXR/segmentation/18713249/56117933/eb80ed86-da05e903-ac591a0e-e3cc826e-2a2c6ab9.png  
  inflating: ./MIMIC-CXR/MIMIC-CXR/segmentation/18991516/56755931/e68d876b-82b96bc4-73549fb3-144d8866-d5836ec5.png  
  inflating: ./MIMIC-CXR/MIMIC-CXR/segmentation/18871870/50832312/6744c4f1-d659b0de-9616f798-def4d319-7c248ea7.png  
  inflating: ./MIMIC-CXR/MIMIC-CXR/segmentation/18871870/59153339/f3fcbda8-43ced708-1386efce-06910f23-abd0e1c3.png  
  inflating: ./MIMIC-CXR/MIMIC-CXR/segmentation/18871870/57448965/8e50972c-052a94b0-4c9a7a2a-ed550529-bf329b84.png  
  inflating: ./MIMIC-CXR/MIMIC-CXR/segmentation/18871870/52923609/a1ff3160-76069a6b-5ce15b1b-46564928-b7a3c300.png  

In [53]:
from sklearn.model_selection import train_test_split

radiomics = pd.read_csv("radiomics.csv")
labels_df = pd.read_csv("labels.csv")

# Merge on study_id instead of StudyInstanceUID
df = pd.merge(radiomics, labels_df, on="study_id", how="inner")

# Use dicom_id_y for path construction as StudyInstanceUID is not in merged df
df["path"] = df["dicom_id_y"].apply(lambda x: os.path.join(DATA_PATH, f"{x}.jpg"))
META_COLS = [c for c in df.columns if c.startswith("radiomics_") or c.startswith("feature_")]
label_cols = [c for c in labels_df.columns if c != "StudyInstanceUID"] # Keep this as is, refers to original labels_df cols
# Use study_id for splitting
uids = df["study_id"].unique()
train_ids, test_ids = train_test_split(uids, test_size=0.2, random_state=42)
val_ids, test_ids = train_test_split(test_ids, test_size=0.5, random_state=42)
def assign_split(uid):
    if uid in train_ids: return "train"
    elif uid in val_ids: return "val"
    else: return "test"
# Apply split based on study_id
df["split"] = df["study_id"].apply(assign_split)
df.to_csv("mimic_mmd.csv", index=False)
print("Saved merged dataset -> mimic_mmd.csv")
print(df["split"].value_counts())

Saved merged dataset -> mimic_mmd.csv
split
train    664
val       81
test      78
Name: count, dtype: int64


In [88]:
CSV_PATH, OUTPUT_DIR = "./mimic_mmd.csv", "./outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
IMG_SIZE, BATCH, EPOCHS_SCRATCH, EPOCHS_VLM = 256, 32, 5, 3

df = pd.read_csv(CSV_PATH)
import ast
import numpy as np

def parse_label_string(x):
    if isinstance(x, str):
        # Remove brackets, split by space, convert to float
        vals = [float(v) for v in x.replace('[','').replace(']','').split()]
        return np.array(vals, dtype=np.float32)
    elif isinstance(x, (list, np.ndarray)):
        return np.array(x, dtype=np.float32)
    else:
        return np.zeros(14, dtype=np.float32)  # fallback

df['labels_encoded_y'] = df['labels_encoded_y'].apply(parse_label_string)

DATA_PATH_IMAGES = "./MIMIC-CXR/MIMIC-CXR/"

LABELS = ['labels_encoded_y']
print("Using LABELS =", LABELS)

# Radiomics / meta feature columns
META_COLS = [c for c in df.columns if c.startswith("radiomics_") or c.startswith("feature_") or c.startswith("original_")]

train_meta = df[df["split"] == "train"][META_COLS].astype(np.float32)
META_MEAN, META_STD = train_meta.mean(0).values, (train_meta.std(0).values + 1e-6)

from torchvision import transforms

train_tfms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_tfms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class CXRDataset(Dataset):
    def __init__(self, frame: pd.DataFrame, split: str, labels: List[str], tfms):
        self.df = frame[frame["split"] == split].reset_index(drop=True)
        self.labels, self.tfms = labels, tfms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]

        img_path = os.path.join(DATA_PATH_IMAGES, row["dicom_path"])
        try:
            dcm = pydicom.dcmread(img_path)
            img_array = dcm.pixel_array
            if img_array.ndim == 2:  # grayscale
                if img_array.dtype != np.uint8:
                    img_array = img_array.astype(np.float32)
                    img_array = (img_array - img_array.min()) / (img_array.max() - img_array.min()) * 255
                    img_array = img_array.astype(np.uint8)
                img = Image.fromarray(img_array).convert("RGB")
            elif img_array.ndim == 3 and img_array.shape[2] == 3:  # RGB
                if img_array.dtype != np.uint8:
                    img_array = img_array.astype(np.uint8)
                img = Image.fromarray(img_array)
            else:
                raise ValueError(f"Unsupported image format or dimensions: {img_array.shape}")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            raise e

        img_t = self.tfms(img)
        meta = (row[META_COLS].values.astype(np.float32) - META_MEAN) / META_STD

        label_field = row[self.labels[0]]
        y = np.array(label_field, dtype=np.float32)

        return img_t, torch.from_numpy(meta), torch.from_numpy(y)

# build datasets + loaders
train_ds, val_ds, test_ds = [
    CXRDataset(df, s, LABELS, t)
    for s, t in [("train", train_tfms), ("val", val_tfms), ("test", val_tfms)]
]
train_loader, val_loader, test_loader = [
    DataLoader(ds, batch_size=BATCH, shuffle=s, num_workers=2, pin_memory=True)
    for ds, s in [(train_ds, True), (val_ds, False), (test_ds, False)]
]

print(f"Splits -> train: {len(train_ds)} | val: {len(val_ds)} | test: {len(test_ds)}")
print("Labels:", LABELS)
print("Radiomics feature dims:", len(META_COLS))

SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? (ipython-input-2845365625.py, line 3)

In [74]:
def build_resnet(backbone="resnet50", pretrained=False, out_dim=1024):
    m = timm.create_model(backbone, pretrained=pretrained, num_classes=0, global_pool="")
    feat_dim = m.num_features
    projector = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(feat_dim, out_dim), nn.ReLU(inplace=True))
    return m, projector, feat_dim

class ConcatFusion(nn.Module):
    def __init__(self, img_backbone="resnet50", img_dim=768, meta_in_dim=128, hidden=512, n_classes=2):
        super().__init__()
        self.cnn, self.proj, _ = build_resnet(img_backbone, pretrained=False, out_dim=img_dim)
        self.meta_proj = nn.Sequential(nn.Linear(meta_in_dim, 256), nn.ReLU(inplace=True))
        self.head = nn.Sequential(nn.Linear(img_dim + 256, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, n_classes))
    def forward(self, x, meta):
        img_features = self.proj(self.cnn.forward_features(x))
        meta_features = self.meta_proj(meta)
        return self.head(torch.cat([img_features, meta_features], dim=1))

class CrossAttnFusion(nn.Module):
    def __init__(self, img_backbone="resnet50", token_dim=512, meta_in_dim=128, n_heads=8, n_layers=1, n_classes=2):
        super().__init__()
        self.cnn, _, fc = build_resnet(img_backbone, pretrained=False, out_dim=token_dim)
        self.token_proj, self.meta_proj = nn.Conv2d(fc, token_dim, kernel_size=1), nn.Linear(meta_in_dim, token_dim)
        self.cls = nn.Parameter(torch.zeros(1, 1, token_dim))
        enc_layer = nn.TransformerEncoderLayer(d_model=token_dim, nhead=n_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.head = nn.Linear(token_dim, n_classes)
        nn.init.normal_(self.cls, std=0.02)
    def forward(self, x, meta):
        feat = self.cnn.forward_features(x)
        tokens = self.token_proj(feat).flatten(2).transpose(1, 2)
        cls_token = self.cls.expand(tokens.shape[0], -1, -1)
        # Add metadata information to the CLS token
        cls_token = cls_token + self.meta_proj(meta).unsqueeze(1)
        tokens = torch.cat([cls_token, tokens], dim=1)
        # Get the output corresponding to the CLS token
        return self.head(self.encoder(tokens)[:, 0, :])

def count_trainable_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
print("Models defined.")

Models defined.


In [75]:
@torch.no_grad()
def evaluate(model, loader, labels):
    """
    Returns:
      {"per_label": {lab: {"AUC": float|nan, "F1": float, "ACC": float}},
       "macro": {"AUC": float|nan, "F1": float, "ACC": float}}
    """
    import numpy as np
    import torch
    from sklearn.metrics import roc_auc_score, f1_score, accuracy_score

    model.eval()
    all_logits, all_targets = [], []

    for batch in loader:
        # Support (imgs, meta, y) and (imgs, y)
        if len(batch) == 3:
            imgs, meta, y = batch
            imgs, meta, y = imgs.to(DEVICE), meta.to(DEVICE), y.to(DEVICE).float()
            logits = model(imgs, meta)
        else:
            imgs, y = batch
            imgs, y = imgs.to(DEVICE), y.to(DEVICE).float()
            logits = model(imgs)

        all_logits.append(logits.detach().cpu())
        all_targets.append(y.detach().cpu())

    logits = torch.cat(all_logits, dim=0).numpy()
    targets = torch.cat(all_targets, dim=0).numpy()
    probs = 1.0 / (1.0 + np.exp(-logits))           # sigmoid
    preds = (probs >= 0.5).astype(np.float32)

    per_label = {}
    aucs, f1s, accs = [], [], []
    for j, lab in enumerate(labels):
        y_true = targets[:, j].astype(int)
        y_prob = probs[:, j]
        y_pred = preds[:, j]

        # AUC needs both classes present
        if (y_true.max() == 0) or (y_true.min() == 1) or (y_true.sum() == 0) or (y_true.sum() == len(y_true)):
            auc = float("nan")
        else:
            try:
                auc = roc_auc_score(y_true, y_prob)
            except Exception:
                auc = float("nan")

        f1 = f1_score(y_true, y_pred, zero_division=0, average="binary" if len(np.unique(y_true)) == 2 else "macro")
        acc = accuracy_score(y_true, y_pred)
        per_label[lab] = {"AUC": float(auc), "F1": float(f1), "ACC": float(acc)}
        if not np.isnan(auc): aucs.append(auc)
        f1s.append(f1); accs.append(acc)

    macro_auc = float(np.mean(aucs)) if aucs else float("nan")
    macro_f1  = float(np.mean(f1s))
    macro_acc = float(np.mean(accs))
    return {"per_label": per_label, "macro": {"AUC": macro_auc, "F1": macro_f1, "ACC": macro_acc}}


In [76]:
batch = next(iter(train_loader))
print("Label shape:", batch[2].shape)
print("Label min/max:", batch[2].min().item(), batch[2].max().item())


Label shape: torch.Size([32, 14])
Label min/max: 0.0 1.0


In [77]:
def train_model(model, train_loader, val_loader, labels, epochs=5, lr=1e-3, weight_decay=1e-4):
    import torch, numpy as np
    import torch.nn as nn

    criterion = nn.BCEWithLogitsLoss()
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)

    best = {"score": -1.0, "state": None}
    for ep in range(1, epochs + 1):
        model.train()
        running = 0.0
        for batch in train_loader:
            optimizer.zero_grad(set_to_none=True)
            if len(batch) == 3:
                imgs, meta, y = batch
                imgs, meta, y = imgs.to(DEVICE), meta.to(DEVICE), y.to(DEVICE).float()
                logits = model(imgs, meta)
            else:
                imgs, y = batch
                imgs, y = imgs.to(DEVICE), y.to(DEVICE).float()
                logits = model(imgs)

            loss = criterion(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, max_norm=5.0)
            optimizer.step()
            running += loss.item() * y.size(0)

        val_metrics = evaluate(model, val_loader, labels)
        val_auc = val_metrics["macro"]["AUC"]
        print(f"[Epoch {ep}/{epochs}] train_loss={running/len(train_loader.dataset):.4f} | val_macro_auc={val_auc:.4f}")

        if (val_auc == val_auc) and (val_auc >= best["score"]):  # NaN-safe
            best["score"] = float(val_auc)
            best["state"] = {k: v.cpu() for k, v in model.state_dict().items()}

    if best["state"] is not None:
        model.load_state_dict(best["state"])
    return model, best

In [110]:
def count_trainable_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

class VLMWithHead(torch.nn.Module):
    def __init__(self, visual_backbone, feat_dim, meta_dim, n_classes, head_hidden=256):
        super().__init__()
        self.backbone = visual_backbone
        self.pool = torch.nn.AdaptiveAvgPool2d((1,1)) if hasattr(visual_backbone, "conv1") else None
        self.feat_dim = feat_dim
        self.head = torch.nn.Sequential(
            torch.nn.Linear(feat_dim + meta_dim, head_hidden),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(head_hidden, n_classes),
        )

    def forward_visual(self, imgs):
        x = self.backbone(imgs)
        if isinstance(x, (tuple, list)):
            x = x[0]
        if x.ndim == 4:
            x = self.pool(x).flatten(1)
        return x

    def forward(self, imgs, meta):
        img_feat = self.forward_visual(imgs)
        return self.head(torch.cat([img_feat, meta], dim=1))

def build_openclip_visual(name="ViT-B-32", pretrained="laion2b_s34b_b79k"):
    import open_clip, torch
    import torch.nn.functional as F

    model, _, _ = open_clip.create_model_and_transforms(
        name, pretrained=pretrained, device=DEVICE
    )
    visual = model.visual

    if hasattr(visual, "positional_embedding"):
        pe = visual.positional_embedding  # [1, 1+N, C] or [1+N, C]
        if pe.ndim == 2:
            pe = pe.unsqueeze(0)  # make it [1, 1+N, C]

        patch = visual.patch_size[0] if isinstance(visual.patch_size, (tuple, list)) else int(visual.patch_size)
        target_hw = max(1, IMG_SIZE // patch)

        cls_pe = pe[:, :1]
        grid_pe = pe[:, 1:]  # [1, N, C]

        if hasattr(visual, "grid_size") and isinstance(visual.grid_size, (tuple, list)):
            old_h, old_w = visual.grid_size
        else:
            side = int(round(grid_pe.shape[1] ** 0.5))
            old_h, old_w = side, side

        old_tokens = old_h * old_w
        if grid_pe.shape[1] != old_tokens or (1 + target_hw * target_hw) != pe.shape[1]:
            print(f"Resizing positional embeddings from {1+grid_pe.shape[1]} → {1+target_hw*target_hw}")
            grid_pe = grid_pe[:, : old_h * old_w, :].reshape(1, old_h, old_w, -1).permute(0, 3, 1, 2)
            grid_pe = F.interpolate(grid_pe, size=(target_hw, target_hw), mode="bicubic", align_corners=False)
            grid_pe = grid_pe.permute(0, 2, 3, 1).reshape(1, target_hw * target_hw, -1)
            pe_new = torch.cat([cls_pe, grid_pe], dim=1)
            visual.positional_embedding = torch.nn.Parameter(pe_new)

    # ---- Feature dimension ----
    feat_dim = getattr(visual, "output_dim", None)
    if feat_dim is None:
        proj = getattr(visual, "proj", None)
        feat_dim = proj.shape[0] if isinstance(proj, torch.nn.Parameter) else 768

    return visual, feat_dim

class LoRALinear(torch.nn.Module):
    def __init__(self, base: torch.nn.Linear, r: int, alpha: int, dropout: float):
        super().__init__()
        self.in_features = base.in_features
        self.out_features = base.out_features
        self.r = r
        self.scaling = alpha / r
        self.weight = base.weight
        self.bias = base.bias
        # freeze the base weight & bias
        self.weight.requires_grad_(False)
        if self.bias is not None:
            self.bias.requires_grad_(False)

        self.lora_A = torch.nn.Parameter(torch.zeros(self.r, self.in_features))
        self.lora_B = torch.nn.Parameter(torch.zeros(self.out_features, self.r))
        torch.nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
        torch.nn.init.zeros_(self.lora_B)

        self.dropout = torch.nn.Dropout(dropout) if dropout and dropout > 0 else torch.nn.Identity()

    @classmethod
    def from_linear(cls, lin: torch.nn.Linear, r: int, alpha: int, dropout: float):
        m = cls(lin, r, alpha, dropout)
        return m

    def forward(self, x):
        # base
        y = torch.nn.functional.linear(x, self.weight, self.bias)
        # lora branch
        x_d = self.dropout(x)
        lora_update = torch.nn.functional.linear(
            torch.nn.functional.linear(x_d, self.lora_A),  # x @ A^T
            self.lora_B,                                  # (xA) @ B^T
        )
        return y + self.scaling * lora_update

def _gather_target_modules_for_lora(module):
    targets = []
    for name, sub in module.named_modules():
        if isinstance(sub, torch.nn.Linear):
            nm = name.lower()
            # target attention & proj-like linear layers
            if ("qkv" in nm) or (nm.endswith("proj")) or (".proj" in nm) or ("out_proj" in nm):
                targets.append(name)
    return targets

def apply_lora(module, r=8, alpha=16, dropout=0.05):
    to_replace = _gather_target_modules_for_lora(module)
    replaced = []
    for name in to_replace:
        parts = name.split(".")
        parent = module
        for p in parts[:-1]:
            parent = getattr(parent, p)
        leaf = parts[-1]
        old_lin = getattr(parent, leaf)
        if not isinstance(old_lin, torch.nn.Linear):
            continue
        new_lin = LoRALinear.from_linear(old_lin, r=r, alpha=alpha, dropout=dropout)

        device = old_lin.weight.device
        new_lin.lora_A = torch.nn.Parameter(new_lin.lora_A.to(device))
        new_lin.lora_B = torch.nn.Parameter(new_lin.lora_B.to(device))

        setattr(parent, leaf, new_lin)
        replaced.append(name)

    if not replaced:
        print("WARNING: No attention linear layers found for LoRA; proceeding without LoRA.")
    else:
        # freeze everything else
        for p in module.parameters():
            p.requires_grad_(False)
        for n, p in module.named_parameters():
            if n.endswith("lora_A") or n.endswith("lora_B"):
                p.requires_grad_(True)

        total = sum(p.numel() for p in module.parameters())
        trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
        pct = 100.0 * trainable / total if total else 0.0
        print(f"trainable params: {trainable:,} || all params: {total:,} || trainable%: {pct:.4f}")

    return module, replaced


def run_vlm_adaptation(mode="linear", epochs=3, lr=5e-4, lora_r=8):
    meta_dim = len(META_COLS)
    sample_labels = df['labels_encoded_y'].iloc[0]
    if isinstance(sample_labels, str):
        sample_labels = [float(v) for v in sample_labels.replace('[','').replace(']','').split()]
    n_classes = len(sample_labels)
    print("meta_dim =", meta_dim, "| n_classes =", n_classes)

    # Build OpenCLIP ViT visual encoder
    visual, feat_dim = build_openclip_visual()
    visual = visual.to(DEVICE)

    # Freeze by default
    for p in visual.parameters():
        p.requires_grad = False

    model = VLMWithHead(visual, feat_dim, meta_dim, n_classes, head_hidden=256).to(DEVICE)

    if mode == "linear":
        for p in model.backbone.parameters(): p.requires_grad = False
        for p in model.head.parameters(): p.requires_grad = True
        note = "linear probe (head only)"

    elif mode == "partial":
        # unfreeze last block(s) in ViT (simple heuristic)
        unfrozen = 0
        for n, m in model.backbone.named_modules():
            if ".blocks." in n and n.endswith(".mlp.fc2"):
                parent = model.backbone
                for part in n.split(".mlp.fc2")[0].split("."):
                    parent = getattr(parent, part)
                for p in parent.parameters(): p.requires_grad = True
                unfrozen += 1
        for p in model.head.parameters(): p.requires_grad = True
        note = f"partial tuning ({unfrozen} block(s)) + head"

    elif mode == "lora":
        model.backbone, targeted = apply_lora(model.backbone, r=lora_r, alpha=2*lora_r, dropout=0.05)
        for p in model.head.parameters(): p.requires_grad = True
        note = f"LoRA on attention ({len(targeted)} layers) + head"
    else:
        raise ValueError("mode must be one of: 'linear', 'partial', 'lora'")

    print(f"Trainable params: {count_trainable_params(model)} ← {note}")

    model, best = train_model(model, train_loader, val_loader, LABELS, epochs=epochs, lr=lr)
    test_metrics = evaluate(model, test_loader, LABELS)
    print("TEST macro:", test_metrics["macro"])
    return model, test_metrics, count_trainable_params(model)

In [104]:
meta_dim = len(META_COLS)
sample = df['labels_encoded_y'].iloc[0]
if isinstance(sample, str):
        sample = [float(v) for v in sample.replace('[','').replace(']','').split()]
n_classes = len(sample)

print(f"meta_dim = {meta_dim} | n_classes = {n_classes}")
print("--- Training ConcatFusion ---")
concat_model = ConcatFusion("resnet50", 768, meta_dim, 512, n_classes).to(DEVICE)
print("Trainable params:", count_trainable_params(concat_model))
concat_model, best_concat = train_model(
    concat_model, train_loader, val_loader, LABELS,
    epochs=EPOCHS_SCRATCH, lr=1e-3
)
concat_test = evaluate(concat_model, test_loader, LABELS)
print("\nConcatFusion TEST Results:")
print(json.dumps(concat_test["macro"], indent=2))
torch.save(
    {"state_dict": concat_model.state_dict()},
    os.path.join(OUTPUT_DIR, "concat_fusion_best.pt")
)

print("\n" + "="*40 + "\n")

print("--- Training CrossAttnFusion ---")
xattn_model = CrossAttnFusion("resnet50", 512, meta_dim, 8, 1, n_classes).to(DEVICE)
print("Trainable params:", count_trainable_params(xattn_model))
xattn_model, best_xattn = train_model(
    xattn_model, train_loader, val_loader, LABELS,
    epochs=EPOCHS_SCRATCH, lr=1e-3
)
xattn_test = evaluate(xattn_model, test_loader, LABELS)
print("\nCrossAttnFusion TEST Results:")
print(json.dumps(xattn_test["macro"], indent=2))
torch.save(
    {"state_dict": xattn_model.state_dict()},
    os.path.join(OUTPUT_DIR, "xattn_fusion_best.pt")
)


meta_dim = 93 | n_classes = 14
--- Training ConcatFusion ---
Trainable params: 25637710
[Epoch 1/5] train_loss=0.3056 | val_macro_auc=0.6010
[Epoch 2/5] train_loss=0.2378 | val_macro_auc=0.6707
[Epoch 3/5] train_loss=0.2232 | val_macro_auc=0.7141
[Epoch 4/5] train_loss=0.2160 | val_macro_auc=0.7192
[Epoch 5/5] train_loss=0.2114 | val_macro_auc=0.7121

ConcatFusion TEST Results:
{
  "AUC": 0.4411764705882353,
  "F1": 0.0,
  "ACC": 0.8589743589743589
}


--- Training CrossAttnFusion ---
Trainable params: 27765326
[Epoch 1/5] train_loss=0.2822 | val_macro_auc=0.6202
[Epoch 2/5] train_loss=0.2449 | val_macro_auc=0.6364
[Epoch 3/5] train_loss=0.2420 | val_macro_auc=0.6253
[Epoch 4/5] train_loss=0.2436 | val_macro_auc=0.6141
[Epoch 5/5] train_loss=0.2402 | val_macro_auc=0.6747

CrossAttnFusion TEST Results:
{
  "AUC": 0.5397058823529413,
  "F1": 0.0,
  "ACC": 0.8717948717948718
}


In [111]:
results_summary = []
for mode in ["linear", "partial", "lora"]:
    print("\n" + "="*12, mode.upper(), "="*12)
    vlm_model, test_mets, n_trainable = run_vlm_adaptation(
        mode=mode,
        epochs=EPOCHS_VLM,
        lr=(1e-3 if mode=="linear" else 5e-4),
        lora_r=8
    )
    results_summary.append({
        "model": f"openCLIP-ViT-B-32-{mode}",
        "trainable_params": n_trainable,
        "test_macro_auc": test_mets["macro"]["AUC"],
        "test_macro_f1":  test_mets["macro"]["F1"],
        "test_macro_acc": test_mets["macro"]["ACC"],
    })
    # Save the whole wrapped model (works for PEFT too)
    torch.save({"model_state": vlm_model.state_dict()},
               os.path.join(OUTPUT_DIR, f"vlm_{mode}_best.pt"))

import pandas as pd
df_results = pd.DataFrame(results_summary)
print(df_results)


meta_dim = 93 | n_classes = 14
Resizing positional embeddings from 50 → 65
Trainable params: 158734 ← linear probe (head only)
[Epoch 1/3] train_loss=0.3529 | val_macro_auc=0.6242
[Epoch 2/3] train_loss=0.2379 | val_macro_auc=0.6424
[Epoch 3/3] train_loss=0.2260 | val_macro_auc=0.6364
TEST macro: {'AUC': 0.5073529411764707, 'F1': 0.0, 'ACC': 0.8717948717948718}

meta_dim = 93 | n_classes = 14
Resizing positional embeddings from 50 → 65
Trainable params: 158734 ← partial tuning (0 block(s)) + head
[Epoch 1/3] train_loss=0.4079 | val_macro_auc=0.6010
[Epoch 2/3] train_loss=0.2459 | val_macro_auc=0.6162
[Epoch 3/3] train_loss=0.2342 | val_macro_auc=0.6343
TEST macro: {'AUC': 0.4955882352941176, 'F1': 0.0, 'ACC': 0.8717948717948718}

meta_dim = 93 | n_classes = 14
Resizing positional embeddings from 50 → 65
trainable params: 516,096 || all params: 88,376,832 || trainable%: 0.5840
Trainable params: 674830 ← LoRA on attention (24 layers) + head
[Epoch 1/3] train_loss=0.3855 | val_macro_auc=

In [112]:
summary_rows = [
    {"model": "ConcatFusion-ResNet50",
     "trainable_params": count_trainable_params(concat_model),
     "test_macro_auc": concat_test["macro"]["AUC"],
     "test_macro_f1":  concat_test["macro"]["F1"],
     "test_macro_acc": concat_test["macro"]["ACC"]},
    {"model": "XAttnFusion-ResNet50",
     "trainable_params": count_trainable_params(xattn_model),
     "test_macro_auc": xattn_test["macro"]["AUC"],
     "test_macro_f1":  xattn_test["macro"]["F1"],
     "test_macro_acc": xattn_test["macro"]["ACC"]},
]
summary_rows += results_summary

summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv(os.path.join(OUTPUT_DIR, "summary_results.csv"), index=False)

print("\n" + "="*25 + " FINAL SUMMARY " + "="*25)
print(summary_df.to_string())
print(f"\nAll artifacts saved to: {OUTPUT_DIR}")


                       model  trainable_params  test_macro_auc  test_macro_f1  test_macro_acc
0      ConcatFusion-ResNet50          25637710        0.441176            0.0        0.858974
1       XAttnFusion-ResNet50          27765326        0.539706            0.0        0.871795
2   openCLIP-ViT-B-32-linear            158734        0.507353            0.0        0.871795
3  openCLIP-ViT-B-32-partial            158734        0.495588            0.0        0.871795
4     openCLIP-ViT-B-32-lora            674830        0.485294            0.0        0.871795

All artifacts saved to: ./outputs
