# Data

- [Download corrupted Cifar10 .tar](https://zenodo.org/records/2535967) and extract it into a folder.

- Cifar10 is downloaded using torch.datasets

- [Download corrupted Tiny-ImageNet](https://zenodo.org/records/2536630) and extract it into a folder.

- [Download Tiny-ImageNet](https://www.kaggle.com/datasets/akash2sharma/tiny-imagenet) and extract it into a folder.

# Imports


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import os
import timm

from my_transformers import CorruptDistillVisionTransformer
from utils import load_experimental_TinyImageNet
from train_test_module import compute_ece
import json

  from .autonotebook import tqdm as notebook_tqdm


# Loading Data

In [None]:
corrupt_types = ["motion_blur", "shot_noise", "jpeg_compression", "fog"]

# corrupt_path = r"C:\Users\Hp\Desktop\Coding\Transformer-Thesis\Tiny-ImageNet-C\Tiny-ImageNet-C"
# normal_path = r"C:\Users\Hp\Desktop\Coding\Transformer-Thesis\Tiny-ImageNet-Normal"
# train_data, test_data = load_experimental_TinyImageNet(normal_path, corrupt_path, corrupt_types, num_train_imgs=20)

# torch.save(train_data, "train_data.pt")
# torch.save(test_data, "test_data.pt")


# Visualisation TinyImageNet

In [None]:
def tensor_to_img(tensor):
    mean = torch.tensor([0.4802, 0.4481, 0.3975])
    std = torch.tensor([0.2302, 0.2265, 0.2262])
    
    # denormalize
    img = tensor.clone()
    img = img * std[:, None, None] + mean[:, None, None]
    img = torch.clamp(img, 0, 1)
    
    # convert to PIL image
    to_pil = T.ToPILImage()
    return to_pil(img)

# plot
imgs_to_display = [random.randint(0, len(train_data[0])-1) for i in range(9)]
fig, axes = plt.subplots(3, 3, figsize=(4, 4))
axes = axes.flatten()

for i in range(9):
    img, label = train_data[0][imgs_to_display[i]], train_data[2][imgs_to_display[i]].item()
    img = tensor_to_img(img)
    if label == len(corrupt_types):
        axes[i].set_title(f"normal", fontsize=8)
    else:
        axes[i].set_title(corrupt_types[label], fontsize=8)
    axes[i].imshow(img)
    axes[i].axis("off")

plt.tight_layout()
plt.show()

# Model config for TinyImageNet

In [2]:
# setting seed 
torch.cuda.manual_seed(22)
random.seed(22)
torch.manual_seed(22)

device = "cuda" if torch.cuda.is_available() else "cpu"

corrupt_types = ["motion_blur", "shot_noise", "jpeg_compression", "fog"]

# Hyper-parameters
PATCH_SIZE = 8
IMG_SIZE = 64
EMBED_DIM = 192
NUM_HEADS = 3
NUM_IMG_TYPES = len(corrupt_types)+1
NUM_ENCODERS = 12
NUM_CLASSES = 200
DROPOUT = 0
DROP_PATH = 0.1
ERASE_PROB = 0.25

In [None]:
BATCH_SIZE = 128
train_data = torch.load("train_data.pt", weights_only=True)
test_loader = DataLoader(dataset=TensorDataset(*torch.load("test_data.pt", weights_only=True)), batch_size=BATCH_SIZE, shuffle=True)

itr = iter(test_loader)
train_batch = [next(itr)]
test_batch = [next(itr)]

# deit3_small_patch16_224.fb_in22k_ft_in1k -- 22M
# convnext_tiny.fb_in22k_ft_in1k -- 28M
deit3_small = timm.create_model('deit3_small_patch16_224.fb_in22k_ft_in1k', pretrained=True).cuda()
# teacher_model = timm.create_model('vit_small_patch16_224.augreg_in21k_ft_in1k', pretrained=True).cuda()
deit3_small.head = nn.Linear(in_features=384, out_features=NUM_CLASSES, bias=True).cuda()
num_params = sum(p.numel() for p in deit3_small.parameters())
print(num_params)
print(deit3_small.head)

21751496
Linear(in_features=384, out_features=200, bias=True)


In [None]:
# === LR Scheduler ===
NUM_EPOCHS = 50
WARMUP_EPOCHS = 3
total_steps = NUM_EPOCHS * len(trai)
warmup_steps = WARMUP_EPOCHS * len(train_batch)

optimizer = optim.AdamW(deit3_small.parameters(), lr=5e-5, weight_decay=0.05)
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-5, total_iters=warmup_steps)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS-WARMUP_EPOCHS)
scheduler = optim.lr_scheduler.SequentialLR(
    optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps]
)

from train_test_module import MyAugments
augmenter = MyAugments(NUM_CLASSES, mixup_p=0.2)

# ------------------------ Baseline ------------------------------

In [None]:
class TrainTestBaseline:
    def __init__(self, model:nn.Module, train_data, test_loader,
                 augmenter, num_img_types, batch_size, device):
        self.model = model.to(device)
        self.train_data = train_data
        self.test_loader = test_loader
        self.num_img_types = num_img_types
        self.device = device
        self.batch_size = batch_size
        self.augmenter = augmenter

        self.all_test_metrics, self.all_train_metrics = [], []

    def test(self, print_metrics=False):
        top1_correct_preds, top5_correct_preds = 0, 0
        total_samples, total_ece, total_entropy = 0, 0, 0

        total_per_type = torch.zeros(self.num_img_types, device=self.device)
        top1_correct_per_type = torch.zeros(self.num_img_types, device=self.device)
        top5_correct_per_type = torch.zeros(self.num_img_types, device=self.device)

        self.model.eval()
        with torch.no_grad():
            for x_batch, y_batch, c_batch in self.test_loader:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                x_batch = F.interpolate(x_batch, size=(224, 224), mode='bilinear', align_corners=False)
                preds = self.model(x_batch)
                del x_batch 
                
                total_samples += y_batch.size(0)
                # top-1 acc
                top1_right = (torch.argmax(preds, dim=1) == y_batch)
                top1_correct_preds += top1_right.sum().item()
                # top-5 acc
                top5_right = torch.topk(preds, 5, dim=1).indices.eq(y_batch.unsqueeze(1)).any(dim=1)
                top5_correct_preds += top5_right.sum().item()
                # ECE loss
                total_ece += compute_ece(preds, y_batch, self.device)
                # Per-type acc
                for t in range(self.num_img_types):
                    mask = (c_batch == t)
                    total_per_type[t] += mask.sum()
                    top1_correct_per_type[t] += top1_right[mask].sum()
                    top5_correct_per_type[t] += top5_right[mask].sum()
                # entropy
                total_entropy += -torch.sum(torch.softmax(preds, dim=1) * torch.log_softmax(preds, dim=1), dim=1).sum().item()
                del y_batch, c_batch
                
        top1_acc_per_type = (top1_correct_per_type / total_per_type).tolist()
        test_metrics = {
            "top1_acc" : top1_correct_preds / total_samples, 
            "top5_acc" : top5_correct_preds / total_samples, 
            "top1_acc_per_type" : top1_acc_per_type,
            "top5_acc_per_type" :(top5_correct_per_type / total_per_type).tolist(), 
            "error_rate_per_type" : [1 - acc for acc in top1_acc_per_type],
            "ece" : total_ece / total_samples,
            "entropy" : total_entropy / total_samples
        }
        if print_metrics : print(f"Test-Accuracy:{test_metrics['top1_acc']:.2f}")
        return test_metrics
    
    def train(self, optimizer, scheduler, save_path, 
              num_epochs=1, label_smoothing=0.1, print_metrics=False
              ):
        best_acc, epochs_no_improve, min_delta, patience = 0, 0, 0.003, 5
        for epoch in range(1, num_epochs+1):
            self.model.train()
            print(f"------- Epoch {epoch} -------")
            total_samples, top1_correct_preds, loss_total = 0, 0, 0
            train_loader = DataLoader(dataset=TensorDataset(*self.train_data),batch_size=self.batch_size, shuffle=True)

            for x_batch, y_batch, _ in train_loader:
                x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
                x_batch, y_batch = self.augmenter(x_batch, y_batch)
                x_batch = F.interpolate(x_batch, size=(224, 224), mode='bilinear', align_corners=False)
                preds = self.model(x_batch)

                del x_batch                 # free memory
                total_samples += y_batch.size(0)
                # top-1 acc
                top1_correct_preds += (torch.argmax(preds, dim=1) == y_batch).sum().item()
                #loss
                loss = F.cross_entropy(preds, y_batch, label_smoothing=label_smoothing)
                # backprop
                optimizer.zero_grad()
                loss.backward()
                optimizer.step() 

                loss_total += loss.item() * y_batch.size(0)
                del y_batch, preds, loss            # free memory
            
            scheduler.step()

            test_metrics = self.test()
            train_metrics = {
                "loss_total": loss_total/total_samples,
                "top1_acc" : top1_correct_preds / total_samples
                }
            current_acc = test_metrics['top1_acc']
            if print_metrics : 
                print(f"train-loss: {train_metrics['loss_total']:.2f} -- train-acc: {train_metrics['top1_acc']:.2f} -- "
                      f"train-acc: {current_acc:.2f}"
                      )
            self.all_train_metrics.append(train_metrics)
            self.all_test_metrics.append(test_metrics)
            # early stopping
            if current_acc - best_acc > min_delta:
                best_acc = current_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} — no improvement in {patience} epochs.")
                break
        
        # save trained model and metrics
        if save_path:
            torch.save(self.model.state_dict(), f"{save_path}.pth")
            print(f"Model saved to {save_path}")
            
            with open(f"{save_path}_train_metrics.json", "w") as f1:
                json.dump(self.all_train_metrics, f1, indent=4)
            with open(f"{save_path}_test_metrics.json", "w") as f2:
                json.dump(self.all_test_metrics, f2, indent=4)

In [6]:
baseline_module = TrainTestBaseline(deit3_small, train_batch[0], test_batch, augmenter, NUM_IMG_TYPES, BATCH_SIZE, device)

baseline_module.train(optimizer, scheduler, "")

------- Epoch 1 -------


In [None]:
from train_test_module import LossCalculatorDeiT
class TrainTestDeiT:
    def __init__(self, model:nn.Module, teacher_model:nn.Module, train_batches, test_batches, head_strategy, num_img_types,
                 device
                 ):
        assert head_strategy > 0 and head_strategy <= 3
        self.model = model.to(device)
        self.teacher_model = teacher_model.to(device)
        self.train_batches = train_batches
        self.test_batches = test_batches
        self.head_strategy = head_strategy
        self.num_img_types = num_img_types
        self.device = device

        self.loss_calculater = LossCalculatorDeiT(self.teacher_model)
        self.all_test_metrics, self.all_train_metrics = [], []
    
    # testing function
    def test(self, print_metrics=False):
        top1_correct_preds, top5_correct_preds = 0, 0
        top1_correct_corruptions, total_samples, total_ece, total_entropy = 0, 0, 0, 0

        total_per_type = torch.zeros(self.num_img_types, device=self.device)
        top1_correct_per_type = torch.zeros(self.num_img_types, device=self.device)
        top5_correct_per_type = torch.zeros(self.num_img_types, device=self.device)

        self.model.eval()
        with torch.no_grad():
            for x_batch, y_batch, c_batch in self.test_batches:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                tokens = self.model(x_batch)
                del x_batch
                
                if len(tokens) == 3 and self.head_strategy != 1:
                    preds = (tokens[0] + tokens[1] + self.model.output_head.ffn(tokens[2])) / 3
                else:
                    preds = (tokens[0] + tokens[1]) / 2
                
                total_samples += y_batch.size(0)
                # top-1 acc
                top1_right = (torch.argmax(preds, dim=1) == y_batch)
                top1_correct_preds += top1_right.sum().item()
                # top-5 acc
                top5_right = torch.topk(preds, 5, dim=1).indices.eq(y_batch.unsqueeze(1)).any(dim=1)
                top5_correct_preds += top5_right.sum().item()
                # ECE loss
                total_ece += compute_ece(preds, y_batch, self.device)
                # Per-type acc
                for t in range(self.num_img_types):
                    mask = (c_batch == t)
                    total_per_type[t] += mask.sum()
                    top1_correct_per_type[t] += top1_right[mask].sum()
                    top5_correct_per_type[t] += top5_right[mask].sum()
                # entropy
                total_entropy += -torch.sum(torch.softmax(preds, dim=1) * torch.log_softmax(preds, dim=1), dim=1).sum().item()
                # top-1 corruption classification acc
                if len(tokens) == 3 : top1_correct_corruptions += (torch.argmax(tokens[2], dim=1) == c_batch).sum().item()              
                del y_batch, c_batch, preds

        top1_acc_per_type = (top1_correct_per_type / total_per_type).tolist()
        test_metrics = {
            "top1_acc" : top1_correct_preds / total_samples, 
            "top5_acc" : top5_correct_preds / total_samples, 
            "top1_acc_per_type" : top1_acc_per_type,
            "top5_acc_per_type" :(top5_correct_per_type / total_per_type).tolist(), 
            "error_rate_per_type" : [1 - acc for acc in top1_acc_per_type],
            "ece" : total_ece / total_samples,
            "entropy" : total_entropy / total_samples
        }
        if top1_correct_corruptions : test_metrics["top1_corrupt_acc"] = top1_correct_corruptions / total_samples
        if print_metrics : print(f"Test-Accuracy:{test_metrics['top1_acc']:.2f}")
        return test_metrics

    # ------------ training function ------------
    def train(self, optimizer, scheduler, save_path, num_epochs=1, print_metrics=False):
        best_acc, epochs_no_improve, min_delta, patience = 0, 0, 0.003, 4
        
        for epoch in range(1, num_epochs+1):
            self.model.train()
            loss_total, loss_cls, loss_distill, loss_corrupt, loss_cls_corruptFFN = 0, 0, 0, 0, 0
            total_samples, sim_cls_distill, sim_cls_corrupt, top1_correct_preds = 0, 0, 0, 0

            print(f"------- Epoch {epoch} -------")
            for x_batch, y_batch, c_batch in self.train_batches:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                tokens = self.model(x_batch)
                
                if len(tokens) == 3:
                    losses = self.loss_calculater(tokens, (x_batch, y_batch, c_batch))
                    if self.head_strategy >= 2: preds = (tokens[0] + tokens[1] + self.model.output_head.ffn(tokens[2])) / 3
                    else: preds = (tokens[0] + tokens[1]) / 2
                else:
                    losses = self.loss_calculater(tokens, (x_batch, y_batch))
                    preds = (tokens[0] + tokens[1]) / 2
                
                # top-1 acc
                top1_correct_preds += (torch.argmax(preds, dim=1) == y_batch).sum().item()
                
                del x_batch, preds, c_batch
                # backprop
                optimizer.zero_grad()
                losses[0].backward()
                optimizer.step()
                scheduler.step()

                total_samples += y_batch.size(0)
                # losses
                loss_total += losses[0].item() * y_batch.size(0)
                loss_cls += losses[1].item() * y_batch.size(0)
                loss_distill += losses[2].item() * y_batch.size(0)
                if len(losses) == 4:
                    loss_corrupt += losses[3].item() * y_batch.size(0)
                    sim_cls_corrupt += self.model.sim_cls_corrupt.item() * y_batch.size(0)
                if len(losses) == 5:
                    loss_cls_corruptFFN += losses[4].item() * y_batch.size(0)
                # cosine similarity
                sim_cls_distill += self.model.sim_cls_distill.item() * y_batch.size(0)

                del y_batch
                

            train_metrics = {
                "top1_acc" : top1_correct_preds / total_samples,
                "loss_total": loss_total/total_samples,
                "loss_cls": loss_cls/total_samples,
                "loss_distill": loss_distill/total_samples,
                "sim_cls_distill" : sim_cls_distill/total_samples,
            }
            if loss_corrupt: 
                train_metrics["loss_corrupt"] = loss_corrupt/total_samples
                train_metrics["sim_cls_corrupt"] = sim_cls_corrupt/total_samples
            if loss_cls_corruptFFN:
                train_metrics["loss_cls_corruptFFN"] = loss_cls_corruptFFN/total_samples
            test_metrics = self.test()
            
            current_acc = test_metrics['top1_acc']
            if print_metrics : 
                print(f"train-loss: {train_metrics['loss_total']:.2f} -- train-acc: {train_metrics['top1_acc']:.2f} -- "
                      f"train-acc: {current_acc:.2f}"
                      )
            self.all_train_metrics.append(train_metrics)
            self.all_test_metrics.append(test_metrics)

            # early stopping
            if current_acc - best_acc > min_delta:
                best_acc = current_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} — no improvement in {patience} epochs.")
                break
        
        # save trained model
        if save_path:
            torch.save(self.model.state_dict(), f"{save_path}.pth")
            print(f"\nModel saved to {save_path}")
            with open(f"{save_path}_train_metrics.json", "w") as f1:
                json.dump(self.all_train_metrics, f1, indent=4)
            with open(f"{save_path}_test_metrics.json", "w") as f2:
                json.dump(self.all_test_metrics, f2, indent=4)


In [4]:
class TrainTestCdeiT:
    def __init__(self, model:nn.Module, teacher_model:nn.Module, train_batches, test_batches, head_strategy, num_img_types,
                 device
                 ):
        assert head_strategy > 0 and head_strategy <= 3
        self.model = model.to(device)
        self.teacher_model = teacher_model.to(device)
        self.train_batches = train_batches
        self.test_batches = test_batches
        self.head_strategy = head_strategy
        self.num_img_types = num_img_types
        self.device = device

        self.loss_calculater = LossCalculatorDeiT(self.teacher_model)
        self.all_test_metrics, self.all_train_metrics = [], []
    
    # testing function
    def test(self, print_metrics=False):
        top1_correct_preds, top5_correct_preds = 0, 0
        top1_correct_corruptions, total_samples, total_ece, total_entropy = 0, 0, 0, 0

        total_per_type = torch.zeros(self.num_img_types, device=self.device)
        top1_correct_per_type = torch.zeros(self.num_img_types, device=self.device)
        top5_correct_per_type = torch.zeros(self.num_img_types, device=self.device)

        self.model.eval()
        with torch.no_grad():
            for x_batch, y_batch, c_batch in self.test_batches:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                tokens = self.model(x_batch)
                del x_batch
                
                if self.head_strategy >= 2:
                    preds = (tokens[0] + tokens[1] + self.model.output_head.ffn(tokens[2])) / 3
                else:
                    preds = (tokens[0] + tokens[1]) / 2
                
                total_samples += y_batch.size(0)
                # top-1 acc
                top1_right = (torch.argmax(preds, dim=1) == y_batch)
                top1_correct_preds += top1_right.sum().item()
                # top-5 acc
                top5_right = torch.topk(preds, 5, dim=1).indices.eq(y_batch.unsqueeze(1)).any(dim=1)
                top5_correct_preds += top5_right.sum().item()
                # ECE loss
                total_ece += compute_ece(preds, y_batch, self.device)
                # Per-type acc
                for t in range(self.num_img_types):
                    mask = (c_batch == t)
                    total_per_type[t] += mask.sum()
                    top1_correct_per_type[t] += top1_right[mask].sum()
                    top5_correct_per_type[t] += top5_right[mask].sum()
                # entropy
                total_entropy += -torch.sum(torch.softmax(preds, dim=1) * torch.log_softmax(preds, dim=1), dim=1).sum().item()
                # top-1 corruption classification acc
                top1_correct_corruptions += (torch.argmax(tokens[2], dim=1) == c_batch).sum().item()              
                del y_batch, c_batch, preds

        top1_acc_per_type = (top1_correct_per_type / total_per_type).tolist()
        test_metrics = {
            "top1_acc" : top1_correct_preds / total_samples, 
            "top5_acc" : top5_correct_preds / total_samples, 
            "top1_corrupt_acc" : top1_correct_corruptions / total_samples,
            "top1_acc_per_type" : top1_acc_per_type,
            "top5_acc_per_type" :(top5_correct_per_type / total_per_type).tolist(), 
            "error_rate_per_type" : [1 - acc for acc in top1_acc_per_type],
            "ece" : total_ece / total_samples,
            "entropy" : total_entropy / total_samples
        }
        if print_metrics : print(f"Test-Accuracy:{test_metrics['top1_acc']:.2f}")
        return test_metrics

    # ------------ training function ------------
    def train(self, optimizer, scheduler, save_path, num_epochs=1, print_metrics=False):

        best_acc, epochs_no_improve, min_delta, patience = 0, 0, 0.003, 4
        for epoch in range(1, num_epochs+1):
            self.model.train()
            loss_total, loss_cls, loss_distill, loss_corrupt, loss_cls_corruptFFN = 0, 0, 0, 0, 0
            total_samples, sim_cls_distill, sim_cls_corrupt, top1_correct_preds = 0, 0, 0, 0

            print(f"------- Epoch {epoch} -------")
            for x_batch, y_batch, c_batch in self.train_batches:
                x_batch, y_batch, c_batch = x_batch.to(self.device), y_batch.to(self.device), c_batch.to(self.device)
                tokens = self.model(x_batch)
                
                losses = self.loss_calculater(tokens, (x_batch, y_batch, c_batch))
                if self.head_strategy >= 2: preds = (tokens[0] + tokens[1] + self.model.output_head.ffn(tokens[2])) / 3
                else: preds = (tokens[0] + tokens[1]) / 2
                
                # top-1 acc
                top1_correct_preds += (torch.argmax(preds, dim=1) == y_batch).sum().item()
                
                del x_batch, preds, c_batch
                # backprop
                optimizer.zero_grad()
                losses[0].backward()
                optimizer.step()
                scheduler.step()

                total_samples += y_batch.size(0)
                # losses
                loss_total += losses[0].item() * y_batch.size(0)
                loss_cls += losses[1].item() * y_batch.size(0)
                loss_distill += losses[2].item() * y_batch.size(0)
                loss_corrupt += losses[3].item() * y_batch.size(0)
                if len(losses) == 5:
                    loss_cls_corruptFFN += losses[4].item() * y_batch.size(0)
                # cosine similarity
                sim_cls_distill += self.model.sim_cls_distill.item() * y_batch.size(0)
                sim_cls_corrupt += self.model.sim_cls_corrupt.item() * y_batch.size(0)
                del y_batch
                
            train_metrics = {
                "top1_acc" : top1_correct_preds / total_samples,
                "loss_total": loss_total/total_samples,
                "loss_cls": loss_cls/total_samples,
                "loss_distill": loss_distill/total_samples,
                "sim_cls_distill" : sim_cls_distill/total_samples,
                "loss_corrupt" : loss_corrupt/total_samples,
                "sim_cls_corrupt" : sim_cls_corrupt/total_samples
            }
            if loss_cls_corruptFFN : train_metrics["loss_cls_corruptFFN"] = loss_cls_corruptFFN/total_samples
            test_metrics = self.test()
            
            current_acc = test_metrics['top1_acc']
            if print_metrics : 
                print(f"train-loss: {train_metrics['loss_total']:.2f} -- train-acc: {train_metrics['top1_acc']:.2f} -- "
                      f"train-acc: {current_acc:.2f}"
                      )
            self.all_train_metrics.append(train_metrics)
            self.all_test_metrics.append(test_metrics)

            # early stopping
            if current_acc - best_acc > min_delta:
                best_acc = current_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} — no improvement in {patience} epochs.")
                break
        
        # save trained model
        if save_path:
            torch.save(self.model.state_dict(), f"{save_path}.pth")
            print(f"\nModel saved to {save_path}")
            with open(f"{save_path}_train_metrics.json", "w") as f1:
                json.dump(self.all_train_metrics, f1, indent=4)
            with open(f"{save_path}_test_metrics.json", "w") as f2:
                json.dump(self.all_test_metrics, f2, indent=4)


In [6]:
from my_transformers import DistillVisionTransformer, CorruptDistillVisionTransformer

deit_small = DistillVisionTransformer(
    EMBED_DIM, IMG_SIZE, PATCH_SIZE, NUM_CLASSES, attention_heads=NUM_HEADS,
    num_encoders=NUM_ENCODERS, dropout=DROPOUT, drop_path=DROP_PATH
    )

Cdeit_small = CorruptDistillVisionTransformer(
    EMBED_DIM, IMG_SIZE, PATCH_SIZE, NUM_CLASSES, attention_heads=NUM_HEADS,
    num_encoders=NUM_ENCODERS, dropout=DROPOUT, drop_path=DROP_PATH,
    num_img_types=NUM_IMG_TYPES, head_strategy=3
)

In [9]:

train_student_module = TrainTestBaseline(deit3_small, train_batch, test_batch, NUM_IMG_TYPES, device)
train_student_module.train(optimizer, scheduler, "")
train_student_module.all_test_metrics

------- Epoch 1 -------


[{'top1_acc': 0.0234375,
  'top5_acc': 0.0703125,
  'top1_acc_per_type': [0.032258063554763794,
   0.0,
   0.06896551698446274,
   0.0,
   0.0],
  'top5_acc_per_type': [0.06451612710952759,
   0.0357142873108387,
   0.13793103396892548,
   0.05000000074505806,
   0.05000000074505806],
  'error_rate_per_type': [0.9677419364452362,
   1.0,
   0.9310344830155373,
   1.0,
   1.0],
  'ece': 0.0004661552084144205,
  'entropy': 4.456512451171875}]

In [8]:
train_student_module.all_test_metrics

[{'top1_acc': 0.015625,
  'top5_acc': 0.0390625,
  'top1_acc_per_type': [0.0, 0.0, 0.06896551698446274, 0.0, 0.0],
  'top5_acc_per_type': [0.032258063554763794,
   0.0,
   0.1034482792019844,
   0.0,
   0.05000000074505806],
  'error_rate_per_type': [1.0, 1.0, 0.9310344830155373, 1.0, 1.0],
  'ece': 0.0006139964680187404,
  'entropy': 4.389969348907471}]