In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as T
import os
import timm

from train_test_module import MyAugments, LossCalculatorDeiT, TrainTestDeiTModule

%load_ext autoreload
%autoreload 2

In [2]:
# setting seed 
torch.cuda.manual_seed(22)
random.seed(22)
torch.manual_seed(22)

device = "cuda" if torch.cuda.is_available() else "cpu"

corrupt_types = ["brightness", "defocus_blur", "zoom_blur", "motion_blur", "fog", "frost", "snow", "shot_noise", "gaussian_noise", "jpeg_compression"]

# Hyper-parameters
PATCH_SIZE = 4
IMG_SIZE = 32
EMBED_DIM = 192
NUM_HEADS = 3
NUM_ENCODERS = 12

NUM_IMG_TYPES = len(corrupt_types)+1
NUM_CLASSES = 10
DROPOUT = 0
DROP_PATH = 0.1

ERASE_P = 0.25
RANDAUG_P = 0.5
MIXUP_P = 0.3
CUTMIX_P = 0.5

BATCH_SIZE = 1024
NUM_EPOCHS = 50
WARMUP_EPOCHS = 3

In [3]:
train_loader = DataLoader(dataset=TensorDataset(*torch.load("train_cifar10.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(*torch.load("test_cifar10.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=False)

deit3_teacher = timm.create_model('deit3_small_patch16_224').cuda()
deit3_teacher.head = nn.Linear(in_features=384, out_features=NUM_CLASSES, bias=True).cuda()
deit3_teacher.load_state_dict(torch.load("deit3HEAD_all0.pth"))

from soft_distillation import SoftLossCalculatorDeiT
loss_calculator = SoftLossCalculatorDeiT(deit3_teacher, tau=3, alpha=0.5)

  deit3_teacher.load_state_dict(torch.load("deit3HEAD_all0.pth"))


# Experiment lr= 5e-4

In [4]:
lr = 0.0005  * BATCH_SIZE/512

In [5]:
from my_transformers import DistillVisionTransformer
deit_tiny = DistillVisionTransformer(
    EMBED_DIM, IMG_SIZE, PATCH_SIZE, NUM_CLASSES, attention_heads=NUM_HEADS,
    num_encoders=NUM_ENCODERS, dropout=0, drop_path=0.1
    ).to(device)

optimizer = optim.AdamW(deit_tiny.parameters(), lr=lr, weight_decay=0.05)
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-3, total_iters=WARMUP_EPOCHS)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS - WARMUP_EPOCHS)
scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[WARMUP_EPOCHS])

augmenter = MyAugments(NUM_CLASSES, mixup_p=MIXUP_P, cutmix_p=0, randaug_p=0, erasing_p=ERASE_P)
deit_train_module = TrainTestDeiTModule(deit_tiny, deit3_teacher, train_loader, test_loader, NUM_IMG_TYPES, device)
deit_train_module.train(optimizer, scheduler, augmenter, loss_calculator, "deit_softDistill", num_epochs=15, print_metrics=True)



------- Epoch 1 -------
train-loss: 1.235 -- train-acc: 0.144 -- test-loss: 2.188 -- test-acc: 0.213
Best model saved to deit_softDistill.pth
------- Epoch 2 -------
train-loss: 0.970 -- train-acc: 0.430 -- test-loss: 1.166 -- test-acc: 0.611
Best model saved to deit_softDistill.pth
------- Epoch 3 -------




train-loss: 0.788 -- train-acc: 0.611 -- test-loss: 0.722 -- test-acc: 0.796
Best model saved to deit_softDistill.pth
------- Epoch 4 -------
train-loss: 0.656 -- train-acc: 0.750 -- test-loss: 0.420 -- test-acc: 0.921
Best model saved to deit_softDistill.pth
------- Epoch 5 -------
train-loss: 0.599 -- train-acc: 0.813 -- test-loss: 0.379 -- test-acc: 0.942
Best model saved to deit_softDistill.pth
------- Epoch 6 -------
train-loss: 0.547 -- train-acc: 0.860 -- test-loss: 0.407 -- test-acc: 0.954
------- Epoch 7 -------
train-loss: 0.489 -- train-acc: 0.893 -- test-loss: 0.299 -- test-acc: 0.962
Best model saved to deit_softDistill.pth
------- Epoch 8 -------
train-loss: 0.486 -- train-acc: 0.896 -- test-loss: 0.308 -- test-acc: 0.961
------- Epoch 9 -------
train-loss: 0.471 -- train-acc: 0.905 -- test-loss: 0.247 -- test-acc: 0.963
Best model saved to deit_softDistill.pth
------- Epoch 10 -------
train-loss: 0.466 -- train-acc: 0.910 -- test-loss: 0.256 -- test-acc: 0.966
------- Ep