In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

import torch
from torchvision import datasets, transforms

from torch import nn, optim
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [2]:
# seed 
seed = 17
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [3]:
# set device
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [4]:
# load MINST dataset, train + test as dataset


# define a transform to normalize the data
transform_ = transforms.Compose([transforms.ToTensor()])

# download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform_)
testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform_)

data = torch.utils.data.ConcatDataset([trainset, testset])

# make custum train, val and test sets from data in ratio 0.7, 0.1 and 0.2
train_size = int(0.7 * len(data))
val_size = int(0.1 * len(data))
test_size = len(data) - train_size - val_size

train_data, val_data, test_data = torch.utils.data.random_split(data, [train_size, val_size, test_size])

# prepare data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

## ViT

In [5]:
class ViT(nn.Module):
    def __init__(self, img_width, img_channels, patch_size, d_model, num_heads, num_layers, num_classes, ff_dim):
        super().__init__()

        self.patch_size = patch_size

        # given 7x7 flattened patch, map it into an embedding
        self.patch_embedding = nn.Linear(img_channels * patch_size * patch_size, d_model)

        # cls_token we are using we will be concatenating
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

        # (1, 4*4 + 1, 64)
        # + 1 because we add cls tokens
        self.position_embedding = nn.Parameter(
            torch.rand(1, (img_width // patch_size) * (img_width // patch_size) + 1, d_model)
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # mapping 64 to 10 at the end
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        N, C, H, W = x.shape

        # we divide the image into 4 different 7x7 patches, and then flatten those patches
        # img shape will be 4*4 x 7*7
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.contiguous().view(N, C, -1, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 3, 4, 1).contiguous().view(N, -1, C * self.patch_size * self.patch_size)

        # each 7*7 flatten patch will be embedded to 64 dim vector
        x = self.patch_embedding(x)

        # cls tokens concatenated after repeating it for the batch
        cls_tokens = self.cls_token.repeat(N, 1, 1)
        x = torch.cat((cls_tokens, x), dim=1)

        # learnable position embeddings added
        x = x + self.position_embedding

        # transformer takes 17x64 tensor, like it is a sequence with 17 words (17 because 4*4 + 1 from cls)
        x = self.transformer_encoder(x)

        # only taking the transformed output of the cls token
        x = x[:, 0]

        # mapping to number of classes
        x = self.fc(x)

        return x


In [6]:
def trainViT(model, train_loader, val_loader, epochs, lr, output_path):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # create output file
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    best_loss = 1e10

    # create log file
    log_file = open(output_path + 'log.txt', 'w')
    training_loss = []
    validation_loss = []
    training_acc = []
    validation_acc = []
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        correct = 0
        total = 0
        for data, target in tqdm(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
        training_loss.append(epoch_loss / len(train_loader))
        training_acc.append(100 * correct / total)

        model.eval()
        epoch_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                epoch_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
        validation_loss.append(epoch_loss / len(val_loader))
        validation_acc.append(100 * correct / total)

        print(f'Epoch {epoch + 1}/{epochs} Training Loss: {training_loss[-1]:.3f} Validation Loss: {validation_loss[-1]:.3f} Training Acc: {training_acc[-1]:.2f} Validation Acc: {validation_acc[-1]:.2f}')

        log_file.write(f'Epoch {epoch + 1}/{epochs} Training Loss: {training_loss[-1]:.3f} Validation Loss: {validation_loss[-1]:.3f} Training Acc: {training_acc[-1]:.2f} Validation Acc: {validation_acc[-1]:.2f}\n')

        if validation_loss[-1] < best_loss:
            best_loss = validation_loss[-1]
            torch.save(model.state_dict(), output_path + 'best_model.pth')

    # save last model
    torch.save(model.state_dict(), output_path + 'last_model.pth')

    log_file.close()


        

In [7]:
batch_size = 128
lr = 3e-4
num_epochs = 15

img_width = 28
img_channels = 1
num_classes = 10
patch_size = 7
embedding_dim = 64
ff_dim = 2048
num_heads = 8
num_layers = 3
weight_decay = 1e-4

In [8]:
model = ViT(img_width, img_channels, patch_size, embedding_dim, num_heads, num_layers, num_classes, ff_dim).to(device)

In [9]:
trainViT(model, train_loader, val_loader, num_epochs, lr, 'output/')

100%|██████████| 766/766 [00:08<00:00, 89.88it/s]


Epoch 1/15 Training Loss: 0.630 Validation Loss: 0.189 Training Acc: 79.72 Validation Acc: 94.50


100%|██████████| 766/766 [00:08<00:00, 91.45it/s]


Epoch 2/15 Training Loss: 0.198 Validation Loss: 0.131 Training Acc: 94.02 Validation Acc: 96.13


 19%|█▉        | 147/766 [00:01<00:06, 91.37it/s]


KeyboardInterrupt: 