In [130]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

# Data

In [131]:
import torch
from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize(size=[16,16]),
                                transforms.Normalize((0.5,), (0.5,))])

batch_size = 64

# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

# Exploration

# Transformer

## Patch embedding

Reshape the image of each image of size HxWxC to flattened 2D patches of size Nx(P^2*C) 
where P = new resolution (16x16)
and N = the new number of patches = input length for the transformer    
D = some constant, latent vector size to flatten and map the patches onto D dimensions

In [134]:
import torch
from torch import nn
from torchvision.transforms import Compose, Resize, ToTensor
from einops.layers.torch import Rearrange


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=28):
        super().__init__()
        self.patch_size = patch_size
        self.pos_embedding = nn.Parameter(torch.randn(1, (img_size // patch_size) ** 2 + 1, emb_size))
        self.projection = nn.Sequential(
            # Splitting image into patches and flattening
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))

    def forward(self, x):
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = self.cls_token.expand(b, -1, -1)
        # Prepend the cls token to the input
        x = torch.cat((cls_tokens, x), dim=1)
        # Add positional embedding
        x += self.pos_embedding[:, :(x.size(1))]
        return x

## Transformer encoder

In [135]:
class TransformerEncoder(nn.Module):
    def __init__(self, emb_size=768, depth=12, heads=12, mlp_dim=3072):
        super().__init__()
        # The encoder stack
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.TransformerEncoderLayer(d_model=emb_size, nhead=heads, dim_feedforward=mlp_dim))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

## ViT Model

In [136]:
class ViT(nn.Module):
    def __init__(self, in_channels=1, patch_size=16, emb_size=256, img_size=28, depth=6, heads=8, mlp_dim=512, num_classes=10):
        super().__init__()
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.transformer_encoder = TransformerEncoder(emb_size, depth, heads, mlp_dim)
        self.classifier = nn.Linear(emb_size, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.transformer_encoder(x)
        # Classifier on the class token
        x = x[:, 0]
        x = self.classifier(x)
        return x

# Training

In [186]:
import torch

# Checks if CUDA is available, and sets device to CUDA if it is, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


For the Vision Transformer to function correctly, it doesn't necessarily need to change its architecture based on H and W directly, as long as the total number of patches (determined by H, W, and P) remains consistent with the model's expectations. The model's architecture (number of layers, heads, embedding dimensions) typically remains fixed once defined.

In [205]:
import torch
from torch import nn
from torch import optim

H, W = next(iter(trainloader))[0].shape[2], next(iter(trainloader))[0].shape[3]
C = next(iter(trainloader))[0].shape[1]
P = 16

model = ViT(
    in_channels=C, 
    patch_size=P, 
    emb_size= (P * P), 
    img_size=H, 
    depth=6, 
    heads=8, 
    mlp_dim=512, 
    num_classes=10).to(device)

criterion = nn.CrossEntropyLoss() #Classification 
optimizer = optim.Adam(model.parameters(), lr=0.003)

In [206]:
from tqdm import tqdm 

epochs = 5

for epoch in range(epochs):
    running_loss = 0
    for images, labels in tqdm(trainloader):  # Assuming trainloader is defined
        # Reset the gradients
        optimizer.zero_grad()
        
        # Forward pass: compute predicted outputs by passing inputs to the model
        outputs = model(images.to(device))
        
        # Calculate the loss
        loss = criterion(outputs, labels.to(device))
        
        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        
        # Perform a single optimization step (parameter update)
        optimizer.step()
        
        running_loss += loss.item()
    else:
        print(f"Epoch {epoch+1}/{epochs} - Training loss: {running_loss/len(trainloader)}")


100%|██████████| 938/938 [00:47<00:00, 19.83it/s]


Epoch 1/5 - Training loss: 2.3268222409779074


100%|██████████| 938/938 [00:59<00:00, 15.77it/s]


Epoch 2/5 - Training loss: 2.3040264486504007


100%|██████████| 938/938 [00:54<00:00, 17.27it/s]


Epoch 3/5 - Training loss: 2.3030850340816764


100%|██████████| 938/938 [00:54<00:00, 17.22it/s]


Epoch 4/5 - Training loss: 2.303072214126587


100%|██████████| 938/938 [00:54<00:00, 17.12it/s]

Epoch 5/5 - Training loss: 2.3030997125833017





# Testing

In [185]:
true_y, y = [], []
for images, labels in tqdm(testloader):
    outputs = model(images)
    y.append(outputs.argmax(dim=1))
    true_y.append(labels)
acc = sum([(sum(torch.eq(y[i], true_y[i]))/64).item() for i in range(len(y))])/len(testloader)
print("Accuracy: {acc}".format(acc=acc))

100%|██████████| 157/157 [00:07<00:00, 19.86it/s]

Accuracy: 0.1129578025477707



