# imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import os
import timm

from my_transformers import CorruptDistillVisionTransformer
from utils import load_experimental_TinyImageNet
from train_test_module import compute_ece
import json

%load_ext autoreload
%autoreload 2



In [2]:
class TrainTestBaseline:
    def __init__(self, model:nn.Module, train_loader, test_loader,
                 num_img_types, batch_size, device
                 ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.num_img_types = num_img_types
        self.device = device
        self.batch_size = batch_size

        for param in model.parameters(): param.requires_grad = False
        for param in model.head.parameters(): param.requires_grad = True

        self.all_test_metrics, self.all_train_metrics = [], []

    def test(self, print_metrics=False):
        top1_correct_preds, top5_correct_preds = 0, 0
        total_samples, total_ece, total_entropy = 0, 0, 0

        total_per_type = torch.zeros(self.num_img_types, device=self.device)
        top1_correct_per_type = torch.zeros(self.num_img_types, device=self.device)
        top5_correct_per_type = torch.zeros(self.num_img_types, device=self.device)

        self.model.eval()
        with torch.no_grad():
            for x_batch, y_batch, c_batch in self.test_loader:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                x_batch = F.interpolate(x_batch, size=(224, 224), mode='bilinear', align_corners=False)
                preds = self.model(x_batch)
                del x_batch 
                
                total_samples += y_batch.size(0)
                # top-1 acc
                top1_right = (torch.argmax(preds, dim=1) == y_batch)
                top1_correct_preds += top1_right.sum().item()
                # top-5 acc
                top5_right = torch.topk(preds, 5, dim=1).indices.eq(y_batch.unsqueeze(1)).any(dim=1)
                top5_correct_preds += top5_right.sum().item()
                # ECE loss
                total_ece += compute_ece(preds, y_batch, self.device)
                # Per-type acc
                for t in range(self.num_img_types):
                    mask = (c_batch == t)
                    total_per_type[t] += mask.sum()
                    top1_correct_per_type[t] += top1_right[mask].sum()
                    top5_correct_per_type[t] += top5_right[mask].sum()
                # entropy
                total_entropy += -torch.sum(torch.softmax(preds, dim=1) * torch.log_softmax(preds, dim=1), dim=1).sum().item()
                del y_batch, c_batch
                
        top1_acc_per_type = (top1_correct_per_type / total_per_type).tolist()
        test_metrics = {
            "top1_acc" : top1_correct_preds / total_samples, 
            "top5_acc" : top5_correct_preds / total_samples, 
            "top1_acc_per_type" : top1_acc_per_type,
            "top5_acc_per_type" :(top5_correct_per_type / total_per_type).tolist(), 
            "error_rate_per_type" : [1 - acc for acc in top1_acc_per_type],
            "ece" : total_ece / total_samples,
            "entropy" : total_entropy / total_samples
        }
        if print_metrics : print(f"Test-Accuracy:{test_metrics['top1_acc']:.2f}")
        return test_metrics
    
    def train(self, optimizer, scheduler, augmenter, save_path, 
              num_epochs=1, label_smoothing=0.1, erasing_p=0, print_metrics=False
              ):
        erase = T.RandomErasing(p=erasing_p)
        mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
        std  = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
        best_acc, epochs_no_improve, min_delta, patience = 0, 0, 0.003, 5

        for epoch in range(1, num_epochs+1):
            self.model.train()
            print(f"------- Epoch {epoch} -------")
            total_samples, top1_correct_preds, loss_total = 0, 0, 0

            for x_batch, y_batch, _ in self.train_loader:
                x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
                x_batch, y_batch = augmenter(x_batch, y_batch)
                x_batch = F.interpolate(x_batch, size=(224, 224), mode='bilinear', align_corners=False)
                x_batch = (x_batch - mean) / std
                x_batch = torch.stack([erase(img) for img in x_batch])
                preds = self.model(x_batch)

                del x_batch                 # free memory
                total_samples += y_batch.size(0)
                # top-1 acc.  # if coming from mixup/cutmix
                if len(y_batch.shape)==2 : top1_correct_preds += (torch.argmax(preds, dim=1) == torch.argmax(y_batch, dim=-1)).sum().item()
                else : top1_correct_preds += (torch.argmax(preds, dim=1) == y_batch).sum().item()
                #loss   # ybatch can be passed in directly even if cutmix/mixup applied
                loss = F.cross_entropy(preds, y_batch, label_smoothing=label_smoothing)
                # backprop
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step() 

                loss_total += loss.item() * y_batch.size(0)
                del y_batch, preds, loss            # free memory

            test_metrics = self.test()
            train_metrics = {
                "loss_total": loss_total/total_samples,
                "top1_acc" : top1_correct_preds / total_samples
                }
            current_acc = test_metrics['top1_acc']
            if print_metrics : 
                print(f"train-loss: {train_metrics['loss_total']:.2f} -- train-acc: {train_metrics['top1_acc']:.2f} -- "
                      f"test-acc: {current_acc:.2f}"
                      )
            self.all_train_metrics.append(train_metrics)
            self.all_test_metrics.append(test_metrics)

            # early stopping
            if current_acc - best_acc > min_delta:
                best_acc = current_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} — no improvement in {patience} epochs.")
                break
        
        # save trained model and metrics
        if save_path:
            torch.save(self.model.state_dict(), f"{save_path}.pth")
            print(f"Model saved to {save_path}")
            
            with open(f"{save_path}_train_metrics.json", "w") as f1:
                json.dump(self.all_train_metrics, f1, indent=4)
            with open(f"{save_path}_test_metrics.json", "w") as f2:
                json.dump(self.all_test_metrics, f2, indent=4)

# ------------------- Template --------------------------

In [3]:
# setting seed 
torch.cuda.manual_seed(22)
random.seed(22)
torch.manual_seed(22)

device = "cuda" if torch.cuda.is_available() else "cpu"

corrupt_types = ["motion_blur", "shot_noise", "jpeg_compression", "fog"]

# Hyper-parameters
NUM_IMG_TYPES = len(corrupt_types)+1
NUM_CLASSES = 200
DROPOUT = 0
DROP_PATH = 0.1

ERASE_P = 0.25
RANDAUG_P = 0.5
MIXUP_P = 0.8
CUTMIX_P = 1
AUGMIX_P = 0


BATCH_SIZE = 256
NUM_EPOCHS = 50
WARMUP_EPOCHS = 3

In [4]:
train_loader = DataLoader(dataset=TensorDataset(*torch.load("train_data.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(*torch.load("test_data.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=False)
# itr = iter(test_loader)
# train_batch = [next(itr)]
# test_batch = [next(itr)]

# deit3_small_patch16_224.fb_in22k_ft_in1k -- 22M
deit3_small = timm.create_model('deit3_small_patch16_224.fb_in22k_ft_in1k', pretrained=True).cuda()
deit3_small.head = nn.Linear(in_features=384, out_features=NUM_CLASSES, bias=True).cuda()

In [5]:
from timm.models.layers import DropPath
def set_drop_path(model, drop_path):
    for i in range(len(model.blocks)):
        model.blocks[i].drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        model.blocks[i].drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()

In [6]:
# === LR Scheduler ===
total_steps = NUM_EPOCHS * len(train_loader)
warmup_steps = WARMUP_EPOCHS * len(train_loader)

learning_rates = [5e-5, 1e-5, 5e-6]
optimizer = optim.AdamW(deit3_small.parameters(), lr=5e-4, weight_decay=0.03)
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-5, total_iters=warmup_steps)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS-WARMUP_EPOCHS)
scheduler = optim.lr_scheduler.SequentialLR(
    optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps]
)

from train_test_module import MyAugments
augmenter = MyAugments(NUM_CLASSES) # , mixup_p=MIXUP_P, randaug_p=RANDAUG_P, cutmix_p=CUTMIX_P, augmix_p=AUGMIX_P

# Experiment 1

In [None]:
baseline_module = TrainTestBaseline(deit3_small, train_loader, test_loader, NUM_IMG_TYPES, BATCH_SIZE, device)
baseline_module.train(optimizer, scheduler, augmenter, "deit3-ex1", num_epochs=NUM_EPOCHS, erasing_p=ERASE_P, print_metrics=True)


------- Epoch 1 -------
train-loss: 4.73 -- train-acc: 0.19 -- test-acc: 0.41
------- Epoch 2 -------


In [None]:
baseline_module.test()

{'top1_acc': 0.49425,
 'top5_acc': 0.6808,
 'top1_acc_per_type': [0.40549999475479126,
  0.46000000834465027,
  0.5404999852180481,
  0.33125001192092896,
  0.734000027179718],
 'top5_acc_per_type': [0.6122499704360962,
  0.6582499742507935,
  0.7310000061988831,
  0.5212500095367432,
  0.8812500238418579],
 'error_rate_per_type': [0.5945000052452087,
  0.5399999916553497,
  0.4595000147819519,
  0.668749988079071,
  0.265999972820282],
 'ece': 0.00032730889059603215,
 'entropy': 2.319790731048584}