<h1>PyTorch MNIST Classification Using a Transformer</h1>
Welcome to my implementation of MNIST Classification using PyTorch. In this implementation I will be applying the Vision Transformer (Dosovitskiy et al., 2020) to MNIST Classification using a Google Colab notebook


In [24]:
!pip install vit-pytorch
!pip install -U fvcore

Collecting fvcore
  Downloading fvcore-0.1.5.post20210804.tar.gz (49 kB)
[K     |████████████████████████████████| 49 kB 3.7 MB/s 
Collecting yacs>=0.1.6
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 15.1 MB/s 
Collecting iopath>=0.1.7
  Downloading iopath-0.1.9-py3-none-any.whl (27 kB)
Collecting portalocker
  Downloading portalocker-2.3.0-py2.py3-none-any.whl (15 kB)
Building wheels for collected packages: fvcore
  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Created wheel for fvcore: filename=fvcore-0.1.5.post20210804-py3-none-any.whl size=60618 sha256=d72dcf77f08e214457fb81e5d52c3510b18797be16b3164cc0efe1766c6f23e0
  Stored in directory: /root/.cache/pip/wheels/1d/e2/fe/67887e71552be741faccead8f7a8e013b6e0b1225cf591afa1
Successfully built fvcore
Installing collected packages: pyyaml, portalocker, yacs, iopath, fvcore
  Att

In [None]:
# Check which GPU Google has generously provided us :)
!nvidia-smi -L

In [22]:
import torch
from torch import optim
import torch.nn.functional as F
import torchvision
from torchvision.models import resnet50 as rn50
from vit_pytorch import ViT
from fvcore.nn import flop_count, flop_count_str, flop_count_table

import time

In [10]:
# Declare Vision Transformer model
v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

In [11]:
# Prepare the dataset

n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7fb23afddbb0>

Now lets load the training and test sets. This includes 60k images for training and 10k images for testing.

In [13]:
# Load the dataset
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

And now we will begin training on the training set that we just loaded. We will loop over the dataset 3 times and optimize learning on the fly.

In [15]:
# Training the model

def train_epoch(model, optimizer, data_loader, loss_history):
    total_samples = len(data_loader.dataset)
    model.train()

    for i, (data, target) in enumerate(data_loader):
        optimizer.zero_grad()
        output = F.log_softmax(model(data), dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
            loss_history.append(loss.item())

In [16]:
# Evaluate the model on our test set

def evaluate(model, data_loader, loss_history):
    model.eval()
    
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0

    with torch.no_grad():
        for data, target in data_loader:
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)
            
            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()

    avg_loss = total_loss / total_samples
    loss_history.append(avg_loss)
    print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
          '  Accuracy:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(total_samples) + ' (' +
          '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n')

In [23]:
# Run training
N_EPOCHS = 25

start_time = time.time()
model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,
            dim=64, depth=6, heads=8, mlp_dim=128)
optimizer = optim.Adam(model.parameters(), lr=0.003)

train_loss_history, test_loss_history = [], []
for epoch in range(1, N_EPOCHS + 1):
    print('Epoch:', epoch)
    train_epoch(model, optimizer, train_loader, train_loss_history)
    evaluate(model, test_loader, test_loss_history)

print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')

Epoch: 1

Average test loss: 0.1710  Accuracy: 9472/10000 (94.72%)

Epoch: 2

Average test loss: 0.1323  Accuracy: 9579/10000 (95.79%)

Epoch: 3

Average test loss: 0.1536  Accuracy: 9542/10000 (95.42%)

Epoch: 4

Average test loss: 0.1163  Accuracy: 9633/10000 (96.33%)

Epoch: 5

Average test loss: 0.1060  Accuracy: 9655/10000 (96.55%)

Epoch: 6

Average test loss: 0.0855  Accuracy: 9732/10000 (97.32%)

Epoch: 7

Average test loss: 0.0835  Accuracy: 9735/10000 (97.35%)

Epoch: 8

Average test loss: 0.0899  Accuracy: 9730/10000 (97.30%)

Epoch: 9

Average test loss: 0.0837  Accuracy: 9739/10000 (97.39%)

Epoch: 10

Average test loss: 0.0727  Accuracy: 9777/10000 (97.77%)

Epoch: 11

Average test loss: 0.0719  Accuracy: 9783/10000 (97.83%)

Epoch: 12

Average test loss: 0.0743  Accuracy: 9789/10000 (97.89%)

Epoch: 13

Average test loss: 0.0804  Accuracy: 9757/10000 (97.57%)

Epoch: 14

Average test loss: 0.0733  Accuracy: 9788/10000 (97.88%)

Epoch: 15

Average test loss: 0.0717  Accur