In [4]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import roc_auc_score
import shutil
from concurrent.futures import ThreadPoolExecutor
import wandb

In [5]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmrahbar-2001[0m ([33mmrahbar-2001-university-of-isfahan[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [33]:
# ======= LRF =======
class LRFModel(nn.Module):
    def __init__(self, backbone_name='vit_base_patch16_224', rank=64, num_classes=5):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=True)
        in_features = self.backbone.head.in_features
        self.backbone.reset_classifier(0)
        self.low_rank_head = nn.Sequential(
            nn.Linear(in_features, rank, bias=False),
            nn.ReLU(),
            nn.Linear(rank, num_classes)
        )

    def forward(self, x):
        feats = self.backbone.forward_features(x)  # [B, 197, D] €åÿß [B, D]
        if feats.ndim == 3:
            feats = feats[:, 0]  # €åÿß feats.mean(dim=1) ÿ®ÿ≥ÿ™Ÿá ÿ®Ÿá ŸÖŸÇÿßŸÑŸá
        return self.low_rank_head(feats)


In [9]:
class CheXpertDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):

        self.labels_df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        # ŸÖÿ≥€åÿ± ŸÜÿ≥ÿ®€å ÿ™ÿµŸà€åÿ± ÿßÿ≤ ÿ≥ÿ™ŸàŸÜ CSV
        img_rel_path = self.labels_df.iloc[idx]['Path']

        # ÿ≠ÿ∞ŸÅ Ÿæ€åÿ¥ŸàŸÜÿØ "CheXpert-v1.0-small" ÿß⁄Øÿ± ÿØÿ± ŸÖÿ≥€åÿ± ÿ®ŸàÿØ
        if img_rel_path.startswith("CheXpert-v1.0-small"):
            img_rel_path = img_rel_path[len("CheXpert-v1.0-small")+1:]  # +1 ÿ®ÿ±ÿß€å ÿ≠ÿ∞ŸÅ ÿßÿ≥ŸÑÿ¥ ÿ®ÿπÿØ€å

        # ŸÖÿ≥€åÿ± ⁄©ÿßŸÖŸÑ ÿ™ÿµŸà€åÿ± ÿ®ÿß join ⁄©ÿ±ÿØŸÜ ŸÖÿ≥€åÿ± ÿ±€åÿ¥Ÿá Ÿà ŸÖÿ≥€åÿ± ŸÜÿ≥ÿ®€å ÿßÿµŸÑÿßÿ≠ ÿ¥ÿØŸá
        img_path = os.path.join(self.img_dir, img_rel_path)

        # ÿ®ÿßÿ±⁄Øÿ∞ÿßÿ±€å ÿ™ÿµŸà€åÿ± ÿ®ÿß ÿ™ÿ®ÿØ€åŸÑ ÿ®Ÿá RGB (3 ⁄©ÿßŸÜÿßŸÑŸá)
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"[WARNING] Could not load image: {img_path} -- {e}")
            image = Image.new('RGB', (224, 224), (0, 0, 0))

        if self.transform:
            image = self.transform(image)

        # ⁄Øÿ±ŸÅÿ™ŸÜ ŸÑ€åÿ®ŸÑ‚ÄåŸáÿß Ÿà infer_objects ÿ®ÿ±ÿß€å ÿ¨ŸÑŸà⁄Ø€åÿ±€å ÿßÿ≤ warning
        labels = self.labels_df.iloc[idx][['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']]
        labels = labels.infer_objects(copy=False).fillna(0).values.astype('float32')

        return image, labels



In [10]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [11]:
df = pd.read_csv('/content/drive/MyDrive/chexpert_data_v2/train.csv')
dfvalid = pd.read_csv('/content/drive/MyDrive/chexpert_data_v2/valid.csv')
df_subset = df.sample(frac=0.3, random_state=42).reset_index(drop=True)
df_subset.to_csv("chexpert_30percent.csv", index=False)


In [13]:
csv_path = 'chexpert_30percent.csv'
source_root = '/content/drive/MyDrive/chexpert_data_v2'
target_root = '/content/chexpert_data_v2_selected'
df = pd.read_csv(csv_path)
image_paths = df['Path'].str.replace('CheXpert-v1.0-small/', '', regex=False).tolist()


In [None]:

def copy_file(rel_path):
    src = os.path.join(source_root, rel_path)
    dst = os.path.join(target_root, rel_path)
    os.makedirs(os.path.dirname(dst), exist_ok=True)
    try:
        shutil.copy2(src, dst)
        return True
    except:
        return False

os.makedirs(target_root, exist_ok=True)
with ThreadPoolExecutor(max_workers=8) as executor:
    list(tqdm(executor.map(copy_file, image_paths), total=len(image_paths)))

In [14]:
# ======= DataLoaders =======
train_dataset = CheXpertDataset('chexpert_30percent.csv', target_root, transform)
val_dataset = CheXpertDataset('/content/drive/MyDrive/chexpert_data_v2/valid.csv',
                              '/content/drive/MyDrive/chexpert_data_v2', transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)



In [15]:
model = LRFModel().to(device)
criterion = nn.BCEWithLogitsLoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
scaler = GradScaler()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  scaler = GradScaler()


In [27]:
wandb.init(
    project="chexpert-lrf-vit",
    name="run-vit-lrf-v1",
    config={
        "lr": 1e-4,
        "batch_size": 128,
        "epochs": 10,
        "model": "ViT + LRF",
        "rank": 64
    }
)

In [31]:
def validate(model, dataloader, subset_ratio=0.3):
    model.eval()
    val_loss, all_labels, all_outputs = 0, [], []
    max_batches = int(len(dataloader) * subset_ratio)
    auc_scores = {}

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="üß™ Validating", leave=False)):
            if i > max_batches: break
            images, labels = images.to(device), labels.to(device).float()
            with autocast():
                outputs = model(images)  # shape: [B, 197, 5]
                outputs = outputs.mean(dim=1)  # üîß ŸáŸÖ€åŸÜ ÿÆÿ∑ ⁄©ŸÑ ŸÖÿ¥⁄©ŸÑ ÿ±Ÿà ÿ≠ŸÑ ŸÖ€å‚Äå⁄©ŸÜŸá
                mask = (labels != -1).float()
                loss_raw = criterion(outputs, labels)
                loss = (loss_raw * mask).sum() / mask.sum()
            val_loss += loss.item() * images.size(0)
            all_labels.append(labels.cpu().numpy())
            all_outputs.append(torch.sigmoid(outputs).cpu().numpy())

    avg_loss = val_loss / (max_batches * dataloader.batch_size)

    try:
        all_labels = np.concatenate(all_labels)
        all_outputs = np.concatenate(all_outputs)
        for i, disease in enumerate(['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']):
            try:
                auc = roc_auc_score(all_labels[:, i], all_outputs[:, i])
                auc_scores[disease] = auc
                print(f"‚úÖ AUC {disease}: {auc:.4f}")
            except:
                auc_scores[disease] = float('nan')
                print(f"‚ö†Ô∏è AUC {disease}: Not enough data")
    except:
        print("‚ö†Ô∏è AUC skipped due to shape issues")

    return avg_loss, auc_scores


In [26]:
def validate(model, dataloader, subset_ratio=0.3):
    model.eval()
    val_loss, all_labels, all_outputs = 0, [], []
    max_batches = int(len(dataloader) * subset_ratio)
    auc_scores = {}

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="üß™ Validating", leave=False)):
            if i > max_batches: break
            images, labels = images.to(device), labels.to(device).float()
            with autocast():
                outputs = model(images)
                mask = (labels != -1).float()
                loss_raw = criterion(outputs, labels)
                loss = (loss_raw * mask).sum() / mask.sum()
            val_loss += loss.item() * images.size(0)
            all_labels.append(labels.cpu().numpy())
            all_outputs.append(torch.sigmoid(outputs).cpu().numpy())

    avg_loss = val_loss / (max_batches * dataloader.batch_size)

    try:
        all_labels = np.concatenate(all_labels)
        all_outputs = np.concatenate(all_outputs)
        for i, disease in enumerate(['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']):
            try:
                auc = roc_auc_score(all_labels[:, i], all_outputs[:, i])
                auc_scores[disease] = auc
                print(f"‚úÖ AUC {disease}: {auc:.4f}")
            except:
                auc_scores[disease] = float('nan')
                print(f"‚ö†Ô∏è AUC {disease}: Not enough data")
    except:
        print("‚ö†Ô∏è AUC skipped due to shape issues")

    return avg_loss, auc_scores


In [None]:
best_val_loss = float('inf')

for epoch in range(10):
    print(f"\nüìö Epoch {epoch+1}/10")

    train_loss = train_one_epoch(model, train_loader)
    val_loss, auc_scores = validate(model, val_loader)

    # ÿ∞ÿÆ€åÿ±Ÿá ÿ®Ÿáÿ™ÿ±€åŸÜ ŸÖÿØŸÑ
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pt")
        print(f"üíæ Saved best model at epoch {epoch+1} with val loss {val_loss:.4f}")

    # üìà ŸÑÿß⁄Ø‚Äå⁄Ø€åÿ±€å ÿØÿ± wandb
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "lr": scheduler.get_last_lr()[0],
        **{f"AUC_{k}": v for k, v in auc_scores.items()}
    })

    scheduler.step()



üìö Epoch 1/10


  with autocast():
  with autocast():
  with autocast():


‚úÖ AUC Atelectasis: 0.6877
‚úÖ AUC Cardiomegaly: 0.7956
‚úÖ AUC Consolidation: 0.8729
‚úÖ AUC Edema: 0.8680
‚úÖ AUC Pleural Effusion: 0.8901
üíæ Saved best model at epoch 1 with val loss 0.9185

üìö Epoch 2/10


  with autocast():
  with autocast():
  with autocast():
                                                            

‚úÖ AUC Atelectasis: 0.6156
‚úÖ AUC Cardiomegaly: 0.7951
‚úÖ AUC Consolidation: 0.8631
‚úÖ AUC Edema: 0.8527
‚úÖ AUC Pleural Effusion: 0.8936

üìö Epoch 3/10


  with autocast():
  with autocast():
  with autocast():
                                                            

‚úÖ AUC Atelectasis: 0.7073
‚úÖ AUC Cardiomegaly: 0.7459
‚úÖ AUC Consolidation: 0.8729
‚úÖ AUC Edema: 0.8309
‚úÖ AUC Pleural Effusion: 0.8824

üìö Epoch 4/10


  with autocast():


In [35]:
# import torch
# import gc

# gc.collect()          # ÿ¨ŸÖÿπ‚Äåÿ¢Ÿàÿ±€å ÿ≠ÿßŸÅÿ∏Ÿá ÿ¢ÿ≤ÿßÿØ ÿ¥ÿØŸá ÿØÿ± Ÿæÿß€åÿ™ŸàŸÜ
# torch.cuda.empty_cache()  # ÿ¢ÿ≤ÿßÿØÿ≥ÿßÿ≤€å ⁄©ÿ¥ ÿ≠ÿßŸÅÿ∏Ÿá GPU
