<a href="https://colab.research.google.com/github/foxtrotmike/CS909/blob/master/mnist_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [9]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from einops import rearrange

class Residual(nn.Module):
    # Implements a residual connection followed by a function fn.
    def __init__(self, fn):
        super().__init__()
        self.fn = fn  # The function to apply before adding the input back.

    def forward(self, x, **kwargs):
        # Applies the function and adds the input x back to the result.
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    # Applies layer normalization before a function fn.
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # Normalizes the input across the specified dimension.
        self.fn = fn  # The function to apply after normalization.

    def forward(self, x, **kwargs):
        # Normalizes the input and then applies the function.
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    # Implements a feedforward neural network with one hidden layer.
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),  # Linear transformation to the hidden dimension.
            nn.GELU(),  # GELU non-linearity.
            nn.Linear(hidden_dim, dim)  # Linear transformation back to the original dimension.
        )

    def forward(self, x):
        # Passes the input through the feedforward network.
        return self.net(x)

class Attention(nn.Module):
    # Implements the multi-head self-attention mechanism.
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads  # Number of attention heads.
        self.scale = dim ** -0.5  # Scaling factor for the dot products.

        # Linear layer to project the input into queries, keys, and values.
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        # Linear layer to project the concatenated outputs back to the original dimension.
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x, mask=None):
        b, n, _, h = *x.shape, self.heads
        # Projects input to queries, keys, and values.
        qkv = self.to_qkv(x)
        # Rearranges the projections for multi-head attention processing.
        q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)

        # Calculates the dot products of queries and keys, scales the result.
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale

        if mask is not None:
            # Applies masking if provided, for handling variable sequence lengths.
            mask = F.pad(mask.flatten(1), (1, 0), value=True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, float('-inf'))
            del mask

        # Applies softmax to get the attention weights.
        attn = dots.softmax(dim=-1)

        # Uses einsum to compute weighted sum of values.
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        # Rearranges the output back to the original input shape.
        out = rearrange(out, 'b h n d -> b n (h d)')
        # Projects the output back to the original dimension.
        out = self.to_out(out)
        return out

class Transformer(nn.Module):
    # Implements a sequence of transformer blocks.
    def __init__(self, dim, depth, heads, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            # Each block consists of a multi-head attention layer and a feedforward layer, both wrapped in residual and normalization layers.
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads=heads))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
            ]))

    def forward(self, x, mask=None):
        # Processes the input through each block in sequence.
        for attn, ff in self.layers:
            x = attn(x, mask=mask)  # Applies attention.
            x = ff(x)  # Applies feedforward network.
        return x

class ViT(nn.Module):
    # Vision Transformer (ViT) for image classification.
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):
        super().__init__()
        assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2  # Dimensionality of the flattened patch.

        self.patch_size = patch_size

        # Learnable position embeddings.
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # Linear layer to project flattened patches to the dimensionality of the transformer.
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        # Learnable class token.
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # The transformer core.
        self.transformer = Transformer(dim, depth, heads, mlp_dim)

        self.to_cls_token = nn.Identity()

        # The "head" of the ViT to make final class predictions.
        self.mlp_head = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, num_classes)
        )

    def forward(self, img, mask=None):
        p = self.patch_size

        # Rearranges the input image into non-overlapping patches.
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)

        # Projects patches to embeddings.
        x = self.patch_to_embedding(x)

        # Prepends the class token to the sequence of embedded patches.
        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        # Adds position embeddings to the patch embeddings.
        x += self.pos_embedding
        # Processes the sequence through the transformer.
        x = self.transformer(x, mask)

        # Extracts the class token and applies the classification head.
        x = self.to_cls_token(x[:, 0])
        return self.mlp_head(x)


In [10]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,
            dim=64, depth=6, heads=8, mlp_dim=128)
# Define a transform to normalize the data
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # Ensure the image size is 28x28
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize((0.5,), (0.5,))  # Normalize the images
])

# Download and load the training data
trainset = MNIST('', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
testset = MNIST('', download=True, train=False, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
import torch.optim as optim

# Define the loss function and optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)

epochs = 5
for epoch in range(epochs):
    running_loss = 0
    for images, labels in tqdm(trainloader):
        # Reset the gradients to zero
        optimizer.zero_grad()

        # Forward pass
        output = model(images)
        loss = loss_function(output, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    else:
        print(f"Training loss: {running_loss/len(trainloader)}")

correct = 0
total = 0
with torch.no_grad():
    model.eval()
    for images, labels in tqdm(testloader):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')


  0%|          | 0/938 [00:00<?, ?it/s]

Training loss: 0.34086699017694894


  0%|          | 0/938 [00:00<?, ?it/s]

Training loss: 0.13769662428869686


  0%|          | 0/938 [00:00<?, ?it/s]

Training loss: 0.11251812642276375


  0%|          | 0/938 [00:00<?, ?it/s]

Training loss: 0.0998129521939419


  0%|          | 0/938 [00:00<?, ?it/s]

Training loss: 0.08331751398621068


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97.64 %
