# Data

- [Download corrupted Cifar10 .tar](https://zenodo.org/records/2535967) and extract it into a folder.

- Cifar10 is downloaded using torch.datasets

- [Download corrupted Tiny-ImageNet](https://zenodo.org/records/2536630) and extract it into a folder.

- [Download Tiny-ImageNet](https://www.kaggle.com/datasets/akash2sharma/tiny-imagenet) and extract it into a folder.

# Imports


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import os
import timm

from my_transformers import CorruptDistillVisionTransformer
from utils import load_experimental_TinyImageNet

  from .autonotebook import tqdm as notebook_tqdm


# Loading Data

In [8]:
corrupt_types = ["motion_blur", "shot_noise", "jpeg_compression", "fog"]

# corrupt_path = r"C:\Users\Hp\Desktop\Coding\Transformer-Thesis\Tiny-ImageNet-C\Tiny-ImageNet-C"
# normal_path = r"C:\Users\Hp\Desktop\Coding\Transformer-Thesis\Tiny-ImageNet-Normal"
# train_data, test_data = load_experimental_TinyImageNet(normal_path, corrupt_path, corrupt_types, num_train_imgs=20)


# Visualisation TinyImageNet

In [None]:
def tensor_to_img(tensor):
    mean = torch.tensor([0.4802, 0.4481, 0.3975])
    std = torch.tensor([0.2302, 0.2265, 0.2262])
    
    # denormalize
    img = tensor.clone()
    img = img * std[:, None, None] + mean[:, None, None]
    img = torch.clamp(img, 0, 1)
    
    # convert to PIL image
    to_pil = T.ToPILImage()
    return to_pil(img)

# plot
imgs_to_display = [random.randint(0, len(train_data[0])-1) for i in range(9)]
fig, axes = plt.subplots(3, 3, figsize=(4, 4))
axes = axes.flatten()

for i in range(9):
    img, label = train_data[0][imgs_to_display[i]], train_data[2][imgs_to_display[i]].item()
    img = tensor_to_img(img)
    if label == len(corrupt_types):
        axes[i].set_title(f"normal", fontsize=8)
    else:
        axes[i].set_title(corrupt_types[label], fontsize=8)
    axes[i].imshow(img)
    axes[i].axis("off")

plt.tight_layout()
plt.show()

# Model config for TinyImageNet

In [2]:
# setting seed 
torch.cuda.manual_seed(22)
random.seed(22)
torch.manual_seed(22)

device = "cuda" if torch.cuda.is_available() else "cpu"

corrupt_types = ["motion_blur", "shot_noise", "jpeg_compression", "fog"]

# Hyper-parameters
PATCH_SIZE = 8
IMG_SIZE = 64
EMBED_DIM = 192
NUM_HEADS = 3
NUM_IMG_TYPES = len(corrupt_types)+1
NUM_ENCODERS = 12
NUM_CLASSES = 200
DROPOUT = 0
DROP_PATH = 0.1
ERASE_PROB = 0.25

In [3]:
# train_data = torch.load("train_data.pt", weights_only=True)
# test_data = torch.load("test_data.pt", weights_only=True)

BATCH_SIZE = 128
# train_loader = DataLoader(dataset=TensorDataset(*torch.load("train_data.pt", weights_only=True)), 
#                                   batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(*torch.load("test_data.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=True)
itr = iter(test_loader)
train_batch = [next(itr)]
test_batch = [next(itr)]

# deit3_small_patch16_224.fb_in22k_ft_in1k -- 22M
# convnext_tiny.fb_in22k_ft_in1k -- 28M
deit3_small = timm.create_model('deit3_small_patch16_224.fb_in22k_ft_in1k', pretrained=True).cuda()
# teacher_model = timm.create_model('vit_small_patch16_224.augreg_in21k_ft_in1k', pretrained=True).cuda()
deit3_small.head = nn.Linear(in_features=384, out_features=NUM_CLASSES, bias=True).cuda()
num_params = sum(p.numel() for p in deit3_small.parameters())
print(num_params)
print(deit3_small.head)

21751496
Linear(in_features=384, out_features=200, bias=True)


In [4]:
from train_test_module import compute_ece
import json

class TrainTestBaseline:
    def __init__(self, model, train_batches, test_batches, num_img_types, device):
        self.model = model.to(device)
        self.train_batches = train_batches
        self.test_batches = test_batches
        self.num_img_types = num_img_types
        self.device = device

        self.all_test_metrics, self.all_train_metrics = [], []

    def test(self, print_metrics=False):
        top1_correct_preds, top5_correct_preds = 0, 0
        total_samples, total_ece, total_entropy = 0, 0, 0

        total_per_type = torch.zeros(self.num_img_types, device=self.device)
        top1_correct_per_type = torch.zeros(self.num_img_types, device=self.device)
        top5_correct_per_type = torch.zeros(self.num_img_types, device=self.device)

        self.model.eval()
        with torch.no_grad():
            for x_batch, y_batch, c_batch in self.test_batches:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                x_batch = F.interpolate(x_batch, size=(224, 224), mode='bilinear', align_corners=False)
                preds = self.model(x_batch)
                del x_batch 
                
                total_samples += y_batch.size(0)
                # top-1 acc
                top1_right = (torch.argmax(preds, dim=1) == y_batch)
                top1_correct_preds += top1_right.sum().item()
                # top-5 acc
                top5_right = torch.topk(preds, 5, dim=1).indices.eq(y_batch.unsqueeze(1)).any(dim=1)
                top5_correct_preds += top5_right.sum().item()
                # ECE loss
                total_ece += compute_ece(preds, y_batch, self.device)
                # Per-type acc
                for t in range(self.num_img_types):
                    mask = (c_batch == t)
                    total_per_type[t] += mask.sum()
                    top1_correct_per_type[t] += top1_right[mask].sum()
                    top5_correct_per_type[t] += top5_right[mask].sum()
                # entropy
                total_entropy += -torch.sum(torch.softmax(preds, dim=1) * torch.log_softmax(preds, dim=1), dim=1).sum().item()
                del y_batch, c_batch
                
        top1_acc_per_type = (top1_correct_per_type / total_per_type).tolist()
        test_metrics = {
            "top1_acc" : top1_correct_preds / total_samples, 
            "top5_acc" : top5_correct_preds / total_samples, 
            "top1_acc_per_type" : top1_acc_per_type,
            "top5_acc_per_type" :(top5_correct_per_type / total_per_type).tolist(), 
            "error_rate_per_type" : [1 - acc for acc in top1_acc_per_type],
            "ece" : total_ece / total_samples,
            "entropy" : total_entropy / total_samples
        }
        if print_metrics : print(f"Test-Accuracy:{test_metrics['top1_acc']:.2f}")
        return test_metrics
    
    def train(self, optimizer, scheduler, save_path, num_epochs=1, label_smoothing=0.1, print_metrics=False):
        best_acc, epochs_no_improve, min_delta, patience = 0, 0, 0.003, 4
        self.model.train()
        for epoch in range(1, num_epochs+1):
            print(f"------- Epoch {epoch} -------")
            total_samples, top1_correct_preds, loss_total = 0, 0, 0

            for x_batch, y_batch, _ in self.test_batches:
                x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
                x_batch = F.interpolate(x_batch, size=(224, 224), mode='bilinear', align_corners=False)
                preds = self.model(x_batch)

                del x_batch                 # free memory
                total_samples += y_batch.size(0)
                # top-1 acc
                top1_correct_preds += (torch.argmax(preds, dim=1) == y_batch).sum().item()
                #loss
                loss = F.cross_entropy(preds, y_batch, label_smoothing=label_smoothing)
                # backprop
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step() 

                loss_total += loss.item() * y_batch.size(0)
                del y_batch, preds, loss            # free memory

            test_metrics = self.test()
            train_metrics = {
                "loss_total": loss_total/total_samples,
                "top1_acc" : top1_correct_preds / total_samples
                }
            current_acc = test_metrics['top1_acc']
            if print_metrics : 
                print(f"train-loss: {train_metrics['loss_total']:.2f} -- train-acc: {train_metrics['top1_acc']:.2f} -- "
                      f"train-acc: {current_acc:.2f}"
                      )
            self.all_train_metrics.append(train_metrics)
            self.all_test_metrics.append(test_metrics)
            # early stopping
            if current_acc - best_acc > min_delta:
                best_acc = current_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} — no improvement in {patience} epochs.")
                break
        
        # save trained model and metrics
        if save_path:
            torch.save(self.model.state_dict(), f"{save_path}.pth")
            print(f"Model saved to {save_path}")
            
            with open(f"{save_path}_train_metrics.json", "w") as f1:
                json.dump(self.all_train_metrics, f1, indent=4)
            with open(f"{save_path}_test_metrics.json", "w") as f2:
                json.dump(self.all_test_metrics, f2, indent=4)

In [None]:
# === LR Scheduler ===
NUM_EPOCHS = 50
WARMUP_EPOCHS = 3
total_steps = NUM_EPOCHS * len(train_loader)
warmup_steps = WARMUP_EPOCHS * len(train_loader)

optimizer = optim.AdamW(deit3_small.parameters(), lr=5e-5, weight_decay=0.05)
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-5, total_iters=warmup_steps)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS-WARMUP_EPOCHS)
scheduler = optim.lr_scheduler.SequentialLR(
    optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps]
)

In [11]:
from train_test_module import LossCalculator
class TrainTestStudentDeiT:
    def __init__(self, model, train_batches, test_batches, head_strategy, num_img_types,
                 device, save_path:str, teacher_model=None
                 ):
        assert head_strategy > 0 and head_strategy <= 3
        self.model = model.to(device)
        if teacher_model: self.teacher_model = teacher_model.to(device)
        self.train_batches = train_batches
        self.test_batches = test_batches
        self.head_strategy = head_strategy
        self.num_img_types = num_img_types
        self.save_path = save_path
        self.device = device

        self.all_test_metrics, self.all_train_metrics = [], []
    
    # testing function
    def test(self, print_metrics=False):
        top1_correct_preds, top5_correct_preds = 0, 0
        top1_correct_corruptions, total_samples, total_ece, total_entropy = 0, 0, 0, 0

        total_per_type = torch.zeros(self.num_img_types, device=self.device)
        top1_correct_per_type = torch.zeros(self.num_img_types, device=self.device)
        top5_correct_per_type = torch.zeros(self.num_img_types, device=self.device)

        self.model.eval()
        with torch.no_grad():
            for x_batch, y_batch, c_batch in self.test_batches:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                tokens = self.model(x_batch)
                del x_batch
                
                if len(tokens) == 3 and self.head_strategy != 1:
                    preds = (tokens[0] + tokens[1] + self.model.output_head.corrupt_connection(tokens[2])) / 3
                else:
                    preds = (tokens[0] + tokens[1]) / 2
                    print(tokens[0].shape, tokens[1].shape, preds.shape)
                
                total_samples += y_batch.size(0)
                # top-1 acc
                top1_right = (torch.argmax(preds, dim=1) == y_batch)
                top1_correct_preds += top1_right.sum().item()
                # top-5 acc
                top5_right = torch.topk(preds, 5, dim=1).indices.eq(y_batch.unsqueeze(1)).any(dim=1)
                top5_correct_preds += top5_right.sum().item()
                # ECE loss
                total_ece += compute_ece(preds, y_batch, self.device)
                # Per-type acc
                for t in range(self.num_img_types):
                    mask = (c_batch == t)
                    total_per_type[t] += mask.sum()
                    top1_correct_per_type[t] += top1_right[mask].sum()
                    top5_correct_per_type[t] += top5_right[mask].sum()
                # entropy
                total_entropy += -torch.sum(torch.softmax(preds, dim=1) * torch.log_softmax(preds, dim=1), dim=1).sum().item()
                # top-1 corruption classification acc
                if len(tokens) == 3 : top1_correct_corruptions += (torch.argmax(tokens[2], dim=1) == c_batch).sum().item()              
                del y_batch, c_batch, preds

        top1_acc_per_type = (top1_correct_per_type / total_per_type).tolist()
        test_metrics = {
            "top1_acc" : top1_correct_preds / total_samples, 
            "top5_acc" : top5_correct_preds / total_samples, 
            "top1_acc_per_type" : top1_acc_per_type,
            "top5_acc_per_type" :(top5_correct_per_type / total_per_type).tolist(), 
            "error_rate_per_type" : [1 - acc for acc in top1_acc_per_type],
            "ece" : total_ece / total_samples,
            "entropy" : total_entropy / total_samples
        }
        if top1_correct_corruptions : test_metrics["top1_corrupt_acc"] = top1_correct_corruptions / total_samples
        if print_metrics : print(f"Test-Accuracy:{test_metrics['top1_acc']:.2f}")
        return test_metrics

    # ------------ training function ------------
    def train_cdeit_model(self, optimizer, num_epochs=1):
        self.model.train()
        for epoch in range(1, num_epochs+1):
            loss_total, loss_corrupt, loss_distill, loss_cls = 0, 0, 0, 0
            total_samples, sim_cls_distill, sim_cls_corr = 0, 0, 0

            print(f"------- Epoch {epoch} -------")
            for x_batch, y_batch, c_batch in self.train_batches:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                tokens = self.model(x_batch)
                
                loss_calculater = LossCalculator(self.teacher_model)
                L_corrupt = None
                if len(tokens) == 3:
                    L_total, L_cls, L_distill, L_corrupt = loss_calculater((third_batch, y_batch, c_batch), tokens)
                else:
                    L_total, L_cls, L_distill = loss_calculater((third_batch, y_batch), tokens)

                # backprop
                optimizer.zero_grad()
                L_total.backward()
                optimizer.step()

                total_samples += y_batch.size(0)
                # losses
                loss_total += L_total.item() * y_batch.size(0)
                loss_cls += L_cls.item() * y_batch.size(0)
                loss_distill += L_distill.item() * y_batch.size(0)
                if L_corrupt : loss_corrupt += L_corrupt.item() * y_batch.size(0)

                # cosine similarity
                sim_cls_distill += self.model.sim_cls_distill_end
                sim_cls_corr += self.model.sim_cls_corr_end

            train_metrics = {
                "loss_total": loss_total/total_samples,
                "loss_cls": loss_cls/total_samples,
                "loss_distill": loss_distill/total_samples,
                "sim_cls_distill" : sim_cls_distill/total_samples,
                "sim_cls_corr" : sim_cls_corr/total_samples
            }
            if loss_corrupt : train_metrics["loss_corrupt"] = loss_corrupt/total_samples
            if self.print_metrics : print(f"Train-Loss: {train_metrics['loss_total']:.3f}")
            test_metrics = self.test_cdeit_model()
            self.all_train_metrics.append(train_metrics)
            self.all_test_metrics.append(test_metrics)
        
        # save trained model
        torch.save(self.model.state_dict(), self.save_path)
        print(f"\nModel saved to {self.save_path}")


In [12]:
from my_transformers import DistillVisionTransformer

deit_small = DistillVisionTransformer(
    EMBED_DIM, IMG_SIZE, PATCH_SIZE, NUM_CLASSES, attention_heads=NUM_HEADS,
    num_encoders=NUM_ENCODERS, dropout=DROPOUT, drop_path=DROP_PATH, erase_prob=ERASE_PROB 
    )

train_student_module = TrainTestStudentDeiT(deit_small, train_batch, test_batch, 1, NUM_IMG_TYPES, device, "")
train_student_module.test()

torch.Size([200]) torch.Size([200]) torch.Size([200])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)