In [None]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from configs import b16_config
from loss import DistillationLoss
from model import DeiT
from teacher_model import ResNet50
from utils import Trainer, WarmupCosineSchedule, WarmupLinearSchedule, build_model, build_teacher_model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
learning_rate = 1e-3
batch_size = 32
num_epochs = 20

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='../datasets', train=True, transform=transform)
val_dataset = datasets.CIFAR10(root='../datasets', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [None]:
distil_type = 'hard'
distil_token_type = 'cls'
alpha = 0.5
tau = 3.0

In [None]:
model_config = b16_config()
model_config.update({'num_classes':10})
model = build_model(DeiT, b16_config()).to(device)
teacher_model = build_teacher_model(ResNet50, 10).to(device)
teacher_type = 'resnet'

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = WarmupCosineSchedule(optimizer, num_epochs//5, num_epochs)
base_criterion = nn.CrossEntropyLoss()
criterion = DistillationLoss(base_criterion, teacher_model, distil_type, alpha, tau, teacher_type)

In [None]:
trainer = Trainer(model, {'train':train_loader, 'validation':val_loader}, criterion, optimizer, scheduler, num_epochs, (1, 5), 'pretrained/cifar10_cls_hard.pth', device, teacher_model, 'loss', True, distil_token_type)

In [None]:
trainer.train()