In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self,  embed_dim=512, output_dim=512, num_heads=8):
        """
        dim of each head = embed_dim//num_heads (i.e. 64 as default)
        """
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.dim_each_heads = embed_dim // num_heads

        self.to_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=1, bias=True)
        self.to_k = nn.Conv1d(embed_dim, embed_dim, kernel_size=1, bias=True)
        self.to_v = nn.Conv1d(embed_dim, embed_dim, kernel_size=1, bias=True)

        self.dk = float( torch.rsqrt(torch.FloatTensor([self.dim_each_heads])).item())   # dk: dim of q and k

        self.out_proj = nn.Conv1d(embed_dim, output_dim, kernel_size=1, bias=True)
    
    def forward(self, x):
        """
        x: batch size, embedding dimension, sequence length   (B, N, L)
        """
        assert x.ndim == 3
        
        B = x.size(0)       # batch size
        Lq = x.shape[-1]    # q sequence length
        
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)
        
        q_multihead = q.view(B, self.num_heads, self.dim_each_heads, -1)
        k_multihead = k.view(B, self.num_heads, self.dim_each_heads, -1)
        v_multihead = v.view(B, self.num_heads, self.dim_each_heads, -1)
        
        scaled_qk = torch.einsum('bhnl, bhnm->bhlm', q_multihead, k_multihead) * self.dk # (B, Lq, Lk)
        attention_weight = torch.softmax(scaled_qk, dim=2)                               # (B, Lq, Lk)
        
        transformed_cocatenated_heads = torch.einsum('bhlm, bhkm->bhkl', attention_weight, v_multihead).reshape(B, -1, Lq) # concatenation
        
        return self.out_proj( transformed_cocatenated_heads )       
    
    
class TransformerEncoder(nn.Module):
    def __init__(self,  embed_dim=512,  num_heads=8, ffn_dim=3072):
        super(TransformerEncoder, self).__init__()
        self.mhsa = MultiHeadSelfAttention(embed_dim, embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        
        self.ffn1 = nn.Conv1d(embed_dim, ffn_dim, kernel_size=1, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.ffn2 = nn.Conv1d(ffn_dim, embed_dim, kernel_size=1, bias=True)
        
    def forward(self, x):
        """
        x: batch size, embedding dimension, sequence length   (B, N, L)
        """
        assert x.ndim == 3
        skip = x
        x = self.mhsa(x)
        
        x = skip + x
        x = self.norm1( x.permute(0, 2, 1) )
        x = x.permute(0, 2, 1)
        
        skip = x
        x = self.ffn1( x )
        x = self.relu( x )
        x = self.ffn2( x )
        
        x = skip + x
        x = self.norm2( x.permute(0, 2, 1) )
        x = x.permute(0, 2, 1)
        
        return skip + x
    
class VisionTransformerEncoder(nn.Module):
    def __init__(self,  embed_dim=512,  num_heads=8, mlp_dim=3072):
        super(VisionTransformerEncoder, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.mhsa = MultiHeadSelfAttention(embed_dim, embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        
        self.mlp = nn.Sequential(*[nn.Conv1d(embed_dim, mlp_dim, kernel_size=1, bias=True),
                                 nn.ReLU(inplace=True),
                                 nn.Conv1d(mlp_dim, embed_dim, kernel_size=1, bias=True)])
        
    def forward(self, x):
        """
        x: batch size, embedding dimension, sequence length   (B, N, L)
        """
        assert x.ndim == 3
        skip = x
        x = self.norm1( x.permute(0, 2, 1) )
        x = x.permute(0, 2, 1)
        x = self.mhsa(x)
        x = skip + x

        skip = x
        x = self.norm2( x.permute(0, 2, 1) )
        x = x.permute(0, 2, 1)
        
        self.mlp(x)
        x = skip + x        
        
        return skip + x
    
class ImageEmbedding(nn.Module):
    def __init__(self,  image_dim=3, embed_dim=512,  patch_size=(16, 16)):
        super(ImageEmbedding, self).__init__()
        self.proj = nn.Conv2d(image_dim, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0, bias=True)
        self.embed_dim= embed_dim       
        
    def forward(self, x):
        """
        x: batch size, image_dim, height, width   (B, C, H, W)
        """
        B = x.size(0)
        
        return self.proj( x ).view(B, self.embed_dim, -1)   
    
class VisionTransformer(nn.Module):
    def __init__(self,  
                 image_size=(256, 256), 
                 patch_size=(16, 16), 
                 image_dim=3, 
                 embed_dim=384, 
                 mlp_dim=1536, 
                 num_heads=8,
                 num_layers=8,
                 global_pool='avg'):
        """
        No dropout implementation for simplicity
        """
        super(VisionTransformer, self).__init__()
        embed_len = (image_size[0]//patch_size[0]) * (image_size[1]//patch_size[1])
        self.pos_embed = nn.Parameter(torch.randn(1,  embed_dim, embed_len+1) * .02)
        self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, 1))
        
        nn.init.normal_(self.pos_embed, std=.02)
        nn.init.normal_(self.cls_token, std=1e-6)
        
        self.proj = ImageEmbedding(image_dim, embed_dim, patch_size)
        
        self.transformer = nn.ModuleList([VisionTransformerEncoder(embed_dim, num_heads, mlp_dim) for _ in range(num_layers) ])   
        
        self.global_pool = global_pool
        
    def forward(self, x):
        x = self.proj(x)
        x = torch.cat( [self.cls_token.expand(x.size(0), -1, -1), x], dim=2)  # concatenate cls token
        x = x + self.pos_embed                                              # add positional embedding
        for layer in self.transformer: 
            x = layer(x)
        if self.global_pool == 'avg':
            cls = x.mean(dim=2)
        else:
            cls = x[:, :, 0]   # token  in timm
        
        return cls
        

In [67]:
ViT_S_config = {'num_layers':12, 'embed_dim': 384, 'mlp_dim': 1536, 'num_heads': 6} 
ViT_B_config = {'num_layers':12, 'embed_dim': 768, 'mlp_dim': 3072, 'num_heads':12} 

HIPT256_config = {'num_layers':8, 'embed_dim': 384, 'mlp_dim': 1536,'num_heads':6} 
HIPT4096_config = {'num_layers':4, 'embed_dim': 192, 'mlp_dim': 768, 'num_heads':6} 

tfe = TransformerEncoder(512, 8)
vtfe = VisionTransformerEncoder(512, 8)
ie = ImageEmbedding(image_dim=3, embed_dim=512,  patch_size=(16, 16))
vit = VisionTransformer(**HIPT4096_config, global_pool='token')
x = torch.rand(4, 3, 256, 256)
with torch.no_grad():
    print( tfe( ie(x) ).size() )
    print( vtfe( ie(x) ).size() )
    print( vit( x ).size() )

torch.Size([4, 512, 256])
torch.Size([4, 512, 256])
torch.Size([4, 192])


In [68]:
num_params = 0
for params in vit.parameters():
    num_params += torch.prod( torch.tensor(params.shape) )
print(num_params / 10**6)

tensor(1.9736)
