In [54]:
# ==========================================
#  CheXpert + ViT + LRFL - Fine-Tuning Notebook
# ==========================================
# Training ViT model for low-Rank Feature Learning (LRFL)
# Morvarid Rahbar
# 4033624008
# ==========================================

In [55]:
# Built-in Libraries
import os
import copy
import shutil
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

# Image & Data
import numpy as np
import pandas as pd
from PIL import Image

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

# Models
import timm
from torchvision import transforms

# Optimizer & Scheduler
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Evaluation
from sklearn.metrics import (
    roc_auc_score,
    f1_score,
    hamming_loss,
    accuracy_score
)

# Utilities
from tqdm import tqdm

# Experiment Tracking
import wandb


In [56]:
# wandb.login()

In [57]:
# Mount Drive
from google.colab import drive
drive.mount("/content/drive")


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Using device:", device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
 Using device: cuda


In [53]:
# LRFL Module
class LRFLModel(nn.Module):
    def __init__(self, backbone_name='vit_base_patch16_224', rank=64, num_classes=5, dropout=0.1):
        super().__init__()
        self.rank = rank
        self.num_classes = num_classes

        self.backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0)
        self.embed_dim = self.backbone.num_features

        self.low_rank_proj = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, rank, bias=False),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.classifier = nn.Linear(rank, num_classes)

    def forward_features(self, x):
        feats = self.backbone.forward_features(x)
        if feats.dim() == 3:  # For ViT-style models
            feats = feats[:, 0, :]
        return feats

    def forward(self, x, return_feats=False):
        feats = self.forward_features(x)
        proj = self.low_rank_proj(feats)
        logits = self.classifier(proj)
        return (logits, feats) if return_feats else logits


In [58]:
# BaseLine Model
# class BaseModel(nn.Module):
#     def __init__(self, backbone_name='vit_base_patch16_224', num_classes=5):
#         super().__init__()
#         self.backbone = timm.create_model(backbone_name, pretrained=True)

#         in_features = self.backbone.head.in_features
#         self.backbone.reset_classifier(0)

#         self.classifier = nn.Linear(in_features, num_classes)

#     def forward(self, x):
#         feats = self.backbone.forward_features(x)

#         logits = self.classifier(feats)
#         return logits

#     def get_image_embedding(self, x):
#         feats = self.backbone.forward_features(x)
#         return feats

In [59]:
class CheXpertDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.label_cols = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]

        self.data[self.label_cols] = self.data[self.label_cols].fillna(0).replace(-1, 1)


        self.prefix = "CheXpert-v1.0-small/"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        path = row["Path"]

        if path.startswith(self.prefix):
            path = path[len(self.prefix):]

        image_path = os.path.join(self.root_dir, path)
        image = Image.open(image_path).convert("RGB")
        labels = row[self.label_cols].values.astype(np.float32)

        if self.transform:
            image = self.transform(image)

        return image, path  # üëà ŸÅŸÇÿ∑ image Ÿà path ÿ®ÿØŸá (label ŸÜŸÖ€å‚ÄåÿÆŸàÿß€å ÿ®ÿ±ÿß€å embedding)


In [60]:
#  Transforms


train_transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [61]:
# LRFL Regularization
def compute_uv(features, rank):
    with torch.no_grad():
        U, _, Vh = torch.linalg.svd(features, full_matrices=False)
        return U[:, :rank], Vh[:rank, :]

def lrfl_loss_fn(logits, labels, features, U, V, eta=1e-3):
    bce = nn.BCEWithLogitsLoss()(logits, labels)
    reg = torch.sum((U.T @ features) @ V.T)
    return bce + eta * reg / features.size(0)


In [62]:
df = pd.read_csv('/content/drive/MyDrive/chexpert_data_v2/train.csv')
test_path = df['Path'][0]  # CheXpert-v1.0-small/train/patient00001/study1/view1_frontal.jpg
test_path = test_path.replace('CheXpert-v1.0-small/', '')
full_path = os.path.join('/content/drive/MyDrive/chexpert_data_v2', test_path)
print(" Exists:", os.path.exists(full_path))


 Exists: True


In [63]:
csv_path = "/content/drive/MyDrive/chexpert_data_v2/train.csv"
img_root = "/content/drive/MyDrive/chexpert_data_v2"
target_root = "/content/chexpert_data_v2_selected"
target_root = Path(target_root)
target_root.mkdir(parents=True, exist_ok=True)

In [64]:
df = pd.read_csv(csv_path)
subset_df = df.sample(frac=0.3 , random_state= 42).reset_index(drop=True)
subset_df.to_csv("chexpert_30percent.csv",index = False)

In [65]:
def copy_file(rel_path_str):
    prefix = "CheXpert-v1.0-small/"
    if rel_path_str.startswith(prefix):
        rel_path_str = rel_path_str[len(prefix):]
    rel_path = Path(rel_path_str)
    src = Path(img_root) / rel_path
    dst = target_root / rel_path
    try:
        dst.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(src, dst)
    except:
        return rel_path

image_paths = subset_df["Path"].tolist()
with ThreadPoolExecutor(max_workers=32) as executor:
    list(tqdm(executor.map(copy_file, image_paths), total=len(image_paths), desc="üì• Copying files"))

üì• Copying files: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 67024/67024 [04:00<00:00, 278.83it/s]


In [66]:
!ls /content/drive/MyDrive/chexpert_data_v2/

train  train.csv  valid  valid.csv


In [67]:
# wandb.init(
#     project="chexpert-lrf-vit",
#     name="run-vit-lrf-v1",
#     config={
#         "lr": 1e-4,
#         "batch_size": 128,
#         "epochs": 10,
#         "model": "ViT + LRF",
#         "rank": 64
#     }
# )

In [68]:
#   One Epoch Training
def train_one_epoch(model, dataloader, optimizer, device, rank, eta):
    model.train()
    all_preds, all_labels = [], []
    total_loss = 0

    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits, feats = model(images, return_feats=True)

        U, V = compute_uv(feats, rank)
        loss = lrfl_loss_fn(logits, labels, feats, U, V, eta)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_preds.append(torch.sigmoid(logits).detach().cpu())
        all_labels.append(labels.cpu())

    preds = torch.cat(all_preds).numpy()
    trues = torch.cat(all_labels).numpy()
    auc = roc_auc_score(trues, preds, average="macro")
    avg_loss = total_loss / len(dataloader)

    return auc, avg_loss


In [69]:
#   Validation
def evaluate(model, dataloader, device, return_loss=False, rank=32, eta=1e-3):
    model.eval()
    all_preds, all_labels = [], []
    total_loss = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            logits, feats = model(images, return_feats=True)
            U, V = compute_uv(feats, rank)
            loss = lrfl_loss_fn(logits, labels, feats, U, V, eta)
            total_loss += loss.item()

            all_preds.append(torch.sigmoid(logits).cpu())
            all_labels.append(labels.cpu())

    preds = torch.cat(all_preds).numpy()  # [N, C]
    trues = torch.cat(all_labels).numpy()  # [N, C]
    bin_preds = (preds > 0.5).astype(int)  # Thresholding

    # ===== AUC per class =====
    aucs = []
    for i in range(trues.shape[1]):
        try:
            auc = roc_auc_score(trues[:, i], preds[:, i])
            aucs.append(auc)
        except ValueError:
            aucs.append(np.nan)  # skip classes with only one label

    mean_auc = np.nanmean(aucs)

    # ===== Other metrics =====
    f1_macro = f1_score(trues, bin_preds, average='macro', zero_division=0)
    f1_micro = f1_score(trues, bin_preds, average='micro', zero_division=0)
    hamming = hamming_loss(trues, bin_preds)
    acc = accuracy_score(trues, bin_preds)  # not very meaningful in multilabel

    avg_loss = total_loss / len(dataloader)

    print(f"\n   Evaluation Metrics:")
    print(f" -   Mean AUC:      {mean_auc:.4f}")
    print(f" -   F1 Macro:      {f1_macro:.4f}")
    print(f" -   F1 Micro:      {f1_micro:.4f}")
    print(f" -   Hamming Loss:  {hamming:.4f}")
    print(f" -   Accuracy:      {acc:.4f}")

    if return_loss:
        return mean_auc, avg_loss, f1_macro, hamming
    return mean_auc



In [70]:
# # ŸÜŸÖÿß€åÿ¥ 5 ŸÜŸÖŸàŸÜŸá‚Äå€å ÿ™ÿµÿßÿØŸÅ€å ÿßÿ≤ train_dataset
# import matplotlib.pyplot as plt

# for i in range(5):
#     img, label, path = train_dataset[i]
#     plt.imshow(img.permute(1, 2, 0))  # ÿ®ÿ±ÿß€å ŸÜŸÖÿß€åÿ¥ [C,H,W] ‚Üí [H,W,C]
#     plt.title(f"Path: {path}\nLabels: {label}")
#     plt.axis('off')
#     plt.show()

In [71]:
def train(model, train_loader, val_loader, optimizer, device, num_epochs=30,
          rank=32, eta=1e-3, patience=5, checkpoint_path="best_model.pth", monitor="val_loss"):

    # ŸÖŸÇÿØÿßÿ± ÿßŸàŸÑ€åŸá best_metric ÿ®ÿß€åÿØ ÿ®ÿ± ÿßÿ≥ÿßÿ≥ ŸÖÿπ€åÿßÿ± ŸÖŸàŸÜ€åÿ™Ÿàÿ± ÿ™ŸÜÿ∏€åŸÖ ÿ®ÿ¥Ÿá
    if monitor in ["val_loss", "hamming"]:
        best_metric = float('inf')
    else:
        best_metric = -float('inf')  # ÿ®ÿ±ÿß€å ŸÖÿπ€åÿßÿ±Ÿáÿß€å€å ⁄©Ÿá ÿ®ÿ≤ÿ±⁄Øÿ™ÿ± ÿ®Ÿáÿ™ÿ±Ÿá ŸÖÿ´ŸÑ AUC Ÿà F1

    epochs_no_improve = 0

    for epoch in range(num_epochs):
        print(f"\n Epoch {epoch+1}/{num_epochs}")
        train_auc, train_loss = train_one_epoch(model, train_loader, optimizer, device, rank, eta)
        val_auc, val_loss, val_f1, val_hamming = evaluate(model, val_loader, device, return_loss=True, rank=rank, eta=eta)

        print(f"Train AUC: {train_auc:.4f} | Val AUC: {val_auc:.4f} | Val Loss: {val_loss:.4f} | F1 Macro: {val_f1:.4f} | Hamming: {val_hamming:.4f}")

        if monitor == "val_auc":
            current_metric = val_auc
            improvement = current_metric > best_metric
        elif monitor == "val_f1":
            current_metric = val_f1
            improvement = current_metric > best_metric
        elif monitor == "hamming":
            current_metric = val_hamming
            improvement = current_metric < best_metric
        else:  # val_loss
            current_metric = val_loss
            improvement = current_metric < best_metric

        if improvement:
            best_metric = current_metric
            epochs_no_improve = 0
            torch.save(model.state_dict(), checkpoint_path)
            print(" Best model saved.")
        else:
            epochs_no_improve += 1
            print(f" No improvement for {epochs_no_improve} epoch(s).")

        if epochs_no_improve >= patience:
            print(" Early stopping triggered!")
            break


In [72]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = CheXpertDataset("chexpert_30percent.csv", "/content/chexpert_data_v2_selected", train_transform)
val_dataset   = CheXpertDataset("/content/drive/MyDrive/chexpert_data_v2/valid.csv", "/content/drive/MyDrive/chexpert_data_v2", train_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=False)



In [73]:
model = LRFLModel(backbone_name="vit_base_patch16_224", rank=64, num_classes=5).to(device)
optimizer = optim.AdamW(model.parameters(), lr=2e-5)

In [75]:
# for images, labels in train_loader:
#     print("Batch labels shape:", labels.shape)
#     print("First batch of labels:", labels[:5])
#     break
# # # for being more strick -1 labels which were mapped to uncertain condition has been transformed to +1 label.

In [None]:
# train(model, train_loader, val_loader, optimizer, device,
#       num_epochs=20, rank=32, eta=1e-3, patience=5, checkpoint_path="best_lrfl_model.pth")

In [77]:
# for batch in train_loader:
#     print(len(batch))
#     break

In [79]:
# # Extract image embeddings and save
# model.load_state_dict(torch.load("/content/drive/MyDrive/best_lrfl_model_v2.pth"))
# model.eval()

# image_embeddings = {}

# with torch.no_grad():
#     for images, paths in tqdm(train_loader):
#         images = images.to(device)
#         embs = model.get_image_embedding(images)  # (B, R)

#         for path, emb in zip(paths, embs):
#             image_embeddings[path] = emb.cpu()

# torch.save(image_embeddings, "/content/drive/MyDrive/image_embeddings_LRF30_v2.pt")
# print(" Image embeddings saved!")


In [80]:

model = LRFLModel(rank=64, num_classes=5)
model.load_state_dict(torch.load("/content/drive/MyDrive/best_lrfl_model_v2.pth"))
model = model.to(device)
model.eval()

image_embeddings = {}

with torch.no_grad():
    for images, paths in tqdm(train_loader):
        images = images.to(device)

        feats = model.forward_features(images)       # [B, embed_dim]
        embs = model.low_rank_proj(feats)            # [B, rank]

        for path, emb in zip(paths, embs):
            image_embeddings[path] = emb.cpu()


torch.save(image_embeddings, "/content/drive/MyDrive/image_embeddings_LRF30_v2.pt")
print("‚úÖ Image embeddings saved successfully!")


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4189/4189 [14:50<00:00,  4.71it/s]


‚úÖ Image embeddings saved successfully!


In [None]:
# Flush GPU Memory
def clear_cuda_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
clear_cuda_cache()


In [None]:

# print(torch.cuda.memory_summary(device=None, abbreviated=False))
