In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, h, edim):
        super().__init__()

        self.h = h
        self.edim = edim
        self.dk = self.edim//self.h
        self.key = nn.Linear(self.edim,self.edim)
        self.query = nn.Linear(self.edim,self.edim)
        self.value = nn.Linear(self.edim,self.edim)
        self.linear = nn.Linear(self.edim,self.edim)
        

    def forward(self, key, value,query):

        bs = key.shape[0]
        nwords_key = key.shape[1]
        nwords_query = query.shape[1]

        k = self.key(key).reshape(bs, nwords_key, self.h, self.dk).transpose(1,2)
        q = self.query(query).reshape(bs, nwords_query, self.h, self.dk).transpose(1,2)
        v = self.value(value).reshape(bs, nwords_key, self.h, self.dk).transpose(1,2)
        x = torch.einsum('bhmd,bhnd -> bhmn',(q,k))

        x = F.softmax(x/(self.dk)**0.5, dim=3)

        x = torch.einsum('bhmn,bhnv -> bhmv', (x,v))
        x = x.transpose(1,2)

        x = x.reshape(bs, nwords_query, -1)
        x = self.linear(x)
        return x


class EncoderBlock(nn.Module):
    def __init__(self, edim, h, hdim, dropout):
        super().__init__()

        self.multiHeadAttention = MultiHeadAttention(h, edim)
        self.norm1 = nn.LayerNorm(edim)
        self.norm2 = nn.LayerNorm(edim)
        self.fc1 = nn.Linear(edim, 4*hdim)
        self.fc2 = nn.Linear(4*hdim, edim)
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)


    def forward(self, embed):

        x = self.multiHeadAttention(embed, embed, embed)
        x = self.dropout1(x)
        subLayer1 = self.norm1(x + embed)

        x = self.fc2(self.relu(self.fc1(subLayer1)))
        x = self.dropout2(x)
        subLayer2 = self.norm2(x + subLayer1)

        return subLayer2


class Encoder(nn.Module):
    '''
    Encoder
    nx: number of transformer blocks
    edim: embedding dimension
    h: number of heads
    hdim: hidden dimension
    '''
    def __init__(self, nx, edim, h, hdim, dropout):
        super().__init__()

        self.transformers = nn.ModuleList([EncoderBlock(edim, h, hdim,dropout) for _ in range(nx)])

    def forward(self, embed):
        for block in self.transformers:
            embed = block(embed)
        return embed


class PositionalEmbedding(nn.Module):
    '''
    Positional Embedding
    '''
    def __init__(self, edim, npatches):
        super().__init__()
        self.embedding = nn.Embedding(npatches, edim)

    def forward(self, x):
        return self.embedding(x)

class ClassificationHead(nn.Module):

    '''
    Classification Head
    Takes mean of the output of the transformer and passes it through a linear layer
    '''
    def __init__(self, edim, n_classes):
        super().__init__()
        self.linear = nn.Linear(edim, n_classes)
    def forward(self, x):
        x = torch.mean(x, dim=1)
        return self.linear(x)

class ViT(nn.Module):
    '''
    ViT
    nx: number of transformer blocks
    edim: embedding dimension
    h: number of heads
    hdim: hidden dimension
    dropout: dropout probability
    n_classes: number of classes
    patch_dim: dimension of the patch
    '''

    def __init__(self, nx, edim, h, hdim, dropout, n_classes, patch_dim, npatches):
        super().__init__()
        self.posEmbedding = PositionalEmbedding(edim, npatches)
        self.embedding = nn.Linear(patch_dim,edim)
        self.encoder = Encoder(nx, edim, h, hdim,dropout)
        self.classificationHead = ClassificationHead(edim, n_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, patches):

        pos_embed = self.posEmbedding(torch.arange(patches.shape[1], device=patches.device))
        src_embed = self.embedding(patches) + pos_embed
        src_embed = self.dropout(src_embed)
        encoded = self.encoder(src_embed)
        output = self.classificationHead(encoded)

        return output

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

img_size = 64
patch_size = 16
n_channels = 3
patch_dim = n_channels * patch_size * patch_size
npatches = img_size // patch_size


vit = ViT(n_classes=1000, nx=6, edim=768, h=8, hdim=1024, dropout=0.1, patch_dim=patch_dim, npatches=npatches).to(device)

print("number of parameters in Million: ", sum(p.numel() for p in vit.parameters() if p.requires_grad)/1e6)
## try on random data

bs = 2
n_classes = 1000
x = torch.randn(bs, npatches, patch_dim)

## positional embedding is same for all patches
pos_embed = torch.randn(npatches, patch_dim)
## expand to batch size
pos_embed = pos_embed.unsqueeze(0).expand(bs, npatches, patch_dim)

## forward pass
x = x.to(device)
pos_embed = pos_embed.to(device)
output = vit(x)
print(output.shape)



number of parameters in Million:  53.333224
torch.Size([2, 1000])
