In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt


In [None]:
train_dataset = torchvision.datasets.QMNIST('./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.QMNIST('./data', train=False, download=True, transform=torchvision.transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=60, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=60, shuffle=False)

print(f"Training with {len(train_dataset)} samples")
print(f"Testing with {len(test_dataset)} samples")


In [None]:
PATCH_SIZE = 4
IMAGE_WIDTH = 28
IMAGE_HEIGHT = IMAGE_WIDTH
IMAGE_CHANNELS = 1
EMBEDDING_DIMS = IMAGE_CHANNELS * PATCH_SIZE**2
NUM_OF_PATCHES = int((IMAGE_WIDTH * IMAGE_HEIGHT) / PATCH_SIZE**2)

assert IMAGE_WIDTH % PATCH_SIZE == 0 and IMAGE_HEIGHT % PATCH_SIZE ==0 , print("Image Width is not divisible by patch size")


In [None]:
class PatchEmbeddingLayer(nn.Module):
    def __init__(self, in_channels, patch_size, embedding_dim):
        super().__init__()
        self.patch_size = patch_size
        self.embedding_dim = embedding_dim
        self.in_channels = in_channels
        self.conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size)
        self.flatten_layer = nn.Flatten(start_dim=1, end_dim=2)
        self.class_token_embeddings = nn.Parameter(torch.rand((60, 1, EMBEDDING_DIMS), requires_grad=True))
        self.position_embeddings = nn.Parameter(torch.rand((1, NUM_OF_PATCHES + 1, EMBEDDING_DIMS), requires_grad=True))

    def forward(self, x):
        output = torch.cat((self.class_token_embeddings, self.flatten_layer(self.conv_layer(x).permute((0, 2, 3, 1)))), dim=1) + self.position_embeddings
        return output
    
patch_embedding_layer = PatchEmbeddingLayer(in_channels=IMAGE_CHANNELS, patch_size=PATCH_SIZE, embedding_dim=IMAGE_CHANNELS * PATCH_SIZE ** 2)

patch_embeddings = patch_embedding_layer(next(iter(train_loader))[0])
patch_embeddings.shape


In [None]:
class MultiHeadSelfAttentionBlock(nn.Module):
    def __init__(
        self,
        embedding_dims=768,  # Hidden Size D in the ViT Paper Table 1
        num_heads=12,  # Heads in the ViT Paper Table 1
        attn_dropout=0.0,  # Default to Zero as there is no dropout for the the MSA Block as per the ViT Paper
    ):
        super().__init__()

        self.embedding_dims = embedding_dims
        self.num_head = num_heads
        self.attn_dropout = attn_dropout

        self.layernorm = nn.LayerNorm(normalized_shape=embedding_dims)

        self.multiheadattention = nn.MultiheadAttention(
            num_heads=num_heads,
            embed_dim=embedding_dims,
            dropout=attn_dropout,
            batch_first=True,
        )

    def forward(self, x):
        x = self.layernorm(x)
        output, _ = self.multiheadattention(query=x, key=x, value=x, need_weights=False)
        return output


multihead_self_attention_block = MultiHeadSelfAttentionBlock(
    embedding_dims=EMBEDDING_DIMS, num_heads=2
)
print(
    f"Shape of the input Patch Embeddings => {list(patch_embeddings.shape)} <= [batch_size, num_patches+1, embedding_dims ]"
)
print(
    f"Shape of the output from MSA Block => {list(multihead_self_attention_block(patch_embeddings).shape)} <= [batch_size, num_patches+1, embedding_dims ]"
)


In [None]:
class MachineLearningPerceptronBlock(nn.Module):
  def __init__(self, embedding_dims, mlp_size, mlp_dropout):
    super().__init__()
    self.embedding_dims = embedding_dims
    self.mlp_size = mlp_size
    self.dropout = mlp_dropout

    self.layernorm = nn.LayerNorm(normalized_shape = embedding_dims)
    self.mlp = nn.Sequential(
        nn.Linear(in_features = embedding_dims, out_features = mlp_size),
        nn.GELU(),
        nn.Dropout(p = mlp_dropout),
        nn.Linear(in_features = mlp_size, out_features = embedding_dims),
        nn.Dropout(p = mlp_dropout)
    )

  def forward(self, x):
    return self.mlp(self.layernorm(x))

mlp_block = MachineLearningPerceptronBlock(embedding_dims = EMBEDDING_DIMS,
                                           mlp_size = 3072,
                                           mlp_dropout = 0.1)


In [None]:
class TransformerBlock(nn.Module):
    def __init__(
        self,
        embedding_dims=768,
        mlp_dropout=0.1,
        attn_dropout=0.0,
        mlp_size=3072,
        num_heads=12,
    ):
        super().__init__()

        self.msa_block = MultiHeadSelfAttentionBlock(
            embedding_dims=embedding_dims,
            num_heads=num_heads,
            attn_dropout=attn_dropout,
        )

        self.mlp_block = MachineLearningPerceptronBlock(
            embedding_dims=embedding_dims,
            mlp_size=mlp_size,
            mlp_dropout=mlp_dropout,
        )

    def forward(self, x):
        x = self.msa_block(x) + x
        x = self.mlp_block(x) + x

        return x
    
transformer_block = TransformerBlock(embedding_dims = EMBEDDING_DIMS,
                                     mlp_dropout = 0.1,
                                     attn_dropout=0.0,
                                     mlp_size = 3072,
                                     num_heads = 2)

print(f'Shape of the input Patch Embeddings => {list(patch_embeddings.shape)} <= [batch_size, num_patches+1, embedding_dims ]')
print(f'Shape of the output from Transformer Block => {list(transformer_block(patch_embeddings).shape)} <= [batch_size, num_patches+1, embedding_dims ]')


In [None]:
class ViT(nn.Module):
    def __init__(
        self,
        img_size=28,
        in_channels=1,
        patch_size=4,
        embedding_dims=16,
        num_transformer_layers=12,  # from table 1 above
        mlp_dropout=0.1,
        attn_dropout=0.0,
        mlp_size=64,
        num_heads=2,
        num_classes=10,
    ):
        super().__init__()

        self.patch_embedding_layer = PatchEmbeddingLayer(
            in_channels=in_channels, patch_size=patch_size, embedding_dim=embedding_dims
        )

        self.transformer_encoder = nn.Sequential(
            *[
                TransformerBlock(
                    embedding_dims=embedding_dims,
                    mlp_dropout=mlp_dropout,
                    attn_dropout=attn_dropout,
                    mlp_size=mlp_size,
                    num_heads=num_heads,
                )
                for _ in range(num_transformer_layers)
            ]
        )

        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dims),
            nn.Linear(in_features=embedding_dims, out_features=num_classes),
        )

    def forward(self, x):
        return self.classifier(
            self.transformer_encoder(self.patch_embedding_layer(x))[:, 0]
        )


In [None]:
model = ViT()
criterion = F.cross_entropy


In [None]:
device = 'mps'

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
import tqdm


In [None]:
for epoch in range(10):
    with tqdm.tqdm(train_loader, unit="batch") as tepoch:
        for data, target in tepoch:
            data, target = data.to(device), target.to(device)
            
            if data.shape[0] != 64:
                print(data.shape)
                continue
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            tepoch.set_postfix_str(f"Loss: {loss.item()}")
            
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
                
        print(f"Accuracy: {correct / total}")
        

In [None]:
correct = 0
total = 0
with torch.no_grad():
    with tqdm.tqdm(test_loader, unit="batch") as pbar:
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            pbar.set_postfix_str(f"Accuracy: {correct / total}")


In [None]:
print(f"Accuracy: {correct / total}")
