In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as T
import os
import timm

from train_test_module import MyAugments, LossCalculatorCdeiT, TrainTestCdeiT

%load_ext autoreload
%autoreload 2

In [2]:
# setting seed 
torch.cuda.manual_seed(22)
random.seed(22)
torch.manual_seed(22)

device = "cuda" if torch.cuda.is_available() else "cpu"

corrupt_types = ["brightness", "defocus_blur", "zoom_blur", "motion_blur", "fog", "frost", "snow", "shot_noise", "gaussian_noise", "jpeg_compression"]

# Hyper-parameters
PATCH_SIZE = 4
IMG_SIZE = 32
EMBED_DIM = 192
NUM_HEADS = 3
NUM_ENCODERS = 12

NUM_IMG_TYPES = len(corrupt_types)+1
NUM_CLASSES = 10
DROPOUT = 0
DROP_PATH = 0.1

ERASE_P = 0.25
RANDAUG_P = 0.5
MIXUP_P = 0.3
CUTMIX_P = 0.5

BATCH_SIZE = 1024
NUM_EPOCHS = 50
WARMUP_EPOCHS = 3

In [3]:
train_loader = DataLoader(dataset=TensorDataset(*torch.load("train_cifar10.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(*torch.load("test_cifar10.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=False)

deit3_teacher = timm.create_model('deit3_small_patch16_224').cuda()
deit3_teacher.head = nn.Linear(in_features=384, out_features=NUM_CLASSES, bias=True).cuda()
deit3_teacher.load_state_dict(torch.load("deit3HEAD_all0.pth"))

  deit3_teacher.load_state_dict(torch.load("deit3HEAD_all0.pth"))


<All keys matched successfully>

# Experiments

In [4]:
lr = 0.0005  * BATCH_SIZE/512

In [5]:
from my_transformers import CorruptDistillVisionTransformer
HEAD_STRATEGY = 3
cdeit_tiny = CorruptDistillVisionTransformer(
    EMBED_DIM, IMG_SIZE, PATCH_SIZE, NUM_CLASSES, attention_heads=NUM_HEADS,
    num_encoders=NUM_ENCODERS, dropout=0, drop_path=0.1, 
    num_img_types=NUM_IMG_TYPES, head_strategy=HEAD_STRATEGY
    ).to(device)

optimizer = optim.AdamW(cdeit_tiny.parameters(), lr=lr, weight_decay=0.05)
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-3, total_iters=WARMUP_EPOCHS)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS - WARMUP_EPOCHS)
scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[WARMUP_EPOCHS])

loss_calculator = LossCalculatorCdeiT(deit3_teacher, HEAD_STRATEGY)

augmenter = MyAugments(NUM_CLASSES, mixup_p=MIXUP_P, cutmix_p=0, randaug_p=0, erasing_p=ERASE_P)
deit_train_module = TrainTestCdeiT(cdeit_tiny, deit3_teacher, train_loader, test_loader, NUM_IMG_TYPES, device, HEAD_STRATEGY)
deit_train_module.train(optimizer, scheduler, augmenter, loss_calculator, "cdeit_softDistill", num_epochs=15, print_metrics=True)



------- Epoch 1 -------
train-loss: 2.328 -- train-acc: 0.151 -- test-loss: 2.214 -- test-acc: 0.207
Best model saved to cdeit_softDistill.pth
------- Epoch 2 -------
train-loss: 1.940 -- train-acc: 0.441 -- test-loss: 1.352 -- test-acc: 0.588
Best model saved to cdeit_softDistill.pth
------- Epoch 3 -------




train-loss: 1.673 -- train-acc: 0.599 -- test-loss: 0.992 -- test-acc: 0.771
Best model saved to cdeit_softDistill.pth
------- Epoch 4 -------
train-loss: 1.521 -- train-acc: 0.689 -- test-loss: 0.805 -- test-acc: 0.868
Best model saved to cdeit_softDistill.pth
------- Epoch 5 -------
train-loss: 1.341 -- train-acc: 0.794 -- test-loss: 0.831 -- test-acc: 0.888
------- Epoch 6 -------
train-loss: 1.252 -- train-acc: 0.833 -- test-loss: 0.839 -- test-acc: 0.930
------- Epoch 7 -------
train-loss: 1.153 -- train-acc: 0.853 -- test-loss: 0.522 -- test-acc: 0.943
Best model saved to cdeit_softDistill.pth
------- Epoch 8 -------
train-loss: 1.110 -- train-acc: 0.883 -- test-loss: 0.592 -- test-acc: 0.955
------- Epoch 9 -------
train-loss: 1.043 -- train-acc: 0.891 -- test-loss: 0.523 -- test-acc: 0.956
------- Epoch 10 -------
train-loss: 1.042 -- train-acc: 0.894 -- test-loss: 0.713 -- test-acc: 0.963
------- Epoch 11 -------
train-loss: 1.011 -- train-acc: 0.899 -- test-loss: 0.538 -- tes