In [11]:
from vit import VisionTransformer
import vit
import transform
import torch
import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
import numpy as np
import random 
import matplotlib.pyplot as plt

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
torch.manual_seed(42)
random.seed(42)

In [5]:
# Hyperparmeters
augmentation = False

batch_size = 128
epoch = 10
learning_rate = 3e-4
patch_size = 4
n_classes = 10 
img_size = 32
channels = 3
embed_dim = 256
n_heads = 8  # Number of multi-headed attention
depth = 6 # Number of transformer blocks
mlp_dim = 512
drop_rate = 0.1

In [6]:
transform = transform.transform_settings(augmentation)

In [7]:
# Load datasets
train_data = datasets.CIFAR10(root='data', train= True, 
                              download= True, transform = transform)

In [8]:
test_data = datasets.CIFAR10(root='data', train=False,
                             download= True, transform= transform)

In [9]:
train_data

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=0.5, std=0.5)
           )

In [10]:
test_data

Dataset CIFAR10
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=0.5, std=0.5)
           )

In [10]:
# Convert to DataLoader (turn data into batches)
train_loader = DataLoader(dataset=train_data,
                          batch_size= batch_size,
                          shuffle= True)

test_loader = DataLoader(dataset=test_data,
                         batch_size=batch_size,
                         shuffle=False)

In [15]:
model = VisionTransformer(img_size, patch_size, channels, n_classes, embed_dim, depth, n_heads, mlp_dim, drop_rate).to(device)

In [16]:
 # Visualize model 
model 

VisionTransformer(
  (patch_embedding): PatchEmbedding(
    (proj): Conv2d(3, 256, kernel_size=(4, 4), stride=(4, 4))
  )
  (encoder): Sequential(
    (0): TransformerEncoder(
      (normalization1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (normalization2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=256, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (1): TransformerEncoder(
      (normalization1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (normalization2): LayerNorm((256,), eps=1e-05, e

In [17]:
 # Loss function and Optimizer

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr= learning_rate)

In [None]:
# Training 
train_accuracy = []
test_accuracy = []

for e in range(epoch):
    train_loss, train_acc = vit.train(model, train_loader, optimizer, criterion, device) 
    test_acc = vit.eval(model, test_loader, device)
    train_accuracy.append(train_acc)
    test_accuracy.append(test_acc)
    print(f'Epoch {e} / {epoch}: Train accuracy: {train_acc:.2f}%, Train loss: {train_loss:.2f}, Test accuracy: {test_acc:.2f}%')

KeyboardInterrupt: 