In [324]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy
import matplotlib.pyplot as pyplot

embed_size = 128
batch_size = 32
patch_height = 7
number_of_patches = 16

def patchify(images):
    patched = torch.empty((images.size(dim=0), number_of_patches, patch_height*patch_height))
    
    for n in range(images.size(dim=0)):
        for i in range(0,4):
            for j in range(0,4):
                patched[n,4*i+j] = torch.flatten( images[n,0,i*patch_height:(i+1)*patch_height, j*patch_height:(j+1)*patch_height] )
    return patched

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.numheads = 4
        self.patches_with_class = number_of_patches + 1
        
        self.qlinear = nn.Linear(embed_size, embed_size)
        self.klinear = nn.Linear(embed_size, embed_size)
        self.vlinear = nn.Linear(embed_size, embed_size)
        
        self.headconcatlinear = nn.Linear(embed_size, embed_size)
        
    def forward(self,x, verbose=False):
        q = self.qlinear(x)
        k = self.klinear(x)
        v  =self.vlinear(x)
        if verbose: print("q,v,k shape", q.shape)
        
        # Create heads: torch.view to split last layer among heads. Torch.permute to place head number as first layer.
        q = q.view(batch_size, self.patches_with_class, self.numheads, embed_size//self.numheads).permute(0,2,1,3)
        k = k.view(batch_size, self.patches_with_class, self.numheads, embed_size//self.numheads).permute(0,2,1,3)
        v = v.view(batch_size, self.patches_with_class, self.numheads, embed_size//self.numheads).permute(0,2,1,3)
        if verbose: print("After split the heads q, k,v shape:", q.shape, "\n")
        
        matmulqk = torch.matmul(q,torch.transpose(k, dim0=2, dim1=3)) / numpy.sqrt( int(embed_size / self.numheads) )
        if verbose: print("multiplied key with query, shape:", matmulqk.shape, "\n")
        
        attention_weights = nn.Softmax(dim=-1)(matmulqk)
        if verbose: print("attention weights after softmax: \n", attention_weights.shape)
        
        result = torch.matmul(attention_weights, v)
        if verbose: print("Multiply attention weights and values shape:", result.shape)
        
        result = result.permute(0,2,1,3)
        if verbose: print("Bring heads together:", result.shape, "\n")
        
        result = result.reshape(batch_size, self.patches_with_class, embed_size)
        if verbose: print("concatenate shape:", result.shape, " data:\n")
        
        result = self.headconcatlinear(result)
        return result
    
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_normalization = nn.LayerNorm(embed_size)
        self.attention = MultiHeadAttention()
        self.ffblock = nn.Sequential(nn.Linear(embed_size, 2048), nn.ReLU(), nn.Linear(2048, embed_size))
        
    def forward(self,x):
        
        result = x + self.attention( self.layer_normalization(x) )  
        result = result + self.ffblock( self.layer_normalization(result) )
        
        return result

In [327]:
class ViT(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_size))
        self.embedder = nn.Linear(patch_height*patch_height, embed_size)
        
        self.encoder1 = TransformerEncoder()
        self.encoder2 = TransformerEncoder()
        
        self.classification_head = nn.Sequential(nn.LayerNorm(embed_size), nn.Linear(embed_size, 10))
    
    def forward(self, x, verbose=False):        
        
        x = self.embedder(x)
        cls_tokens = self.cls_token.repeat(batch_size,1,1)
        x = torch.cat( (cls_tokens, x), dim=1 )
        
        x = PositionalEncoding(embed_size)(x)
        
        result = self.encoder1(x)
        result = self.encoder2(result)
        
        result = result[:,0]
        result = self.classification_head(result)
        
        return result


In [328]:
dataset = datasets.MNIST(root =".\data", transform=transforms.ToTensor(), train=True, download=True )
dataloader = DataLoader(dataset=dataset, batch_size=batch_size)

model = ViT()
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

losses = []
for i in numpy.arange(5):
    print("Epoch {i}:\n")
    for j, (x,y) in enumerate(dataloader):
        patches = patchify(x)
        yhat = model(patches)
        
        loss = criterion(yhat,y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if j % 100 == 0: 
            print(loss.item())
            losses.append(loss.item())
            
pyplot.plot(losses)     

Epoch {i}:

2.4513885974884033
1.8432921171188354


KeyboardInterrupt: 

torch.Size([32, 1, 128])
