In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import numpy as np
import random 
import matplotlib.pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
torch.manual_seed(42)
random.seed(42)

In [None]:
# Hyperparmeters
augmentation = False

batch_size = 128
epoch = 10
learning_rate = 3e-4
patch_size = 4
n_classes = 10 
img_size = 32
channels = 3
embed_dim = 256
n_heads = 8  # Number of multi-headed attention
depth = 6 # Number of transformer blocks
mlp_dim = 512
drop_rate = 0.1

In [None]:
def transform_settings(aug= False):
    if aug == False:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5)) # Normalize images, converge faster and more stable
            ])
        return transform
    else:
        transform_augmentation = transforms.Compose([
            transforms.RandomCrop(32, padding= 4), # Square crop of size x size, pad all border
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), # Randomly change the brightness
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])     
        return transform_augmentation
    
transform = transform_settings(augmentation)

In [6]:
# Load datasets
train_data = datasets.CIFAR10(root='data', train= True, 
                              download= True, transform = transform)

In [7]:
test_data = datasets.CIFAR10(root='data', train=False,
                             download= True, transform= transform)

In [8]:
train_data

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: data
    Split: Train
    StandardTransform
Transform: Compose(
               RandomCrop(size=(32, 32), padding=4)
               RandomHorizontalFlip(p=0.5)
               ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.2, 0.2))
               ToTensor()
               Normalize(mean=0.5, std=0.5)
           )

In [9]:
test_data

Dataset CIFAR10
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: Compose(
               RandomCrop(size=(32, 32), padding=4)
               RandomHorizontalFlip(p=0.5)
               ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.2, 0.2))
               ToTensor()
               Normalize(mean=0.5, std=0.5)
           )

In [10]:
# Convert to DataLoader (turn data into batches)
train_loader = DataLoader(dataset=train_data,
                          batch_size= batch_size,
                          shuffle= True)

test_loader = DataLoader(dataset=test_data,
                         batch_size=batch_size,
                         shuffle=False)

In [None]:
# Building Vision transformer
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, channels, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels=channels,
                              out_channels=embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)
        n_patches = (img_size // patch_size) **2 # define the number of patches 
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, 1+n_patches, embed_dim))

    def forward(self, x: torch.Tensor):
        B = x.size(0)
        x = self.proj(x) # (Batch, Embed, Height/Patch_size, Width/Patch_size)
        x = x.flatten(2).transpose(1,2) # (Batch, N_patch, Embed) 
        # flatten from the dimension 2 onward, switch dim2 and dim2 
        # because transformer expect (Batch, Seq_len, Embed_size)
        cls_token = self.cls_token.expand(B, -1, -1) # Give me B copies of the CLS token -> (Batch, 1, Embed_size)
        x = torch.cat((cls_token, x), dim=1) # -> (Batch, 1 + N_patch, Embed_size) 
        x = x + self.pos_embed

        return x 

In [12]:
class MLP(nn.Module):
    def  __init__(self, in_features, hidden_dim, drop_rate):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_features= in_features)
        self.dropout = nn.Dropout(drop_rate)

    def forward(self, x):
        x = self.dropout(F.gelu(self.fc1(x)))  # F to replace nn.
        x = self.dropout(self.fc2(x))

        return x

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, n_heads, mlp_dim, drop_rate):
        super().__init__()
        self.normalization1 = nn.LayerNorm(embed_dim) # applied per token, normalize the embedding dimension
        self.attention = nn.MultiheadAttention(embed_dim, n_heads, drop_rate, batch_first= True)
        self.normalization2= nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_dim, drop_rate)

    def forward(self, x):
        x_norm = ....
        attention, _ = self.attention(self.normalization1(x), self.normalization1(x), self.normalization1(x))  
        x = x + attention 
        x = x + self.mlp(self.normalization2(x)) 
        return x 

In [14]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, channels, n_classes, embed_dim, depth, n_heads, mlp_dim, drop_rate):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, channels, embed_dim)
        self.encoder = nn.Sequential(*[
            TransformerEncoder(embed_dim, n_heads, mlp_dim, drop_rate)
            for _ in range(depth) 
            ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.encoder(x)
        x = self.norm(x)
        cls_token = x[:, 0]

        return self.head(cls_token)


In [15]:
model = VisionTransformer(img_size, patch_size, channels, n_classes, embed_dim, depth, n_heads, mlp_dim, drop_rate).to(device)

In [16]:
 # Visualize model 
model 

VisionTransformer(
  (patch_embedding): PatchEmbedding(
    (proj): Conv2d(3, 256, kernel_size=(4, 4), stride=(4, 4))
  )
  (encoder): Sequential(
    (0): TransformerEncoder(
      (normalization1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (normalization2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=256, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (1): TransformerEncoder(
      (normalization1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (normalization2): LayerNorm((256,), eps=1e-05, e

In [17]:
 # Loss function and Optimizer

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr= learning_rate)

In [123]:
# Training model 

def train(model, loader, optimizer, criterion):
    model.train()

    total_loss, correct = 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        # forward pass: model output raw logit
        out = model(x) # -> (examples in a batch, logit for all classes)
        # calculate the mean loss per batch (per sample)
        loss = criterion(out, y) # -> a scalar, mean loss for that batch 
        # backpropagation
        loss.backward()   
        # gradient descent
        optimizer.step()

        total_loss += loss.item() * x.size(0)  # mean loss per sample * batch size = total loss for this batch
        correct += (out.argmax(1) == y).sum().item() #  finds the index of the max value along dimension 1
    return total_loss / len(loader.dataset), correct / len(loader.dataset) # averge loss per sample & accuracy

# Evaluate model
def eval(model, loader):
    model.eval()
    correct = 0 

    with torch.inference_mode():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            correct += (out.argmax(dim=1) == y).sum().item()
    return correct / len(loader.dataset)

In [124]:
# Training 
train_accuracy = []
test_accuracy = []

for e in range(epoch):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion) 
    test_acc = eval(model, test_loader)
    train_accuracy.append(train_acc)
    test_accuracy.append(test_acc)
    print(f'Epoch {e} / {epoch}: Train accuracy: {train_acc:.2f}%, Train loss: {train_loss:.2f}, Test accuracy: {test_acc:.2f}%')

KeyboardInterrupt: 