# Experiment: ViTDEIT Baseline for RARE25 Challenge

This notebook trains and validates a `deit_base_patch16_224` Vision Transformer on the RARE25 dataset, following the same structure as the ViT baseline experiment. The goal is to optimize for PPV@90Recall on a highly imbalanced binary classification task.

## Outline
1. Import Required Libraries
2. Load and Preprocess Dataset
3. Initialize vit_large_patch14_dinov2.lvd142m Model
4. Configure Training Parameters
5. Train the Model
6. Validate the Model
7. Visualize Training and Validation Metrics


In [1]:
# Import Required Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import timm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve


  from .autonotebook import tqdm as notebook_tqdm


# Load and Preprocess Dataset
Load the RARE25 dataset and apply preprocessing/augmentations for DeiT input (224x224).

In [10]:
# Load DeiT-Base model for binary classification
model = timm.create_model('deit_base_patch16_224', pretrained=True, num_classes=2)

data_config = timm.data.resolve_model_data_config(model)
transform = timm.data.create_transform(**data_config, is_training=True)

print('Transforms set up using timm for Deit ViT:', transform) 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Transforms set up using timm for Deit ViT: Compose(
    RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic)
    RandomHorizontalFlip(p=0.5)
    ColorJitter(brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=None)
    MaybeToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False

In [12]:
# Load dataset (HuggingFace or local CSV)
from datasets import load_dataset

ds = load_dataset("TimJaspersTue/RARE25-train")
df = ds['train'].to_pandas()

# Show class distribution
print('Class distribution:')
print(df['label'].value_counts())


# Custom dataset class
class ImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        img_dict = self.df.iloc[idx]['image']
        if 'array' in img_dict:
            image = Image.fromarray(np.array(img_dict['array']))
        else:
            image = Image.open(img_dict['path'])
        label = self.df.iloc[idx]['label']
        if self.transform:
            image = self.transform(image)
        return image, label

# Split data
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)
train_transform = timm.data.create_transform(**data_config, is_training=True)
val_transform = timm.data.create_transform(**data_config, is_training=False) #Validation loader should not be sampled; keep shuffle=False so evaluation reflects real imbalance.

train_dataset = ImageDataset(train_df, transform=train_transform)
val_dataset = ImageDataset(val_df, transform=val_transform)



Class distribution:
label
0    2937
1     158
Name: count, dtype: int64


In [13]:

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler

# --- STEP 1: Compute class weights based on inverse frequency ---
# Class distribution: 0 = non-dysplastic (2937), 1 = neoplasia (158)
class_counts = train_df['label'].value_counts().to_dict()
# Inverse frequency weighting: rarer classes get higher weights
class_weights = {cls: 1.0/count for cls, count in class_counts.items()}

# Assign a weight to each sample in the training set
sample_weights = train_df['label'].map(class_weights).values
sample_weights = torch.DoubleTensor(sample_weights)

# --- STEP 2: Create the WeightedRandomSampler ---
# This makes the DataLoader oversample the minority class (neoplasia),
# so the model sees more positive examples during training.
sampler = WeightedRandomSampler(
    weights=sample_weights, 
    num_samples=len(sample_weights),  # number of samples per epoch
    replacement=True                  # allows resampling for balance
)

# --- STEP 3: Use sampler instead of shuffle in DataLoader ---
train_dataset = ImageDataset(train_df, transform=train_transform)
val_dataset = ImageDataset(val_df, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print("Using WeightedRandomSampler to rebalance training batches.")


sample_weights = train_df['label'].map(class_weights).values
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)



Using WeightedRandomSampler to rebalance training batches.


# Configure Training Parameters
Set up optimizer, Focal Loss, and learning rate scheduler.

In [None]:
# Focal Loss for binary classification
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction="mean"):
        """
        :param gamma: focusing parameter (higher → more focus on hard examples)
        :param alpha: class weight balancing (float or list of floats). 
                      If None, no weighting is applied.
        :param reduction: 'mean', 'sum', or 'none'
        """
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        if isinstance(alpha, (list, torch.Tensor)):
            self.alpha = torch.tensor(alpha, dtype=torch.float32)
        else:
            self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        # inputs: logits of shape (batch, num_classes)
        # targets: ground-truth labels of shape (batch,)
        log_probs = F.log_softmax(inputs, dim=1)
        probs = torch.exp(log_probs)

        # gather log_probs corresponding to targets
        ce_loss = F.nll_loss(log_probs, targets, reduction="none")
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)

        if self.alpha is not None:
            if isinstance(self.alpha, torch.Tensor):
                at = self.alpha.to(inputs.device).gather(0, targets)
            else:
                at = self.alpha
            ce_loss = ce_loss * at

        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        else:
            return focal_loss


criterion = FocalLoss(gamma=2.0, alpha=[0.25, 0.75]) 
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)



# Train the Model
Train DeiT-Base on the training set and track loss/accuracy.

In [17]:
import copy
import numpy as np
from sklearn.metrics import precision_recall_curve

# --- Helper function: Compute PPV at a fixed recall level ---
def ppv_at_recall(y_true, y_probs, recall_level=0.9):
    """
    Compute Positive Predictive Value (Precision) at given Recall.
    - y_true: ground truth labels (0/1)
    - y_probs: predicted probabilities for class 1
    - recall_level: desired recall threshold (default 0.9 = 90%)
    """
    precision, recall, thresholds = precision_recall_curve(y_true, y_probs)
    try:
        idx = np.where(recall >= recall_level)[0][-1]  # last index where recall >= recall_level
        return precision[idx], thresholds[idx]
    except IndexError:
        return 0.0, 0.5  # if recall never reaches recall_level

# --- Training loop with PPV@90Recall checkpointing ---
num_epochs = 20
best_ppv = 0.0
best_state = None

for epoch in range(num_epochs):
    # Training
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)  # Focal Loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    all_labels = []
    all_probs = []
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)

            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            all_probs.extend(probs)
            all_labels.extend(labels.cpu().numpy())

    val_epoch_loss = val_loss / len(val_loader.dataset)
    ppv, threshold = ppv_at_recall(np.array(all_labels), np.array(all_probs), recall_level=0.9)

    print(f"Epoch {epoch+1}: Train Loss={epoch_loss:.4f}, Val Loss={val_epoch_loss:.4f}, PPV@90Recall={ppv:.4f}")

    # Save checkpoint based on PPV@90Recall
    if ppv > best_ppv:
        best_ppv = ppv
        best_state = copy.deepcopy(model.state_dict())
        torch.save(best_state, "best_deit_model_ppv.pth")
        print(f"✅ Saved new best model with PPV@90Recall={best_ppv:.4f}")

print(f"Best PPV@90Recall achieved: {best_ppv:.4f}")


Training Epoch 1/20: 100%|██████████| 78/78 [06:30<00:00,  5.01s/it]



Epoch 1: Train Loss=0.0523, Val Loss=0.0202, PPV@90Recall=0.1312
✅ Saved new best model with PPV@90Recall=0.1312
✅ Saved new best model with PPV@90Recall=0.1312


Training Epoch 2/20: 100%|██████████| 78/78 [06:37<00:00,  5.09s/it]



Epoch 2: Train Loss=0.0304, Val Loss=0.0193, PPV@90Recall=0.1160


Training Epoch 3/20: 100%|██████████| 78/78 [06:39<00:00,  5.12s/it]



Epoch 3: Train Loss=0.0227, Val Loss=0.0174, PPV@90Recall=0.1585
✅ Saved new best model with PPV@90Recall=0.1585
✅ Saved new best model with PPV@90Recall=0.1585


Training Epoch 4/20: 100%|██████████| 78/78 [06:36<00:00,  5.08s/it]



Epoch 4: Train Loss=0.0179, Val Loss=0.0271, PPV@90Recall=0.2148
✅ Saved new best model with PPV@90Recall=0.2148
✅ Saved new best model with PPV@90Recall=0.2148


Training Epoch 5/20: 100%|██████████| 78/78 [06:37<00:00,  5.10s/it]



Epoch 5: Train Loss=0.0193, Val Loss=0.0202, PPV@90Recall=0.1648


Training Epoch 6/20: 100%|██████████| 78/78 [06:34<00:00,  5.06s/it]



Epoch 6: Train Loss=0.0159, Val Loss=0.0206, PPV@90Recall=0.3412
✅ Saved new best model with PPV@90Recall=0.3412
✅ Saved new best model with PPV@90Recall=0.3412


Training Epoch 7/20: 100%|██████████| 78/78 [06:35<00:00,  5.07s/it]



Epoch 7: Train Loss=0.0129, Val Loss=0.0227, PPV@90Recall=0.2397


Training Epoch 8/20: 100%|██████████| 78/78 [07:21<00:00,  5.66s/it]



Epoch 8: Train Loss=0.0147, Val Loss=0.0263, PPV@90Recall=0.1779


Training Epoch 9/20: 100%|██████████| 78/78 [06:31<00:00,  5.02s/it]



Epoch 9: Train Loss=0.0140, Val Loss=0.0196, PPV@90Recall=0.2148


Training Epoch 10/20: 100%|██████████| 78/78 [06:32<00:00,  5.04s/it]



Epoch 10: Train Loss=0.0115, Val Loss=0.0193, PPV@90Recall=0.3295


Training Epoch 11/20: 100%|██████████| 78/78 [1:09:26<00:00, 53.42s/it]   



Epoch 11: Train Loss=0.0119, Val Loss=0.0231, PPV@90Recall=0.2148


Training Epoch 12/20: 100%|██████████| 78/78 [1:31:48<00:00, 70.62s/it]   



Epoch 12: Train Loss=0.0136, Val Loss=0.0183, PPV@90Recall=0.2148


Training Epoch 13/20: 100%|██████████| 78/78 [48:23<00:00, 37.23s/it]   



Epoch 13: Train Loss=0.0105, Val Loss=0.0183, PPV@90Recall=0.1518


Training Epoch 14/20: 100%|██████████| 78/78 [06:30<00:00,  5.01s/it]



Epoch 14: Train Loss=0.0094, Val Loss=0.0327, PPV@90Recall=0.2843


Training Epoch 15/20: 100%|██████████| 78/78 [11:47<00:00,  9.07s/it]  



Epoch 15: Train Loss=0.0087, Val Loss=0.0333, PPV@90Recall=0.1847


Training Epoch 16/20: 100%|██████████| 78/78 [23:48<00:00, 18.31s/it] 



Epoch 16: Train Loss=0.0080, Val Loss=0.0262, PPV@90Recall=0.2900


Training Epoch 17/20: 100%|██████████| 78/78 [1:09:37<00:00, 53.56s/it]   



Epoch 17: Train Loss=0.0096, Val Loss=0.0325, PPV@90Recall=0.2148


Training Epoch 18/20: 100%|██████████| 78/78 [59:49<00:00, 46.02s/it]   



Epoch 18: Train Loss=0.0105, Val Loss=0.0239, PPV@90Recall=0.2437


Training Epoch 19/20: 100%|██████████| 78/78 [34:29<00:00, 26.53s/it]   



Epoch 19: Train Loss=0.0094, Val Loss=0.0238, PPV@90Recall=0.3452
✅ Saved new best model with PPV@90Recall=0.3452
✅ Saved new best model with PPV@90Recall=0.3452


Training Epoch 20/20: 100%|██████████| 78/78 [23:07<00:00, 17.79s/it]   



Epoch 20: Train Loss=0.0071, Val Loss=0.0228, PPV@90Recall=0.3372
Best PPV@90Recall achieved: 0.3452


In [18]:
import shutil
shutil.copy('best_deit_model_ppv.pth', 'resources/vit_beit_finetuned.pth')

'resources/vit_beit_finetuned.pth'

# Validate the Model
Evaluate the trained model on the validation set and compute PPV@90Recall.

In [None]:
# Compute PPV@90Recall on validation set
def ppv_at_recall(y_true, y_scores, recall_level=0.9):
    precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
    try:
        idx = next(i for i, r in enumerate(recall) if r < recall_level) - 1
    except StopIteration:
        idx = len(recall) - 1
    ppv = precision[idx]
    threshold = thresholds[idx] if idx < len(thresholds) else 1.0
    return ppv, threshold

model.eval()
all_labels = []
all_probs = []
with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
        all_probs.extend(probs)
        all_labels.extend(labels.numpy())
ppv, threshold = ppv_at_recall(np.array(all_labels), np.array(all_probs), recall_level=0.9)
print(f'PPV at 90% recall: {ppv:.4f} (threshold: {threshold:.4f})')


# Visualize Training and Validation Metrics
Plot loss and accuracy curves to analyze model performance.