# Data

- [Download corrupted Cifar10 .tar](https://zenodo.org/records/2535967) and extract it into a folder.

- Cifar10 is downloaded using torch.datasets

- [Download corrupted Tiny-ImageNet](https://zenodo.org/records/2536630) and extract it into a folder.

- [Download Tiny-ImageNet](https://www.kaggle.com/datasets/akash2sharma/tiny-imagenet) and extract it into a folder.

# Imports


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as T
from torchvision.io import read_image
import matplotlib.pyplot as plt
from PIL import Image
import os

from utils import load_experimental_TinyImageNet
from my_transformers import CorruptDistillVisionTransformer

  from .autonotebook import tqdm as notebook_tqdm


# Loading Data

In [2]:
corrupt_types = ["motion_blur", "shot_noise", "jpeg_compression", "fog"]

# corrupt_path = r"C:\Users\Hp\Desktop\Coding\Transformer-Thesis\Tiny-ImageNet-C\Tiny-ImageNet-C"
# normal_path = r"C:\Users\Hp\Desktop\Coding\Transformer-Thesis\Tiny-ImageNet-Normal"
# train_data, test_data = load_experimental_TinyImageNet(normal_path, corrupt_path, corrupt_types)

# torch.save(train_data, "experiment_train_data.pt")
# torch.save(test_data, "experiment_test_data.pt")

# train_data = TensorDataset(*torch.load("experiment_train_data.pt", weights_only=True))
# test_data = TensorDataset(*torch.load("experiment_test_data.pt", weights_only=True))

# Visualisation TinyImageNet

In [None]:
def tensor_to_img(tensor):
    mean = torch.tensor([0.4802, 0.4481, 0.3975])
    std = torch.tensor([0.2302, 0.2265, 0.2262])
    
    # denormalize
    img = tensor.clone()
    img = img * std[:, None, None] + mean[:, None, None]
    img = torch.clamp(img, 0, 1)
    
    # convert to PIL image
    to_pil = T.ToPILImage()
    return to_pil(img)

# plot
imgs_to_display = [random.randint(0, len(train_data[0])-1) for i in range(9)]
fig, axes = plt.subplots(3, 3, figsize=(4, 4))
axes = axes.flatten()

for i in range(9):
    img, label = train_data[0][imgs_to_display[i]], train_data[2][imgs_to_display[i]].item()
    img = tensor_to_img(img)
    if label == len(corrupt_types):
        axes[i].set_title(f"normal", fontsize=8)
    else:
        axes[i].set_title(corrupt_types[label], fontsize=8)
    axes[i].imshow(img)
    axes[i].axis("off")

plt.tight_layout()
plt.show()

# Model config for TinyImageNet

In [4]:
# setting seed 
torch.cuda.manual_seed(22)
random.seed(22)
torch.manual_seed(22)

device = "cuda" if torch.cuda.is_available() else "cpu"

# Hyper-parameters
BATCH_SIZE = 128
PATCH_SIZE = 8
IMG_SIZE = 64
EMBED_DIM = 192
NUM_HEADS = 3
IMG_TYPES = len(corrupt_types)+1
NUM_ENCODERS = 12
NUM_CLASSES = 200
DROPOUT = 0.1
DROP_PATH = 0.1
ERASE_PROB = 0

# train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
tiny_vit_model = CorruptDistillVisionTransformer(
    EMBED_DIM, IMG_SIZE, PATCH_SIZE, NUM_CLASSES, attention_heads=NUM_HEADS,
    num_encoders=NUM_ENCODERS, dropout=DROPOUT, drop_path = DROP_PATH, erase_prob=ERASE_PROB,
    num_img_types=IMG_TYPES, head_strategy=2
    ).to(device)

tiny_vit_model.output_head.W