# imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as T
import os
import timm
from timm.models.layers import DropPath

from train_test_module import FineTuningModule, MyAugments

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


# ------------------- Template --------------------------

In [2]:
# setting seed 
torch.cuda.manual_seed(22)
random.seed(22)
torch.manual_seed(22)

device = "cuda" if torch.cuda.is_available() else "cpu"

corrupt_types = ["brightness", "defocus_blur", "zoom_blur", "motion_blur", "fog", "frost", "snow", "shot_noise", "gaussian_noise", "jpeg_compression"]

# DEIT Hyper-parameters
NUM_IMG_TYPES = len(corrupt_types)+1
NUM_CLASSES = 10
DROPOUT = 0
DROP_PATH = 0.1

ERASE_P = 0.25
RANDAUG_P = 0.5
MIXUP_P = 0.3
CUTMIX_P = 0.3

BATCH_SIZE = 1024
NUM_EPOCHS = 50
WARMUP_EPOCHS = 3

In [3]:
train_loader = DataLoader(dataset=TensorDataset(*torch.load("train_cifar10.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(*torch.load("test_cifar10.pt", weights_only=True)), 
                                 batch_size=BATCH_SIZE, shuffle=False)

def set_drop_path(model, drop_path):
    for i in range(len(model.blocks)):
        model.blocks[i].drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        model.blocks[i].drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()

def set_dropout(model, dropout):
    for i in range(len(model.blocks)):
        model.blocks[i].mlp.drop1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        model.blocks[i].mlp.drop2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()


# Experiment 1

In [4]:
lr = 5e-5  #3e-5

In [None]:
deit3_small = timm.create_model('deit3_small_patch16_224.fb_in22k_ft_in1k', pretrained=True).cuda()
deit3_small.head = nn.Linear(in_features=384, out_features=NUM_CLASSES, bias=True).cuda()

optimizer = optim.SGD(deit3_small.head.parameters(), lr=lr, weight_decay=0.05)
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-3, total_iters=WARMUP_EPOCHS)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS - WARMUP_EPOCHS)
scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[WARMUP_EPOCHS])

augmenter = MyAugments(NUM_CLASSES, mixup_p=0, cutmix_p=0, randaug_p=0, erasing_p=0)
deit3_trainer_module = FineTuningModule(deit3_small, train_loader, test_loader, NUM_IMG_TYPES, device, freeze_body=True)
deit3_trainer_module.train(optimizer, scheduler, augmenter, "deit3HEAD_all0", num_epochs=20, print_metrics=True)

# all0

------- Epoch 1 -------
train-loss: 3.137 -- train-acc: 0.069 -- test-loss: 3.144 -- test-acc: 0.068
Best model saved to deit3HEAD_all0.pth
------- Epoch 2 -------
train-loss: 2.283 -- train-acc: 0.273 -- test-loss: 1.540 -- test-acc: 0.514
Best model saved to deit3HEAD_all0.pth
------- Epoch 3 -------
train-loss: 1.325 -- train-acc: 0.671 -- test-loss: 0.862 -- test-acc: 0.741
Best model saved to deit3HEAD_all0.pth
------- Epoch 4 -------




train-loss: 1.066 -- train-acc: 0.774 -- test-loss: 0.677 -- test-acc: 0.796
Best model saved to deit3HEAD_all0.pth
------- Epoch 5 -------
train-loss: 0.988 -- train-acc: 0.808 -- test-loss: 0.611 -- test-acc: 0.818
Best model saved to deit3HEAD_all0.pth
------- Epoch 6 -------
train-loss: 0.950 -- train-acc: 0.825 -- test-loss: 0.576 -- test-acc: 0.829
Best model saved to deit3HEAD_all0.pth
------- Epoch 7 -------
train-loss: 0.925 -- train-acc: 0.836 -- test-loss: 0.552 -- test-acc: 0.839
Best model saved to deit3HEAD_all0.pth
------- Epoch 8 -------
train-loss: 0.907 -- train-acc: 0.843 -- test-loss: 0.536 -- test-acc: 0.846
Best model saved to deit3HEAD_all0.pth
------- Epoch 9 -------
train-loss: 0.894 -- train-acc: 0.849 -- test-loss: 0.522 -- test-acc: 0.851
Best model saved to deit3HEAD_all0.pth
------- Epoch 10 -------
train-loss: 0.883 -- train-acc: 0.853 -- test-loss: 0.512 -- test-acc: 0.856
Best model saved to deit3HEAD_all0.pth
------- Epoch 11 -------
train-loss: 0.874 