## Task 2 - CLIP Fine-Tuning on the Visual Encoder

In [None]:
#@title GPU / Python / Torch sanity
import os, sys, subprocess, json, platform, torch
print("Python :", sys.version)
print("CUDA   :", torch.version.cuda)
print("Torch  :", torch.__version__)
print("Device :", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
!nvidia-smi || true

In [None]:
# some imports
import os, time, math, random
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # stable ordering
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, ConcatDataset
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModel, logging
from peft import LoraConfig, get_peft_model, TaskType
from torchinfo import summary
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import json
import warnings

# Ensure reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
# some settings
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "openai/clip-vit-large-patch14" # pre-trained CLIP model (ViT-L/14)
BATCH_SIZE = 32  # chosen for RTX 3090 (24GB); adjust if you hit OOM
gradient_accumulation_steps = 1 # adjust based on your GPU memory
# For Linear Probe & LoRA
NUM_EPOCHS = 10  # increase to >3 for better curves once everything runs
SEED = 42
print(f"Using device: {DEVICE}")
print("device_count:", torch.cuda.device_count())
print("device 0:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

# Default LRs
LR_HEAD_LINEAR = 1e-3   # linear probe head lr
LR_HEAD_LORA   = 1e-3   # LoRA run head lr
LR_LORA        = 1e-4   # LoRA adapter lr
WEIGHT_DECAY   = 1e-4

In [None]:
# CLIP settings
# --- Load CLIP Processor ---
processor = CLIPProcessor.from_pretrained(MODEL_ID)
# --- Define a transform to process images for CLIP ---
class CLIPTransform:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, image):
        # The processor expects a PIL image or list of images
        # It returns a dict, we extract 'pixel_values'
        # .squeeze(0) removes the batch dimension the processor adds
        return self.processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)

clip_transform = CLIPTransform(processor)

In [None]:
# dataset related imports
from torchvision.datasets import Flowers102 
from datasets import load_dataset

# --- Flowers102 ---
# prepare Flowers102 dataset
flowers102_test_dts = Flowers102(root=".", split="test", transform=clip_transform) # evaluation on this set
flowers102_train_dts = Flowers102(root=".", split="train", transform=clip_transform)
flowers102_val_dts = Flowers102(root=".", split="val", transform=clip_transform)

print(f"Total training samples (orig train): {len(flowers102_train_dts)}")
print(f"Total validation samples (orig val): {len(flowers102_val_dts)}")
print(f"Total test samples: {len(flowers102_test_dts)}") # should be 6149

# prepare class names for Flowers102
with open("cat_to_name.json", "r") as f:
    cat_to_name = json.load(f)

flowers102_class_names = [cat_to_name[str(i + 1)] for i in range(102)]

# --- CUB-200-2011 ---
birds_200 = load_dataset("bentrevett/caltech-ucsd-birds-200-2011")
split = birds_200["train"].train_test_split(test_size=0.1, seed=42, shuffle=True)

cub_bird_train_dts = split["train"]
cub_bird_val_dts = split["test"]
cub_bird_test_dts = birds_200["test"]

print(f"Total training samples: {len(cub_bird_train_dts)}")
print(f"Total validation samples: {len(cub_bird_val_dts)}")
print(f"Total test samples: {len(cub_bird_test_dts)}") # should be 5794

# prepare class names for CUB-200-2011
cub_class_names_raw = birds_200["train"].features["label"].names
cub_class_names = [name.split('.')[-1].replace('_', ' ') for name in cub_class_names_raw]

# Wrap HF dataset to return (pixel_values, label) tensors compatible with default collate
class HFCUBWrapper(torch.utils.data.Dataset):
    def __init__(self, hf_ds):
        self.hf_ds = hf_ds
    def __len__(self):
        return len(self.hf_ds)
    def __getitem__(self, idx):
        ex = self.hf_ds[idx]
        pv = clip_transform(ex["image"])  # Tensor [3, 224, 224]
        lbl = int(ex["label"])            # int
        return pv, lbl
    
cub_bird_train_dts = HFCUBWrapper(cub_bird_train_dts)
cub_bird_val_dts = HFCUBWrapper(cub_bird_val_dts)
cub_bird_test_dts = HFCUBWrapper(cub_bird_test_dts)

# === Create DataLoaders (only test for now, train/val in next cell) ===
flowers102_test_loader = DataLoader(flowers102_test_dts, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
cub200_test_loader = DataLoader(cub_bird_test_dts, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)


In [None]:
# === Train/Val splits and DataLoaders ===
# Flowers: merge original train+val (2040) and split into 1836 train / 204 val
from torch.utils.data import ConcatDataset
flowers_trainval = ConcatDataset([flowers102_train_dts, flowers102_val_dts])
assert len(flowers_trainval) == (len(flowers102_train_dts) + len(flowers102_val_dts)) == 2040
train_len, val_len = 1836, 204
g = torch.Generator().manual_seed(SEED)
flowers_train_dts, flowers_val_dts_new = random_split(flowers_trainval, [train_len, val_len], generator=g)
flowers_train_loader = DataLoader(flowers_train_dts, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
flowers_val_loader   = DataLoader(flowers_val_dts_new, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
flowers_test_loader  = flowers102_test_loader
print(f"Flowers -> train: {len(flowers_train_dts)}, val: {len(flowers_val_dts_new)}, test: {len(flowers102_test_dts)}")

# CUB: we already did a 90/10 split to get 5394/600
cub_train_loader = DataLoader(cub_bird_train_dts, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
cub_val_loader   = DataLoader(cub_bird_val_dts, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
cub_test_loader  = cub200_test_loader
print(f"CUB-200 -> train: {len(cub_bird_train_dts)}, val: {len(cub_bird_val_dts)}, test: {len(cub_bird_test_dts)}")

In [None]:
# === Training utilities ===
def accuracy_from_logits(logits: torch.Tensor, targets: torch.Tensor) -> float:
    preds = logits.argmax(dim=-1)
    return (preds == targets).float().mean().item()

# Safe image features that work with/without LoRA wrapping
def get_image_features_safe(model, pixel_values: torch.Tensor, require_backbone_grad: bool):
    def compute():
        # First try the official helper
        try:
            return model.get_image_features(pixel_values=pixel_values)
        except TypeError:
            # Manual: vision forward -> pooled -> visual projection -> l2 norm
            vision_out = model.vision_model(pixel_values=pixel_values, return_dict=True)
            pooled = getattr(vision_out, "pooler_output", None)
            if pooled is None:
                pooled = vision_out[1]
            feats = model.visual_projection(pooled)
            feats = feats / feats.norm(dim=-1, keepdim=True)
            return feats

    if require_backbone_grad:
        return compute()
    else:
        with torch.no_grad():
            return compute()

@torch.no_grad()
def evaluate_epoch(model, head, data_loader, device, require_backbone_grad: bool):
    model.eval()
    head.eval()
    total_loss, total_acc, total_n = 0.0, 0.0, 0
    ce = nn.CrossEntropyLoss()
    for pixel_values, labels in data_loader:
        pixel_values = pixel_values.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        feats = get_image_features_safe(model, pixel_values, require_backbone_grad=False)
        logits = head(feats)
        loss = ce(logits, labels)
        bs = labels.size(0)
        total_loss += loss.item() * bs
        total_acc  += (logits.argmax(-1) == labels).sum().item()
        total_n    += bs
    return total_loss / max(1,total_n), total_acc / max(1,total_n)

def train_epoch(model, head, data_loader, device, optimizer, grad_accum_steps: int, require_backbone_grad: bool):
    model.train() if require_backbone_grad else model.eval()
    head.train()
    ce = nn.CrossEntropyLoss()
    total_loss, total_acc, total_n = 0.0, 0.0, 0
    optimizer.zero_grad(set_to_none=True)
    for step, (pixel_values, labels) in enumerate(tqdm(data_loader, leave=False)):
        pixel_values = pixel_values.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        feats = get_image_features_safe(model, pixel_values, require_backbone_grad=require_backbone_grad)
        logits = head(feats)
        loss = ce(logits, labels) / grad_accum_steps
        loss.backward()
        if (step + 1) % grad_accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
        bs = labels.size(0)
        total_loss += loss.item() * grad_accum_steps * bs
        total_acc  += (logits.argmax(-1) == labels).sum().item()
        total_n    += bs
    return total_loss / max(1,total_n), total_acc / max(1,total_n)

def plot_curves(history, title_prefix="", save_path: str | None = None):
    train_loss, val_loss = history['train_loss'], history['val_loss']
    train_acc,  val_acc  = history['train_acc'],  history['val_acc']
    epochs = range(1, len(train_loss)+1)
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.plot(epochs, train_loss, label='train')
    plt.plot(epochs, val_loss, label='val')
    plt.xlabel('epoch'); plt.ylabel('loss'); plt.title(f'{title_prefix} loss'); plt.legend()
    plt.subplot(1,2,2)
    plt.plot(epochs, train_acc, label='train')
    plt.plot(epochs, val_acc, label='val')
    plt.xlabel('epoch'); plt.ylabel('accuracy'); plt.title(f'{title_prefix} accuracy'); plt.legend()
    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight')
    plt.show()

In [None]:
print("--- Starting Method: Linear Probing ---")

# 1) Load full CLIP (we'll only use the vision tower)
clip_linear = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE)
vision_model = clip_linear.vision_model
visual_projection = clip_linear.visual_projection

# 2) Freeze backbone and projection (no grads in vision encoder)
for p in vision_model.parameters():
    p.requires_grad = False
for p in visual_projection.parameters():
    p.requires_grad = False
clip_linear.eval()  # ensure backbone runs in eval for deterministic features

# Helper to run one dataset

def run_linear_probe_one(name, num_classes, loaders, save_prefix):
    train_loader, val_loader, test_loader = loaders
    # Feature dim from projected image features
    in_dim = visual_projection.out_features if hasattr(visual_projection, 'out_features') else clip_linear.config.projection_dim
    head = nn.Linear(in_dim, num_classes).to(DEVICE)
    optimizer = torch.optim.Adam(head.parameters(), lr=LR_HEAD_LINEAR, weight_decay=WEIGHT_DECAY)

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
    for epoch in range(NUM_EPOCHS):
        tr_loss, tr_acc = train_epoch(clip_linear, head, train_loader, DEVICE, optimizer, gradient_accumulation_steps, require_backbone_grad=False)
        val_loss, val_acc = evaluate_epoch(clip_linear, head, val_loader, DEVICE, require_backbone_grad=False)
        history["train_loss"].append(tr_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(tr_acc)
        history["val_acc"].append(val_acc)
        print(f"[{name}] Epoch {epoch+1}/{NUM_EPOCHS} | train loss {tr_loss:.4f} acc {tr_acc*100:.2f}% | val loss {val_loss:.4f} acc {val_acc*100:.2f}%")

    # Save curves and head
    save_history_png(history, title_prefix=f"LinearProbe {name}", filename=f"linear_{save_prefix}_curves.png")
    torch.save(head.state_dict(), CKPT_DIR / f"linear_{save_prefix}_head.pt")
    print(f"Saved head: {CKPT_DIR / f'linear_{save_prefix}_head.pt'}")

    test_loss, test_acc = evaluate_epoch(clip_linear, head, test_loader, DEVICE, require_backbone_grad=False)
    append_metrics_csv([{"method":"Linear","dataset":name,"epochs":NUM_EPOCHS,"test_acc":test_acc,"val_acc":history['val_acc'][-1]}])
    print(f"[{name}] Test: loss {test_loss:.4f} | acc {test_acc*100:.2f}%")
    return head, history, {"test_loss":test_loss, "test_acc":test_acc}

# Flowers-102 (102 classes)
flowers_linear_head, flowers_linear_hist, flowers_linear_test = run_linear_probe_one(
    name="Flowers102",
    num_classes=102,
    loaders=(flowers_train_loader, flowers_val_loader, flowers_test_loader),
    save_prefix="flowers"
)

# CUB-200-2011 (200 classes)
cub_linear_head, cub_linear_hist, cub_linear_test = run_linear_probe_one(
    name="CUB-200",
    num_classes=200,
    loaders=(cub_train_loader, cub_val_loader, cub_test_loader),
    save_prefix="cub"
)

In [None]:
print("--- Starting Method: LoRA Fine-Tuning ---")

# 1) Load CLIP and inject LoRA into the full CLIP model (targets still match vision modules)
clip_lora = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE)
visual_projection = clip_lora.visual_projection

# 2) LoRA config (target Q/V projections in ViT attention blocks)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],  # matches both text/vision; only vision path gets gradients
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION,
)

# 3) Wrap with PEFT at the CLIPModel level
clip_lora = get_peft_model(clip_lora, lora_config)
print("LoRA Model - Trainable Parameters:")
clip_lora.print_trainable_parameters()

# 4) Freeze non-LoRA weights and the visual projection
for n, p in clip_lora.named_parameters():
    if "lora_" in n:
        p.requires_grad = True
    else:
        p.requires_grad = False
for p in visual_projection.parameters():
    p.requires_grad = False

# Helper to run one dataset with LoRA
def run_lora_one(name, num_classes, loaders, save_prefix):
    train_loader, val_loader, test_loader = loaders
    in_dim = visual_projection.out_features if hasattr(visual_projection, 'out_features') else clip_lora.config.projection_dim
    head = nn.Linear(in_dim, num_classes).to(DEVICE)

    # Two parameter groups: LoRA adapters and head
    lora_params = [p for p in clip_lora.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam([
        {"params": lora_params, "lr": LR_LORA},
        {"params": head.parameters(), "lr": LR_HEAD_LORA},
    ], weight_decay=WEIGHT_DECAY)

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
    for epoch in range(NUM_EPOCHS):
        tr_loss, tr_acc = train_epoch(clip_lora, head, train_loader, DEVICE, optimizer, gradient_accumulation_steps, require_backbone_grad=True)
        val_loss, val_acc = evaluate_epoch(clip_lora, head, val_loader, DEVICE, require_backbone_grad=True)
        history["train_loss"].append(tr_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(tr_acc)
        history["val_acc"].append(val_acc)
        print(f"[LoRA {name}] Epoch {epoch+1}/{NUM_EPOCHS} | train loss {tr_loss:.4f} acc {tr_acc*100:.2f}% | val loss {val_loss:.4f} acc {val_acc*100:.2f}%")

    # Save curves and adapters/head
    save_history_png(history, title_prefix=f"LoRA {name}", filename=f"lora_{save_prefix}_curves.png")
    torch.save(head.state_dict(), CKPT_DIR / f"lora_{save_prefix}_head.pt")
    torch.save(clip_lora.state_dict(), CKPT_DIR / f"lora_{save_prefix}_clip.pt")
    print(f"Saved head: {CKPT_DIR / f'lora_{save_prefix}_head.pt'}")
    print(f"Saved LoRA model: {CKPT_DIR / f'lora_{save_prefix}_clip.pt'}")

    test_loss, test_acc = evaluate_epoch(clip_lora, head, test_loader, DEVICE, require_backbone_grad=True)
    append_metrics_csv([{"method":"LoRA","dataset":name,"epochs":NUM_EPOCHS,"test_acc":test_acc,"val_acc":history['val_acc'][-1]}])
    print(f"[LoRA {name}] Test: loss {test_loss:.4f} | acc {test_acc*100:.2f}%")
    return head, history, {"test_loss":test_loss, "test_acc":test_acc}, clip_lora

# Flowers-102 (102 classes)
flowers_lora_head, flowers_lora_hist, flowers_lora_test, flowers_lora_model = run_lora_one(
    name="Flowers102",
    num_classes=102,
    loaders=(flowers_train_loader, flowers_val_loader, flowers_test_loader),
    save_prefix="flowers"
)

# CUB-200-2011 (200 classes)
cub_lora_head, cub_lora_hist, cub_lora_test, cub_lora_model = run_lora_one(
    name="CUB-200",
    num_classes=200,
    loaders=(cub_train_loader, cub_val_loader, cub_test_loader),
    save_prefix="cub"
)

In [None]:
# === Results saving helpers ===
import os
from pathlib import Path
import pandas as pd
RESULTS_DIR = Path("results/task2")
CKPT_DIR = Path("checkpoints/task2")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

def save_history_png(history, title_prefix, filename):
    out = RESULTS_DIR / filename
    plot_curves(history, title_prefix=title_prefix, save_path=str(out))
    print(f"Saved curve: {out}")

def append_metrics_csv(rows, filename="metrics.csv"):
    out = RESULTS_DIR / filename
    df = pd.DataFrame(rows)
    if out.exists():
        old = pd.read_csv(out)
        df = pd.concat([old, df], ignore_index=True)
    df.to_csv(out, index=False)
    print(f"Saved metrics: {out}")

In [None]:
# === NTU images: inference on 4 trained models ===
from PIL import Image
import torch.nn.functional as F

def load_image_tensor(path):
    img = Image.open(path).convert('RGB')
    pv = clip_transform(img)  # [3,224,224]
    return pv.unsqueeze(0).to(DEVICE), img

def topk_from_logits(logits, class_names, k=5):
    probs = F.softmax(logits, dim=-1).squeeze(0)
    topk = torch.topk(probs, k)
    idxs = topk.indices.tolist()
    vals = topk.values.tolist()
    return [(class_names[i], float(v)) for i, v in zip(idxs, vals)]

def predict_one(model, head, image_path, class_names, title_prefix, save_name):
    pixel_values, img = load_image_tensor(image_path)
    feats = get_image_features_safe(model, pixel_values, require_backbone_grad=False)
    logits = head(feats)
    topk = topk_from_logits(logits, class_names, k=5)
    # Plot bar
    labels = [l for l,_ in topk][::-1]
    vals   = [v for _,v in topk][::-1]
    plt.figure(figsize=(6,3))
    sns.barplot(x=vals, y=labels, palette="mako")
    plt.xlim(0,1)
    for i,v in enumerate(vals):
        plt.text(v+0.01, i, f"{v*100:.1f}%", va='center')
    plt.title(f"{title_prefix}")
    out = RESULTS_DIR / save_name
    plt.tight_layout()
    plt.savefig(out, bbox_inches='tight')
    plt.show()
    print(f"Saved: {out}")
    return topk

bird_path = "img/bird_ntu.jpg"
flower_path = "img/flower_ntu.jpg"

# Linear probe models use the frozen clip_linear for features
if 'clip_linear' in globals() and 'flowers_linear_head' in globals() and 'cub_linear_head' in globals():
    print("-- Linear probe: NTU images --")
    predict_one(clip_linear, flowers_linear_head, flower_path, flowers102_class_names, "Linear Flowers102 - flower_ntu", "ntu_linear_flowers_flower.png")
    predict_one(clip_linear, cub_linear_head,     bird_path,   cub_class_names,        "Linear CUB-200 - bird_ntu",     "ntu_linear_cub_bird.png")
else:
    print("Linear probe heads not found. Run the Linear Probing cell first.")

# LoRA models use their respective clip_lora_model instances returned
if 'flowers_lora_model' in globals() and 'flowers_lora_head' in globals():
    print("-- LoRA Flowers: NTU images --")
    predict_one(flowers_lora_model, flowers_lora_head, flower_path, flowers102_class_names, "LoRA Flowers102 - flower_ntu", "ntu_lora_flowers_flower.png")
else:
    print("LoRA Flowers model not found. Run the LoRA cell.")

if 'cub_lora_model' in globals() and 'cub_lora_head' in globals():
    print("-- LoRA CUB: NTU images --")
    predict_one(cub_lora_model, cub_lora_head, bird_path,   cub_class_names, "LoRA CUB-200 - bird_ntu",   "ntu_lora_cub_bird.png")
else:
    print("LoRA CUB model not found. Run the LoRA cell.")

In [None]:
# === Optional: load checkpoints and run NTU predictions without retraining ===
def load_linear_head(save_prefix, in_dim, num_classes):
    head = nn.Linear(in_dim, num_classes).to(DEVICE)
    sd = torch.load(CKPT_DIR / f"linear_{save_prefix}_head.pt", map_location=DEVICE)
    head.load_state_dict(sd)
    head.eval()
    return head

def load_lora_model(save_prefix):
    # Reload base CLIP and then load raw state dict (saved above)
    mdl = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE)
    mdl = get_peft_model(mdl, lora_config)
    sd = torch.load(CKPT_DIR / f"lora_{save_prefix}_clip.pt", map_location=DEVICE)
    mdl.load_state_dict(sd)
    return mdl

def try_ntu_from_checkpoints():
    try:
        # Linear
        base = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE)
        base.eval()
        in_dim = base.visual_projection.out_features if hasattr(base.visual_projection,'out_features') else base.config.projection_dim
        fl_head = load_linear_head('flowers', in_dim, 102)
        cb_head = load_linear_head('cub', in_dim, 200)
        print("-- Linear from checkpoints --")
        predict_one(base, fl_head, 'img/flower_ntu.jpg', flowers102_class_names, 'Linear Flowers102 - flower_ntu (ckpt)', 'ntu_linear_ckpt_flowers_flower.png')
        predict_one(base, fl_head, 'img/bird_ntu.jpg',   flowers102_class_names, 'Linear Flowers102 - bird_ntu (ckpt)',   'ntu_linear_ckpt_flowers_bird.png')
        predict_one(base, cb_head, 'img/flower_ntu.jpg', cub_class_names,        'Linear CUB-200 - flower_ntu (ckpt)',   'ntu_linear_ckpt_cub_flower.png')
        predict_one(base, cb_head, 'img/bird_ntu.jpg',   cub_class_names,        'Linear CUB-200 - bird_ntu (ckpt)',     'ntu_linear_ckpt_cub_bird.png')
    except Exception as e:
        print("Linear checkpoints not available:", e)
    try:
        print("-- LoRA from checkpoints --")
        fl_m = load_lora_model('flowers')
        fl_h = torch.load(CKPT_DIR / 'lora_flowers_head.pt', map_location=DEVICE)
        head_fl = nn.Linear(fl_m.visual_projection.out_features if hasattr(fl_m.visual_projection,'out_features') else fl_m.config.projection_dim, 102).to(DEVICE)
        head_fl.load_state_dict(fl_h); head_fl.eval()
        predict_one(fl_m, head_fl, 'img/flower_ntu.jpg', flowers102_class_names, 'LoRA Flowers102 - flower_ntu (ckpt)', 'ntu_lora_ckpt_flowers_flower.png')
        predict_one(fl_m, head_fl, 'img/bird_ntu.jpg',   flowers102_class_names, 'LoRA Flowers102 - bird_ntu (ckpt)',   'ntu_lora_ckpt_flowers_bird.png')
    except Exception as e:
        print("LoRA Flowers checkpoints not available:", e)
    try:
        cb_m = load_lora_model('cub')
        cb_h = torch.load(CKPT_DIR / 'lora_cub_head.pt', map_location=DEVICE)
        head_cb = nn.Linear(cb_m.visual_projection.out_features if hasattr(cb_m.visual_projection,'out_features') else cb_m.config.projection_dim, 200).to(DEVICE)
        head_cb.load_state_dict(cb_h); head_cb.eval()
        predict_one(cb_m, head_cb, 'img/flower_ntu.jpg', cub_class_names, 'LoRA CUB-200 - flower_ntu (ckpt)', 'ntu_lora_ckpt_cub_flower.png')
        predict_one(cb_m, head_cb, 'img/bird_ntu.jpg',   cub_class_names, 'LoRA CUB-200 - bird_ntu (ckpt)',   'ntu_lora_ckpt_cub_bird.png')
    except Exception as e:
        print("LoRA CUB checkpoints not available:", e)

# Uncomment to run if you've trained and saved checkpoints previously
# try_ntu_from_checkpoints()