# Data

- [Download corrupted Cifar10 .tar](https://zenodo.org/records/2535967) and extract it into a folder.

- Cifar10 is downloaded using torch.datasets

- [Download corrupted Tiny-ImageNet](https://zenodo.org/records/2536630) and extract it into a folder.

- [Download Tiny-ImageNet](https://www.kaggle.com/datasets/akash2sharma/tiny-imagenet) and extract it into a folder.

# Imports


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import os

from my_transformers import CorruptDistillVisionTransformer
from utils import load_experimental_TinyImageNet

  from .autonotebook import tqdm as notebook_tqdm


# Loading Data

In [2]:
corrupt_types = ["motion_blur", "shot_noise", "jpeg_compression", "fog"]

# corrupt_path = r"C:\Users\Hp\Desktop\Coding\Transformer-Thesis\Tiny-ImageNet-C\Tiny-ImageNet-C"
# normal_path = r"C:\Users\Hp\Desktop\Coding\Transformer-Thesis\Tiny-ImageNet-Normal"
# train_data, test_data = load_experimental_TinyImageNet(normal_path, corrupt_path, corrupt_types, num_train_imgs=20)


# Visualisation TinyImageNet

In [None]:
def tensor_to_img(tensor):
    mean = torch.tensor([0.4802, 0.4481, 0.3975])
    std = torch.tensor([0.2302, 0.2265, 0.2262])
    
    # denormalize
    img = tensor.clone()
    img = img * std[:, None, None] + mean[:, None, None]
    img = torch.clamp(img, 0, 1)
    
    # convert to PIL image
    to_pil = T.ToPILImage()
    return to_pil(img)

# plot
imgs_to_display = [random.randint(0, len(train_data[0])-1) for i in range(9)]
fig, axes = plt.subplots(3, 3, figsize=(4, 4))
axes = axes.flatten()

for i in range(9):
    img, label = train_data[0][imgs_to_display[i]], train_data[2][imgs_to_display[i]].item()
    img = tensor_to_img(img)
    if label == len(corrupt_types):
        axes[i].set_title(f"normal", fontsize=8)
    else:
        axes[i].set_title(corrupt_types[label], fontsize=8)
    axes[i].imshow(img)
    axes[i].axis("off")

plt.tight_layout()
plt.show()

# Model config for TinyImageNet

In [3]:
# setting seed 
torch.cuda.manual_seed(22)
random.seed(22)
torch.manual_seed(22)

device = "cuda" if torch.cuda.is_available() else "cpu"

# Hyper-parameters
PATCH_SIZE = 8
IMG_SIZE = 64
EMBED_DIM = 192
NUM_HEADS = 3
IMG_TYPES = len(corrupt_types)+1
NUM_ENCODERS = 12
NUM_CLASSES = 200
DROPOUT = 0.1
DROP_PATH = 0.1
ERASE_PROB = 0

In [None]:
from train_test_module import LossCalculator
class TrainTestCdeit:
    def __init__(self, model, teacher_model, train_batches, test_batches, head_strategy, num_img_types,
                 device, save_path:str, print_metrics=False, n_bins=15
                 ):
        assert head_strategy > 0 and head_strategy <= 3
        self.model = model.to(device)
        self.teacher_model = teacher_model.to(device)
        self.train_batches = train_batches
        self.test_batches = test_batches
        self.head_strategy = head_strategy
        self.num_img_types = num_img_types
        self.save_path = save_path

        self.device = device
        self.print_metrics = print_metrics
        self.n_bins = n_bins

        self.all_test_metrics, self.all_train_metrics = [], []

    # compute ece function
    def __compute_ece(self, preds, labels):
        preds = torch.softmax(preds, dim=1)
        confidences, preds = preds.max(dim=1)
        acc = preds.eq(labels)
        
        bins = torch.linspace(0,1,self.n_bins+1, device=self.device)
        ece = torch.zeros(1, device=self.device)
        for i in range(self.n_bins):
            mask = confidences.gt(bins[i]) & confidences.le(bins[i+1])
            if mask.sum() > 0:
                ece += (mask.sum().float() / labels.size(0)) * torch.abs(acc[mask].float().mean() - confidences[mask].mean())
        return ece.item()
    
    # testing function
    def test_cdeit_model(self):
        top1_correct_preds, top5_correct_preds = 0, 0
        top1_correct_corrs, total_samples, total_ece = 0, 0, 0

        total_per_type = torch.zeros(self.num_img_types, device=self.device)
        top1_correct_per_type = torch.zeros(self.num_img_types, device=self.device)
        top5_correct_per_type = torch.zeros(self.num_img_types, device=self.device)

        self.model.eval()
        with torch.no_grad():
            for x_batch, y_batch, c_batch in self.test_batches:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                tokens = self.model(x_batch)
                
                if len(tokens) == 3 and self.head_strategy != 1:
                    preds = (tokens[0] + tokens[1] + self.model.output_head.corrupt_connection(tokens[2])) / 3
                else:
                    preds = (tokens[0] + tokens[1]) / 2
                    print("done preds!")
                
                total_samples += y_batch.size(0)
                # top-1 acc
                top1_right = (torch.argmax(preds, dim=1) == y_batch)
                top1_correct_preds += top1_right.sum().item()
                # top-5 acc
                top5_right = torch.topk(preds, 5, dim=1).indices.eq(y_batch.unsqueeze(1)).any(dim=1)
                top5_correct_preds += top5_right.sum().item()
                # ECE loss
                total_ece += self.__compute_ece(preds, y_batch) * y_batch.size(0)
                # Per-type acc
                for t in range(self.num_img_types):
                    mask = (c_batch == t)
                    total_per_type[t] += mask.sum()
                    top1_correct_per_type[t] += top1_right[mask].sum()
                    top5_correct_per_type[t] += top5_right[mask].sum()
                # top-1 corruption classification acc
                if len(tokens) == 3 : top1_correct_corrs += (torch.argmax(tokens[2], dim=1) == c_batch).sum().item()                

        top1_acc_per_type = (top1_correct_per_type / total_per_type).tolist()
        test_metrics = {
            "top1_acc" : top1_correct_preds / total_samples, 
            "top5_acc" : top5_correct_preds / total_samples, 
            "top1_acc_per_type" : top1_acc_per_type,
            "top5_acc_per_type" :(top5_correct_per_type / total_per_type).tolist(), 
            "error_rate_per_type" : [1 - acc for acc in top1_acc_per_type],
            "ece" : total_ece / total_samples
        }
        if top1_correct_corrs : test_metrics["top1_corrupt_acc"] = top1_correct_corrs / total_samples
        if self.print_metrics : print(f"Test-Accuracy:{test_metrics['top1_acc']:.2f}")
        return test_metrics

    # ------------ training function ------------
    def train_cdeit_model(self, optimizer, num_epochs=5):
        self.model.train()
        for epoch in range(1, num_epochs+1):
            loss_total, loss_corrupt, loss_distill, loss_cls = 0, 0, 0, 0
            total_samples, sim_cls_distill, sim_cls_corr = 0, 0, 0

            print(f"------- Epoch {epoch} -------")
            for x_batch, y_batch, c_batch in self.train_batches:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                tokens = self.model(x_batch)
                
                loss_calculater = LossCalculator(self.teacher_model)
                L_corrupt = None
                if len(tokens) == 3:
                    L_total, L_cls, L_distill, L_corrupt = loss_calculater((third_batch, y_batch, c_batch), tokens)
                else:
                    L_total, L_cls, L_distill = loss_calculater((third_batch, y_batch), tokens)

                # backprop
                optimizer.zero_grad()
                L_total.backward()
                optimizer.step()

                total_samples += y_batch.size(0)
                # losses
                loss_total += L_total.item() * y_batch.size(0)
                loss_cls += L_cls.item() * y_batch.size(0)
                loss_distill += L_distill.item() * y_batch.size(0)
                if L_corrupt : loss_corrupt += L_corrupt.item() * y_batch.size(0)

                # cosine similarity
                sim_cls_distill += self.model.sim_cls_distill_end
                sim_cls_corr += self.model.sim_cls_corr_end

            train_metrics = {
                "loss_total": loss_total/total_samples,
                "loss_cls": loss_cls/total_samples,
                "loss_distill": loss_distill/total_samples,
                "sim_cls_distill" : sim_cls_distill/total_samples,
                "sim_cls_corr" : sim_cls_corr/total_samples
            }
            if loss_corrupt : train_metrics["loss_corrupt"] = loss_corrupt/total_samples
            if self.print_metrics : print(f"Train-Loss: {train_metrics['loss_total']:.3f}")
            test_metrics = self.test_cdeit_model()
            self.all_train_metrics.append(train_metrics)
            self.all_test_metrics.append(test_metrics)
        
        # save trained model
        torch.save(self.model.state_dict(), self.save_path)
        print(f"\nModel saved to {self.save_path}")


In [4]:
# train_data = torch.load("train_data.pt", weights_only=True)
# test_data = torch.load("test_data.pt", weights_only=True)

BATCH_SIZE = 128
train_loader = DataLoader(dataset=TensorDataset(*torch.load("train_data.pt", weights_only=True)), 
                                  batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(*torch.load("test_data.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=True)

In [None]:
import timm, torch
# deit3_small_patch16_224.fb_in22k_ft_in1k -- 22M
# convnext_tiny.fb_in22k_ft_in1k -- 28M
teacher_model = timm.create_model('deit3_small_patch16_224.fb_in22k_ft_in1k', pretrained=True).cuda()
teacher_model.head = nn.Linear(in_features=384, out_features=NUM_CLASSES, bias=True).cuda()
num_params = sum(p.numel() for p in teacher_model.parameters())
print(num_params)

21751496


In [6]:
from train_test_module import train_test_teacher_head, test_teacher_head
optimizer = optim.AdamW(teacher_model.parameters(), lr=1e-4)
teacher_metrics = train_test_teacher_head(teacher_model, train_loader, test_loader, 
                                          optimizer, device, save_path="my_model.pth", num_epochs=2, print_metrics=True)

------- Epoch 1 -------
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217,

In [None]:

# --------------------------------------------------------------------------------
tiny_vit_model = CorruptDistillVisionTransformer(
    EMBED_DIM, IMG_SIZE, PATCH_SIZE, NUM_CLASSES, attention_heads=NUM_HEADS,
    num_encoders=NUM_ENCODERS, dropout=DROPOUT, drop_path = DROP_PATH, erase_prob=ERASE_PROB,
    num_img_types=IMG_TYPES, head_strategy=2
    ).to(device)
# --------------------------------------------------------------------------------

train_test_module = TrainTestCdeit(
    tiny_vit_model, teacher_model, first_batch, second_batch, 1, len(corrupt_types)+1,
    device, "savehere"
    )
optimizer = optim.AdamW(tiny_vit_model.parameters(), lr=1e-4)
train_test_module.train_cdeit_model(optimizer, num_epochs=1)

22059496
------- Epoch 1 -------


AssertionError: 

In [None]:
train_test_module.all_train_metrics

[{'loss_total': 0.0,
  'loss_cls': 2.5580086708068848,
  'loss_distill': 5.887204170227051,
  'sim_cls_distill': tensor(5.9064e-05, device='cuda:0', grad_fn=<DivBackward0>),
  'sim_cls_corr': tensor(0.0003, device='cuda:0', grad_fn=<DivBackward0>),
  'loss_corrupt': 1.7868216037750244}]

In [7]:
timm.list_models(pretrained=True)

['aimv2_1b_patch14_224.apple_pt',
 'aimv2_1b_patch14_336.apple_pt',
 'aimv2_1b_patch14_448.apple_pt',
 'aimv2_3b_patch14_224.apple_pt',
 'aimv2_3b_patch14_336.apple_pt',
 'aimv2_3b_patch14_448.apple_pt',
 'aimv2_huge_patch14_224.apple_pt',
 'aimv2_huge_patch14_336.apple_pt',
 'aimv2_huge_patch14_448.apple_pt',
 'aimv2_large_patch14_224.apple_pt',
 'aimv2_large_patch14_224.apple_pt_dist',
 'aimv2_large_patch14_336.apple_pt',
 'aimv2_large_patch14_336.apple_pt_dist',
 'aimv2_large_patch14_448.apple_pt',
 'bat_resnext26ts.ch_in1k',
 'beit_base_patch16_224.in22k_ft_in22k',
 'beit_base_patch16_224.in22k_ft_in22k_in1k',
 'beit_base_patch16_384.in22k_ft_in22k_in1k',
 'beit_large_patch16_224.in22k_ft_in22k',
 'beit_large_patch16_224.in22k_ft_in22k_in1k',
 'beit_large_patch16_384.in22k_ft_in22k_in1k',
 'beit_large_patch16_512.in22k_ft_in22k_in1k',
 'beitv2_base_patch16_224.in1k_ft_in1k',
 'beitv2_base_patch16_224.in1k_ft_in22k',
 'beitv2_base_patch16_224.in1k_ft_in22k_in1k',
 'beitv2_large_patc