# Vision Transformer

## Setup

In [None]:
import argparse
import datetime
import os
import sys
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn

from timm.models import create_model

from engine import train_one_epoch, train_one_epoch_distillation, evaluate
from utils import get_training_dataloader, get_test_dataloader
import models

## Question 1

In order to display the training statistics, I ran the evaluate function on the training dataset. Since the evaluate function already prints the top 1 accuracy and loss, I removed the print statements. I also needed to modify the model file because it didn't apply the num_classes argument we pass in when pretrained was false.

In [None]:
MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
CHECKPOINT_PATH = './checkpoint'
MODEL_NAME = 'vit_base_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0
shots = 1000

print(f"Creating model: {MODEL_NAME}")
model = create_model(
        MODEL_NAME,
        pretrained=False,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

cifar10_training_loader = get_training_dataloader(
    MEAN,
    STD,
    num_workers=2,
    batch_size=16,
    shuffle=True,
    shots=shots
)

assert (shots*num_classes == len(cifar10_training_loader.dataset))

cifar10_test_loader = get_test_dataloader(
    MEAN,
    STD,
    num_workers=4,
    batch_size=256,
    shuffle=False
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)


Creating model: vit_base_patch16_224
Files already downloaded and verified
Files already downloaded and verified
number of params: 85806346




In [None]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        model, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, model, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

    print("TRAINING STATISTICS")
    training_stats = evaluate(cifar10_training_loader, model, criterion, device)
    print("TEST STATISTICS")
    test_stats = evaluate(cifar10_test_loader, model, criterion, device)

Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:17:00  loss: 2.3859 (2.3859)  time: 1.6321  data: 0.3980  max mem: 2092
Epoch: [1]  [100/625]  eta: 0:01:46  loss: 2.1168 (2.3098)  time: 0.1674  data: 0.0053  max mem: 3055
Epoch: [1]  [200/625]  eta: 0:01:19  loss: 2.0270 (2.1972)  time: 0.1697  data: 0.0052  max mem: 3055
Epoch: [1]  [300/625]  eta: 0:00:59  loss: 2.0101 (2.1558)  time: 0.1753  data: 0.0069  max mem: 3055
Epoch: [1]  [400/625]  eta: 0:00:41  loss: 2.0280 (2.1319)  time: 0.2072  data: 0.0115  max mem: 3055
Epoch: [1]  [500/625]  eta: 0:00:22  loss: 2.0186 (2.1131)  time: 0.1790  data: 0.0059  max mem: 3055
Epoch: [1]  [600/625]  eta: 0:00:04  loss: 2.0206 (2.0919)  time: 0.1742  data: 0.0051  max mem: 3055
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 1.8799 (2.0871)  time: 0.1810  data: 0.0098  max mem: 3055
Epoch: [1] Total time: 0:01:53 (0.1818 s / it)
Averaged stats: loss: 1.8799 (2.0871)
TRAINING STATISTICS
Test:  [  0/625]  eta: 0:05:43  loss: 1.8474 (1.847

In [None]:
# Calculate througput
start_time = time.time()
test_stats = evaluate(cifar10_test_loader, model, criterion, device)
end_time = time.time()
num_samples = len(cifar10_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))

Test:  [ 0/40]  eta: 0:02:09  loss: 1.6802 (1.6802)  acc1: 39.4531 (39.4531)  acc5: 88.2812 (88.2812)  time: 3.2348  data: 2.2521  max mem: 3367
Test:  [39/40]  eta: 0:00:00  loss: 1.7428 (1.7195)  acc1: 34.7656 (34.7900)  acc5: 87.8906 (87.4900)  time: 0.8659  data: 0.0563  max mem: 3367
Test: Total time: 0:00:38 (0.9591 s / it)
* Acc@1 34.790 Acc@5 87.490 loss 1.719
Throughput: 260.63119392691556


## Question 2

This model (finetuned base model) is the teacher that we'll be using later in the assignment.

In [None]:
# Step 1: Train the teacher

MODEL_NAME = 'vit_base_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0

print(f"Creating model: {MODEL_NAME}")
teacher = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
teacher = teacher.to(device)



criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in teacher.parameters() if p.requires_grad)
print('number of params:', n_parameters)


Creating model: vit_base_patch16_224


Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
100%|██████████| 330M/330M [00:02<00:00, 167MB/s]


number of params: 85806346


In [None]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        teacher, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

    print("TRAINING STATISTICS")
    training_stats = evaluate(cifar10_training_loader, teacher, criterion, device)
    print("TEST STATISTICS")
    test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)

Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:16:01  loss: 2.6647 (2.6647)  time: 1.5387  data: 0.5186  max mem: 2420
Epoch: [1]  [100/625]  eta: 0:01:36  loss: 2.0907 (2.2181)  time: 0.1678  data: 0.0051  max mem: 3379
Epoch: [1]  [200/625]  eta: 0:01:16  loss: 1.7393 (2.0804)  time: 0.1772  data: 0.0096  max mem: 3379
Epoch: [1]  [300/625]  eta: 0:00:58  loss: 1.3831 (1.9155)  time: 0.1730  data: 0.0066  max mem: 3379
Epoch: [1]  [400/625]  eta: 0:00:40  loss: 1.2427 (1.7650)  time: 0.1707  data: 0.0049  max mem: 3379
Epoch: [1]  [500/625]  eta: 0:00:22  loss: 1.0401 (1.6468)  time: 0.1729  data: 0.0048  max mem: 3379
Epoch: [1]  [600/625]  eta: 0:00:04  loss: 0.8741 (1.5349)  time: 0.1817  data: 0.0086  max mem: 3379
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 0.8170 (1.5098)  time: 0.1778  data: 0.0056  max mem: 3379
Epoch: [1] Total time: 0:01:51 (0.1779 s / it)
Averaged stats: loss: 0.8170 (1.5098)
TRAINING STATISTICS
Test:  [  0/625]  eta: 0:04:27  loss: 0.6830 (0.683

In [None]:
# save finetuned teacher model
torch.save(teacher.state_dict(), './teacher.pth')

In [None]:
teacher = create_model(
        'vit_base_patch16_224',
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
teacher = teacher.to(device)
teacher.load_state_dict(torch.load('./teacher.pth'))

test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)

Test:  [ 0/40]  eta: 0:02:03  loss: 0.5970 (0.5970)  acc1: 80.0781 (80.0781)  acc5: 98.8281 (98.8281)  time: 3.0875  data: 2.0882  max mem: 4284
Test:  [39/40]  eta: 0:00:00  loss: 0.5428 (0.5722)  acc1: 82.4219 (82.5700)  acc5: 99.2188 (99.2800)  time: 0.8880  data: 0.0627  max mem: 4284
Test: Total time: 0:00:39 (0.9950 s / it)
* Acc@1 82.570 Acc@5 99.280 loss 0.572


In [None]:
# Calculate throughput
start_time = time.time()
test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)
end_time = time.time()
num_samples = len(cifar10_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))

Test:  [ 0/40]  eta: 0:02:33  loss: 0.5970 (0.5970)  acc1: 80.0781 (80.0781)  acc5: 98.8281 (98.8281)  time: 3.8428  data: 2.8354  max mem: 4284
Test:  [39/40]  eta: 0:00:00  loss: 0.5428 (0.5722)  acc1: 82.4219 (82.5700)  acc5: 99.2188 (99.2800)  time: 0.8813  data: 0.0632  max mem: 4284
Test: Total time: 0:00:39 (0.9883 s / it)
* Acc@1 82.570 Acc@5 99.280 loss 0.572
Throughput: 252.89932980445153


Pretraining on ImageNet produces significantly better results than training from scratch on CIFAR-10. This is because ImageNet provides us some good initial features to work off of, rather than our model needing to learn it all on its own. This is the essence of transfer learning, which works if we don't have as many resources or enough data and want to use a bigger model to help improve our own model.

## Question 3

Let's finetune the vit_tiny model first as a baseline before we try knowledge distillation in the next problem.

In [None]:
# Train the tiny model
MODEL_NAME = 'vit_tiny_patch16_224'

model = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)


print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        model, criterion, cifar10_training_loader,
        optimizer, device, epoch)

    print("TRAINING STATISTICS")
    training_stats = evaluate(cifar10_training_loader, model, criterion, device)
    print("TEST STATISTICS")
    test_stats = evaluate(cifar10_test_loader, model, criterion, device)

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth" to /root/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
100%|██████████| 21.9M/21.9M [00:00<00:00, 80.1MB/s]


number of params: 5526346
Start training for 5 epochs


  return F.conv2d(input, weight, bias, self.stride,


Epoch: [1]  [  0/625]  eta: 0:03:00  loss: 2.6602 (2.6602)  time: 0.2880  data: 0.1580  max mem: 4284
Epoch: [1]  [100/625]  eta: 0:00:39  loss: 2.1288 (2.2799)  time: 0.0558  data: 0.0049  max mem: 4284
Epoch: [1]  [200/625]  eta: 0:00:28  loss: 2.0735 (2.1918)  time: 0.0557  data: 0.0046  max mem: 4284
Epoch: [1]  [300/625]  eta: 0:00:22  loss: 1.9165 (2.1087)  time: 0.0804  data: 0.0087  max mem: 4284
Epoch: [1]  [400/625]  eta: 0:00:14  loss: 1.6505 (2.0478)  time: 0.0587  data: 0.0050  max mem: 4284
Epoch: [1]  [500/625]  eta: 0:00:08  loss: 1.7300 (1.9995)  time: 0.0927  data: 0.0146  max mem: 4284
Epoch: [1]  [600/625]  eta: 0:00:01  loss: 1.6376 (1.9572)  time: 0.0573  data: 0.0048  max mem: 4284
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 1.7289 (1.9464)  time: 0.0542  data: 0.0045  max mem: 4284
Epoch: [1] Total time: 0:00:40 (0.0656 s / it)
Averaged stats: loss: 1.7289 (1.9464)
TRAINING STATISTICS
Test:  [  0/625]  eta: 0:02:11  loss: 2.1653 (2.1653)  acc1: 6.2500 (6.2500)  a

In [None]:
# save finetuned tiny model
torch.save(model.state_dict(), './tiny.pth')

In [None]:
tiny = create_model(
        'vit_tiny_patch16_224',
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
tiny = tiny.to(device)
tiny.load_state_dict(torch.load('./tiny.pth'))

test_stats = evaluate(cifar10_test_loader, tiny, criterion, device)

Test:  [ 0/40]  eta: 0:01:38  loss: 0.9032 (0.9032)  acc1: 71.4844 (71.4844)  acc5: 96.8750 (96.8750)  time: 2.4577  data: 2.1075  max mem: 4284
Test:  [39/40]  eta: 0:00:00  loss: 0.8065 (0.8204)  acc1: 71.4844 (71.3000)  acc5: 98.4375 (98.2400)  time: 0.4410  data: 0.1676  max mem: 4284
Test: Total time: 0:00:19 (0.4860 s / it)
* Acc@1 71.300 Acc@5 98.240 loss 0.820


In [None]:
# Calculate throughput
start_time = time.time()
test_stats = evaluate(cifar10_test_loader, tiny, criterion, device)
end_time = time.time()
num_samples = len(cifar10_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))

Test:  [ 0/40]  eta: 0:01:32  loss: 0.9032 (0.9032)  acc1: 71.4844 (71.4844)  acc5: 96.8750 (96.8750)  time: 2.3009  data: 1.9487  max mem: 4284
Test:  [39/40]  eta: 0:00:00  loss: 0.8065 (0.8204)  acc1: 71.4844 (71.3000)  acc5: 98.4375 (98.2400)  time: 0.3628  data: 0.1148  max mem: 4284
Test: Total time: 0:00:19 (0.4864 s / it)
* Acc@1 71.300 Acc@5 98.240 loss 0.820
Throughput: 513.8633316699122


The accuracy on our pretrained vit_tiny model is lower than the pretrained vit_base, as expected. However, it still performs much better than the vit_base model trained from scratch, showing the value of transfer learning. The throughput is higher on the tiny model because it's smaller, and the data goes through less processing to produce an output. So, the tiny model can process more data at a faster rate.

## Question 4

Now, we can perform knowledge distillation and compare it to the pretrained vit_tiny model. We expect to see an improvement. I implemented knowledge distillation by first softening the logits with softmax with temperature (and log on the student logits) before two losses that make up the total loss. Then, I computed KL loss between the student and teacher probabilities multiplied by the temperature squared (which is generally suggested). I also computed the Cross-Entropy loss between the student logits and the targets. These are combined for the total loss, which is the sum of alpha times the KL loss and beta times the Cross-Entropy loss. Alpha and beta are hyperparameters, which I chose 0.3 and 0.7 for, respectively.

In [None]:
# Train the distilled student
for p in teacher.parameters():
    p.requires_grad = False

MODEL_NAME = 'vit_tiny_patch16_224'

distilled = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
distilled = distilled.to(device)

optimizer = optim.Adam(distilled.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in distilled.parameters() if p.requires_grad)
print('number of params:', n_parameters)


print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch_distillation(
        teacher, distilled, criterion, cifar10_training_loader,
        optimizer, device, epoch, alpha=0.3, temp=1.0)

    print("TRAINING STATISTICS")
    training_stats = evaluate(cifar10_training_loader, distilled, criterion, device)
    print("TEST STATISTICS")
    test_stats = evaluate(cifar10_test_loader, distilled, criterion, device)

number of params: 5526346
Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:03:44  loss: 2.6307 (2.6307)  time: 0.3592  data: 0.1519  max mem: 4284
Epoch: [1]  [100/625]  eta: 0:00:57  loss: 2.0912 (2.1580)  time: 0.0997  data: 0.0053  max mem: 4284
Epoch: [1]  [200/625]  eta: 0:00:45  loss: 1.8580 (2.0526)  time: 0.1004  data: 0.0051  max mem: 4284
Epoch: [1]  [300/625]  eta: 0:00:34  loss: 1.8292 (1.9928)  time: 0.1151  data: 0.0099  max mem: 4284
Epoch: [1]  [400/625]  eta: 0:00:23  loss: 1.6827 (1.9295)  time: 0.0983  data: 0.0046  max mem: 4284
Epoch: [1]  [500/625]  eta: 0:00:13  loss: 1.5522 (1.8825)  time: 0.0984  data: 0.0049  max mem: 4284
Epoch: [1]  [600/625]  eta: 0:00:02  loss: 1.4754 (1.8387)  time: 0.1021  data: 0.0059  max mem: 4284
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 1.6401 (1.8307)  time: 0.0976  data: 0.0045  max mem: 4284
Epoch: [1] Total time: 0:01:05 (0.1052 s / it)
Averaged stats: loss: 1.6401 (1.8307)
TRAINING STATISTICS
Test:  [  0/625]  eta: 0:

In [None]:
# save distilled student model
torch.save(distilled.state_dict(), './distilled.pth')

In [None]:
distilled = create_model(
        'vit_tiny_patch16_224',
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
distilled = distilled.to(device)
distilled.load_state_dict(torch.load('./distilled.pth'))

test_stats = evaluate(cifar10_test_loader, distilled, criterion, device)

Test:  [ 0/40]  eta: 0:01:36  loss: 0.8084 (0.8084)  acc1: 71.8750 (71.8750)  acc5: 98.0469 (98.0469)  time: 2.4157  data: 2.0551  max mem: 4284
Test:  [39/40]  eta: 0:00:00  loss: 0.7412 (0.7383)  acc1: 74.6094 (74.0900)  acc5: 98.8281 (98.5800)  time: 0.4288  data: 0.1700  max mem: 4284
Test: Total time: 0:00:19 (0.4958 s / it)
* Acc@1 74.090 Acc@5 98.580 loss 0.738


Knowledge distillation did improve our model, but only by a little bit. With this, we could now run a better tiny model on something like a phone. I believe that tuning the alpha and temperature of the loss function could improve the results, if we're willing to run the model many times.