# Import Libraries

In [14]:
!pip -q install vit_pytorch

In [15]:
# import all libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import random_split

import torchvision
import torchvision.transforms as transforms

import os
import argparse


from vit_pytorch import ViT
from tqdm.notebook import tqdm
from torch.optim.lr_scheduler import StepLR


# Load Data

In [17]:
# Set seed for reproducibility
torch.manual_seed(42)

# these are commonly used data augmentations
# random cropping and random horizontal flip
# lastly, we normalize each channel into zero mean and unit standard deviation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

full_data = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

# Split training and validation data set
train_size = int(0.8 * len(full_data))
val_size = len(full_data) - train_size
train_dataset, temp_validation_dataset = random_split(full_data, [train_size, val_size])

print(f"Training size: {train_size}")
print(f"Validation size: {val_size}")

# Reinitialize CIFAR10 with test transform so that validation dataset does not have data augmentation
full_data_no_aug = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_test)

# Use the same indices from the split to get validation subset
validation_dataset = torch.utils.data.Subset(full_data_no_aug, temp_validation_dataset.indices)

# Get training and validation loader
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=2)

validationloader = torch.utils.data.DataLoader(
    validation_dataset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)

# we can use a larger batch size during test, because we do not save
# intermediate variables for gradient computation, which leaves more memory
testloader = torch.utils.data.DataLoader(
    testset, batch_size=256, shuffle=False, num_workers=2)

print(f"Test size: {len(testset)}")

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

Training size: 40000
Validation size: 10000
Test size: 10000


# ViT

In [18]:
device = 'cuda'

In [19]:
model = ViT(
    dim=128,
    image_size=32,
    patch_size=4,
    num_classes=10,
    channels=3,
    depth = 6,
    heads = 16,
    mlp_dim = 1024
).to(device)

# Training

In [20]:
# Training settings
batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42

In [21]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [None]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(trainloader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(trainloader)
        epoch_loss += loss / len(trainloader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in validationloader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(validationloader)
            epoch_val_loss += val_loss / len(validationloader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )
