# 5. Transformers on MNIST

### About this notebook

This notebook was used in the 50.039 Deep Learning course at the Singapore University of Technology and Design.

**Author:** Matthieu DE MARI (matthieu_demari@sutd.edu.sg)

**Version:** 1.0 (28/02/2023)

**Requirements:**
- Python 3 (tested on v3.9.6)
- Torch (tested on v1.10.1)
- Torchvision (tested on v0.11.2)
- We also strongly recommend setting up CUDA on your machine!

### Imports and CUDA

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if CUDA else "cpu")

### Load MNIST

In [2]:
# Load the MNIST dataset and prepare dataloaders as usual
mnist_train = datasets.MNIST(root='.', train = True, download = True,
                             transform = transforms.ToTensor())
mnist_test = datasets.MNIST(root='.', train = False, download = True,
                            transform = transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size = 32, shuffle = True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 32, shuffle = False)

### Define self-attention layer, and Transformer model

We will have to flatten the images to process them with Linear operations and attention operations.

In [3]:
# Define a self-attention layer implementation
class SelfAttentionLayer(nn.Module):
    def __init__(self, in_features):
        super(SelfAttentionLayer, self).__init__()
        self.in_features = in_features
        self.query = nn.Linear(in_features, in_features)
        self.key = nn.Linear(in_features, in_features)
        self.value = nn.Linear(in_features, in_features)

    def forward(self, x):
        batch_size = x.size(0)
        query = self.query(x).view(batch_size, -1, self.in_features)
        key = self.key(x).view(batch_size, -1, self.in_features)
        value = self.value(x).view(batch_size, -1, self.in_features)
        attention_weights = F.softmax(torch.bmm(query, key.transpose(1, 2))/(self.in_features**0.5), dim = 2)
        out = torch.bmm(attention_weights, value).view(batch_size, -1)
        return out

In [4]:
# Neural network definition using self-attention
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.attention1 = SelfAttentionLayer(128)
        self.fc2 = nn.Linear(128, 64)
        self.attention2 = SelfAttentionLayer(64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.attention1(x)
        x = F.relu(self.fc2(x))
        x = self.attention2(x)
        x = self.fc3(x)
        return x

### Try out our model

In [5]:
# Create model
model = Transformer()
print(model)

Transformer(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (attention1): SelfAttentionLayer(
    (query): Linear(in_features=128, out_features=128, bias=True)
    (key): Linear(in_features=128, out_features=128, bias=True)
    (value): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (attention2): SelfAttentionLayer(
    (query): Linear(in_features=64, out_features=64, bias=True)
    (key): Linear(in_features=64, out_features=64, bias=True)
    (value): Linear(in_features=64, out_features=64, bias=True)
  )
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)


### Simple trainer like before

In [6]:
# Create model
model = Transformer()
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

# Train the model
n_epochs = 5
for epoch in range(n_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Flatten image
        images = images.reshape(-1, 28*28)
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Display
        if (i + 1) % 100 == 0:
            print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(epoch + 1, \
                                                                     n_epochs, \
                                                                     i + 1, \
                                                                     len(train_loader), \
                                                                     loss.item()))

Epoch [1/10], Step [100/1875], Loss: 0.6941
Epoch [1/10], Step [200/1875], Loss: 0.2044
Epoch [1/10], Step [300/1875], Loss: 0.7688
Epoch [1/10], Step [400/1875], Loss: 0.4087
Epoch [1/10], Step [500/1875], Loss: 0.3414
Epoch [1/10], Step [600/1875], Loss: 0.0827
Epoch [1/10], Step [700/1875], Loss: 0.1491
Epoch [1/10], Step [800/1875], Loss: 0.5172
Epoch [1/10], Step [900/1875], Loss: 0.1956
Epoch [1/10], Step [1000/1875], Loss: 0.2963
Epoch [1/10], Step [1100/1875], Loss: 0.1300
Epoch [1/10], Step [1200/1875], Loss: 0.4159
Epoch [1/10], Step [1300/1875], Loss: 0.0367
Epoch [1/10], Step [1400/1875], Loss: 0.3448
Epoch [1/10], Step [1500/1875], Loss: 0.2060
Epoch [1/10], Step [1600/1875], Loss: 0.0879
Epoch [1/10], Step [1700/1875], Loss: 0.3513
Epoch [1/10], Step [1800/1875], Loss: 0.2111
Epoch [2/10], Step [100/1875], Loss: 0.2226
Epoch [2/10], Step [200/1875], Loss: 0.1874
Epoch [2/10], Step [300/1875], Loss: 0.0531
Epoch [2/10], Step [400/1875], Loss: 0.0739
Epoch [2/10], Step [500

### Evaluate model

We get a 97% test accuracy, after only 5 iterations of training!

Better performance could be obtained by combining Convolutional operations and Attention ones?

In [7]:
# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        # Flatten images
        images = images.reshape(-1, 28 * 28)
        # Forward pass and accuracy calculation
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    # Final display
    print("Test Accuracy: {} %".format(100*correct/total))

Test Accuracy: 97.04 %
