In [1]:
import sys
sys.path.append("..")

In [2]:
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn as nn

from deit.dataset import iterator
from torch.utils.data import DataLoader

from deit.layer.classifier import DeiT

___

**how to prepare the teacher model?**
- I fine-tuned the pretrained resnet18 supported from torchvision.models using our own imagenet data
- please refer the "train benchmark for imagenet (resnet18).ipynb" file

In [3]:
import torchvision.models as models

teacher_model = models.resnet18(True)
teacher_model.fc = nn.Linear(512, 115).cuda()
teacher_model = nn.DataParallel(teacher_model)
teacher_model.load_state_dict(torch.load("resnet18_imagenet_pretrained.pt"))
teacher_model = teacher_model.module

___

In [4]:
train_iterator = iterator.DistilationImageNetIterator(is_train=True,teacher=teacher_model, device='cpu')
valid_iterator = iterator.DistilationImageNetIterator(is_train=False,teacher=teacher_model, device='cpu')

train_loader = DataLoader(train_iterator, batch_size=32*2, shuffle=True, num_workers=30)
valid_loader = DataLoader(valid_iterator, batch_size=32*2, shuffle=False, num_workers=30)

In [5]:
height = 224
width = 224
channel = 3
patch = 16
d_model = 256
d_ff = d_model * 4
ffn_typ = 'glu'
act_typ = 'GELU'
n_head = 8
dropout_p = 0.1
n_enc_layer = 3
output_dim = len(train_iterator.label_dict)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
model = nn.DataParallel(DeiT(
    height,
    width,
    channel,
    patch,
    d_model,
    d_ff,
    ffn_typ,
    act_typ,
    n_head,
    dropout_p,
    n_enc_layer,
    output_dim,)).to(device)

In [7]:
optimizer = optim.Adam(model.parameters())
criterion = nn.NLLLoss()

In [8]:
def train() : 
    model.train()
    losses = []
    accuracies = []

    for dict_ in tqdm(train_loader, desc='train') : 
        pred = model(dict_['input'].to(device))

        optimizer.zero_grad()
        loss = criterion(pred, dict_['label'].to(device))
        loss.backward()
        optimizer.step()

        correct = (torch.argmax(pred, dim=1) == dict_['label'].to(device)).sum()
        acc = correct.item() / dict_['label'].shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

def evalulate() : 
    model.eval()
    losses = []
    accuracies = []

    for dict_ in tqdm(valid_loader, desc='valid') : 
        pred = model(dict_['input'].to(device))

        loss = criterion(pred, dict_['label'].to(device))
        correct = (torch.argmax(pred, dim=1) == dict_['label'].to(device)).sum()
        acc = correct.item() / dict_['label'].shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

In [None]:
epoches = 20

for proc in range(epoches) : 
    t_acc, t_loss = train()
    v_acc, v_loss = evalulate()
    print(f"""
                === {proc+1}th Epoch ===
    
        Train Loss : {round(t_loss, 3)} | Train Acc : {round(t_acc, 3)}
        Valid Loss : {round(v_loss, 3)} | Valid Acc : {round(v_acc, 3)}
        
        ============================================
        ============================================
    """)

train: 100%|██████████| 1200/1200 [11:40<00:00,  1.71it/s] 
valid: 100%|██████████| 516/516 [05:10<00:00,  1.66it/s]



                === 1th Epoch ===
    
        Train Loss : 8.794 | Train Acc : 0.06
        Valid Loss : 7.949 | Valid Acc : 0.116
        
    


train: 100%|██████████| 1200/1200 [11:45<00:00,  1.70it/s] 
valid: 100%|██████████| 516/516 [05:11<00:00,  1.65it/s]



                === 2th Epoch ===
    
        Train Loss : 7.716 | Train Acc : 0.134
        Valid Loss : 7.537 | Valid Acc : 0.156
        
    


train: 100%|██████████| 1200/1200 [11:45<00:00,  1.70it/s]
valid: 100%|██████████| 516/516 [05:09<00:00,  1.67it/s]



                === 3th Epoch ===
    
        Train Loss : 7.368 | Train Acc : 0.164
        Valid Loss : 7.199 | Valid Acc : 0.175
        
    


train:  70%|███████   | 841/1200 [08:25<00:45,  7.95it/s] 