##  Image Classification with Transformer

In [41]:
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
    """
    Convert input image into a sequence of flattened 'patches'
    which will be equivalent to tokens
    """
    def __init__(self, img_size: int = 32, patch_size:int = 4, in_channels: int = 3, embed_dim : int = 128):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size) ** 2
        self.projection = nn.Linear(patch_size * patch_size * in_channels, embed_dim)

    def forward(self, x):
        """
        x : (B, C, H, W) -> (B, N, embed_dim)
        """
        B, C, H, W = x.shape
        P = self.patch_size
        N = self.num_patches

        # Reshape into patches (B, C, H/P, P, W/P, P)
        x = x.unfold(2, P, P).unfold(3, P, P)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous() # (B, N_h, N_w, C, P, P)
        x = x.view(B, N, -1)
        x = self.projection(x)
        return x


In [43]:
# Sanity Test
IMAGE_SIZE = 32
PATCH_SIZE = 4
EMBED_DIM = 8
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
model = PatchEmbedding(IMAGE_SIZE, PATCH_SIZE, 3, EMBED_DIM)
x = torch.randn(2, 3, IMAGE_SIZE, IMAGE_SIZE)  # Batch size of 2
output = model(x)
assert output.shape == (2, NUM_PATCHES, EMBED_DIM), f"Unexpected shape: {output.shape}"
print("PatchEmbedding test passed")

PatchEmbedding test passed


In [36]:
class ViTTransformerEncoderCell(nn.Module):
    """
    Pre-Norm Transformer ecoder cell
    """
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()

        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        """
        x: (B, num_patches+1, embed_dim)
        """
        # Pre-Norm Self-attention
        residual = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x, need_weights=False)  # Self-attention
        x = residual + x  # Add residual connection

        # Pre-Norm feedforward
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = residual + x  # add residual

        return x

In [40]:
# Sanity test
batch_size = 2
encoder = ViTTransformerEncoderCell(EMBED_DIM, 8, 256)
x = torch.randn(batch_size, NUM_PATCHES + 1, EMBED_DIM)

output = encoder(x)
assert output.shape == x.shape, f"Expected shape {x.shape}, but got {output.shape}"
print("ViTTransformerEncoderCell test passed")

In [44]:
class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT)
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10,
                 embed_dim=128, num_heads=4, num_layers=4, ff_dim=512, dropout=0.1):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Learnable CLS Token & Positional encoding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_encoding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        # nn.init.trunc_normal_(self.pos_encoding, std=0.02)  # Stable init

        self.dropout = nn.Dropout(dropout)

        # Transformer Encoder Stack (as ModuleList for flexibility)
        self.encoder_stack = nn.ModuleList([ViTTransformerEncoderCell(embed_dim, num_heads, ff_dim, dropout)
                                      for _ in range(num_layers)])

        # Classification Head
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, num_patches, embed_dim)

        # Append CLS Token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, num_patches+1, embed_dim)

        # Add Positional Encoding with Dropout
        x = self.dropout(x + self.pos_encoding[:, :x.shape[1], :])

        # Transformer Encoding
        for layer in self.encoder_stack:
            x = layer(x)

        # Use CLS token for classification
        x = self.norm(x[:, 0])  # Extract CLS token output
        y =
        return self.fc(x)


In [46]:
# sanity check just to be sure lol debugging is killing me please somebody release me from this suffering
batch_size = 2
IMAGE_SIZE = 32
PATCH_SIZE = 4
num_classes = 10
embed_dim = 128
num_heads = 4
num_layers = 4
ff_dim = 512

# Initialize model
model = VisionTransformer(IMAGE_SIZE, PATCH_SIZE, 3, num_classes,
                            embed_dim, num_heads, num_layers, ff_dim)

# Create a dummy batch of images (B, C, H, W)
x = torch.randn(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE)

# Forward pass
output = model(x)

# Check output shape
assert output.shape == (batch_size, num_classes), f"Expected shape {(batch_size, num_classes)}, but got {output.shape}"
print("Sanity test passed so I don't have to")

Sanity test passed so I don't have to


### Retrieve and load CIFAR10

In [3]:
import time
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

def rel_error(x, y):
    """ returns relative error """
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))


In [4]:
# let's download the data
# !mkdir ../datasets
# !cd ../datasets

# 1 -- Linux
# 2 -- MacOS
# 3 -- Command Prompt on Windows
# 4 -- manually downloading the data
choice = 1


if choice == 1:
    # should work well on Linux and in Powershell on Windows
    !wget http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
elif choice == 2 or choice ==3:
    # if wget is not available for you, try curl
    # should work well on MacOS
    !curl http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz --output cifar-10-python.tar.gz
else:
    print('Please manually download the data from http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz and put it under the datasets folder.')
!tar -xzvf cifar-10-python.tar.gz

if choice==3:
    !del cifar-10-python.tar.gz
else:
    !rm cifar-10-python.tar.gz

--2025-03-17 22:03:34--  http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170498071 (163M) [application/x-gzip]
Saving to: ‘cifar-10-python.tar.gz’


2025-03-17 22:03:39 (33.4 MB/s) - ‘cifar-10-python.tar.gz’ saved [170498071/170498071]

cifar-10-batches-py/
cifar-10-batches-py/data_batch_4
cifar-10-batches-py/readme.html
cifar-10-batches-py/test_batch
cifar-10-batches-py/data_batch_3
cifar-10-batches-py/batches.meta
cifar-10-batches-py/data_batch_2
cifar-10-batches-py/data_batch_5
cifar-10-batches-py/data_batch_1


In [5]:
# helpful functions to process and load the data
from six.moves import cPickle as pickle
import numpy as np
import os
from imageio import imread
import platform

def load_pickle(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
        return  pickle.load(f)
    elif version[0] == '3':
        return  pickle.load(f, encoding='latin1')
    raise ValueError("invalid python version: {}".format(version))

def load_CIFAR_batch(filename):
  """ load single batch of cifar """
  with open(filename, 'rb') as f:
    datadict = load_pickle(f)
    X = datadict['data']
    Y = datadict['labels']
    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
    Y = np.array(Y)
    return X, Y

def load_CIFAR10(ROOT):
  """ load all of cifar """
  xs = []
  ys = []
  for b in range(1,6):
    f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
    X, Y = load_CIFAR_batch(f)
    xs.append(X)
    ys.append(Y)
  Xtr = np.concatenate(xs)
  Ytr = np.concatenate(ys)
  del X, Y
  Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
  return Xtr, Ytr, Xte, Yte


def get_CIFAR10_data(cifar10_dir, num_training=49000, num_validation=1000, num_test=1000,
                     subtract_mean=True):
    """
    Load the CIFAR-10 dataset from disk and perform preprocessing to prepare
    it for classifiers. These are the same steps as we used for the SVM, but
    condensed to a single function.
    """
    # Load the raw CIFAR-10 data
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

    # Subsample the data
    mask = list(range(num_training, num_training + num_validation))
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = list(range(num_training))
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = list(range(num_test))
    X_test = X_test[mask]
    y_test = y_test[mask]

    # Normalize the data: subtract the mean image
    if subtract_mean:
      mean_image = np.mean(X_train, axis=0)
      X_train -= mean_image
      X_val -= mean_image
      X_test -= mean_image

    # Transpose so that channels come first
    X_train = X_train.transpose(0, 3, 1, 2).copy()
    X_val = X_val.transpose(0, 3, 1, 2).copy()
    X_test = X_test.transpose(0, 3, 1, 2).copy()

    # Package data into a dictionary
    return {
      'X_train': X_train, 'y_train': y_train,
      'X_val': X_val, 'y_val': y_val,
      'X_test': X_test, 'y_test': y_test,
    }

In [6]:
# Load the (preprocessed) CIFAR10 data.
cifar10_dir = './cifar-10-batches-py'

data = get_CIFAR10_data(cifar10_dir, subtract_mean=True)

pix_mean = (0.485, 0.456, 0.406)
pix_std = (0.229, 0.224, 0.225)

for c in range(3):
    data['X_train'][:, c] = (data['X_train'][:, c] / 255 - pix_mean[c]) / pix_std[c]
    data['X_val'][:, c] = (data['X_val'][:, c] / 255 - pix_mean[c]) / pix_std[c]
    data['X_test'][:, c] = (data['X_test'][:, c] / 255 - pix_mean[c]) / pix_std[c]

for split in ['train', 'val', 'test']:
    print('===\nFor the split {}'.format(split))
    print('shape: {}'.format(data['X_{}'.format(split)].shape))
    print('data value range, min: {}, max: {}\n'.format(data['X_{}'.format(split)].min(), data['X_{}'.format(split)].max()))

===
For the split train
shape: (49000, 3, 32, 32)
data value range, min: -4.489820571085577, max: 0.8966644435551998

===
For the split val
shape: (1000, 3, 32, 32)
data value range, min: -4.489820571085577, max: 0.8966644435551998

===
For the split test
shape: (1000, 3, 32, 32)
data value range, min: -4.489820571085577, max: 0.8966644435551998



In [7]:
# no need to implement anything here
def set_up_cifar10_data_loader(images, labels, batch_size, shuffle=True):
    dataset = torch.utils.data.TensorDataset(torch.Tensor(images), torch.Tensor(labels))
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=2)
    return data_loader

### Implement training and testing function

In [51]:
def train_val_model(model, train_data_loader, val_data_loader, loss_fn, optimizer, lr_scheduler, num_epochs, print_freq=50):
    """
    Training and validating a the image classification model

    Inputs:
      - model: An image classification model
      - data_loader: A data loader that will provide batched images and labels
      - loss_fn: A loss function
      - optimizer: optimizer lol
      - lr_scheduler: Learning rate scheduler
      - num_epochs: Number of epochs in total
      - print_freq: Frequency to print training statistics

    Output:
      - model: Trained CNN model
    """

    for epoch_i in range(num_epochs):
        # set the model in the train mode so the batch norm layers will behave correctly
        model.train()

        running_loss = 0.0
        running_total = 0.0
        running_correct = 0.0
        for i, batch_data in enumerate(train_data_loader):
            # Every data instance is an image + label pair
            images, labels = batch_data
            images = images.cuda()
            labels = labels.cuda()

            optimizer.zero_grad()

            # forward pass
            logits = model(images)
            predicted = torch.argmax(logits, dim=1)

            # backward pass
            target_labels = labels.type(torch.LongTensor)
            target_labels = target_labels.cuda()
            loss = loss_fn(logits, target_labels)
            loss.backward()

            # optimize
            optimizer.step()


            # print statistics
            running_loss += loss.item()
            running_total += labels.size(0)
            running_correct += (predicted == labels).sum().item()
            if i % print_freq == 0:    # print every certain number of mini-batches
                running_loss = running_loss / print_freq
                running_acc = running_correct / running_total * 100
                last_lr = lr_scheduler.get_last_lr()[0]
                print(f'[{epoch_i + 1}/{num_epochs}, {i + 1:5d}/{len(train_data_loader)}] loss: {running_loss:.3f} acc: {running_acc:.3f} lr: {last_lr:.5f}')
                running_loss = 0.0
                running_total = 0.0
                running_correct = 0.0

        # adjust the learning rate
        lr_scheduler.step()

        val_acc = test_model(model, val_data_loader)
        print(f'[{epoch_i + 1}/{num_epochs}] val acc: {val_acc:.3f}')

    return model

In [50]:
# Function to test an already trained model
def test_model(model, data_loader):
    """
    Compute accuracy of the model.

    Inputs:
      - model: An image classification model
      - data_loader: A data loader that will provide batched images and labels
    """

    # set the model in evaluation mode so the batch norm layers will behave correctly
    model.eval()

    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for batch_data in data_loader:
            images, labels = batch_data
            images = images.cuda()
            labels = labels.cuda()

            logits = model(images)
            predicted = torch.argmax(logits, dim=1) #softmax preserves ranking anyway


            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = 100 * correct // total
    return acc

### Train the ViT Image Classifier

In [None]:
num_epochs = 3

model = VisionTransformer(
    img_size=32, patch_size=4, in_channels=3,
    num_classes=10,
    embed_dim=128, num_heads=4, num_layers=4, ff_dim=512,
    dropout=0.1)

batch_size = 64
learning_rate = 0.0005
momentum = 0.98
lr_gamma = 0.1

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

In [48]:
# set up the data loaders
# note the usage of the batch_size hyperparameter here
train_loader = set_up_cifar10_data_loader(data['X_train'], data['y_train'], batch_size, shuffle=True)
print("There are {} batches in the training set.".format(len(train_loader)))

val_loader = set_up_cifar10_data_loader(data['X_val'], data['y_val'], batch_size, shuffle=False)
print("There are {} batches in the validation set.".format(len(val_loader)))

test_loader = set_up_cifar10_data_loader(data['X_test'], data['y_test'], batch_size, shuffle=False)
print("There are {} batches in the testing set.".format(len(test_loader)))

num_params = sum(p.numel() for p in model.parameters())
print('Number of parameters: {:.3f}K'.format(num_params / 1000))

There are 766 batches in the training set.
There are 16 batches in the validation set.
There are 16 batches in the testing set.
Number of parameters: 809.354K


In [52]:
model = model.cuda()
model = train_val_model(model, train_loader, val_loader, loss_fn, optimizer, scheduler, num_epochs)
test_acc = test_model(model, test_loader)
print(f"testing accuracy: {test_acc:.3f}")

[1/3,     1/766] loss: 0.048 acc: 9.375 lr: 0.00050
[1/3,    51/766] loss: 2.220 acc: 16.562 lr: 0.00050
[1/3,   101/766] loss: 2.076 acc: 21.250 lr: 0.00050
[1/3,   151/766] loss: 1.941 acc: 24.844 lr: 0.00050
[1/3,   201/766] loss: 1.863 acc: 28.844 lr: 0.00050
[1/3,   251/766] loss: 1.857 acc: 29.844 lr: 0.00050
[1/3,   301/766] loss: 1.784 acc: 33.156 lr: 0.00050
[1/3,   351/766] loss: 1.763 acc: 33.719 lr: 0.00050
[1/3,   401/766] loss: 1.747 acc: 35.312 lr: 0.00050
[1/3,   451/766] loss: 1.696 acc: 36.969 lr: 0.00050
[1/3,   501/766] loss: 1.695 acc: 36.406 lr: 0.00050
[1/3,   551/766] loss: 1.677 acc: 37.250 lr: 0.00050
[1/3,   601/766] loss: 1.622 acc: 40.625 lr: 0.00050
[1/3,   651/766] loss: 1.601 acc: 40.125 lr: 0.00050
[1/3,   701/766] loss: 1.579 acc: 41.938 lr: 0.00050
[1/3,   751/766] loss: 1.540 acc: 44.469 lr: 0.00050
[1/3] val acc: 41.000
[2/3,     1/766] loss: 0.031 acc: 39.062 lr: 0.00049
[2/3,    51/766] loss: 1.563 acc: 42.844 lr: 0.00049
[2/3,   101/766] loss: 1.