In [6]:
import numpy as np
import torch

In [7]:
!pip install einops



In [8]:
import einops

In [9]:
from einops import rearrange
from torch import nn

In [30]:
class PatchEmbed(nn.Module):
    
    """
    To split images into patches of given size, and convert to embeddings
    
    Parameters:
    --------
    img_size: int   (Size of the image)
    
    patch_size: int  (Size of Patch)
    
    in_chans: int (No of input channels)
    
    embedding_dims: int (The embeddings dimension)
    
    
    Attributes:
    -------
    n_patches : int   (No of patches inside the image)
    
    proj: nn.Conv2d
        Convolution Layer that do both the splitting into patches, and their embeddings
        
    """
    
    def __init__(self,img_size,patch_size,in_chans=3,embed_dims=768):
        super().__init__()
        
        self.img_size=img_size
        self.patch_size=patch_size
        self.n_patches=(img_size//patch_size)**2
        
        
        self.proj=nn.Conv2d(in_chans,embed_dims,
                           kernel_size=patch_size,
                           stride=patch_size
                           )
        
    def forward(self,x):
        
        """
        parameters:
        ---------
        x: torch.Tensor   Shape '(n_samples,in_chans, img_size,img_size)'
        
        Returns
        --------
        torch.Tensor  Shape '(n_samples,n_patches, embed_dims)'
        
        """
        
        x=self.proj(
                    x
                    )  # (n_samples,embed_dims,n_patches ** 0.5, n_patches ** 0.5)
        x=x.flatten(2)
        x=x.transpose(1,2) # (n_samples,n_patches, embed_dims)
        
        
        return x
    
    
        
        
        
     

# Experimental

In [15]:
#to del
x=torch.ones(3,5,5)
x2=torch.ones(3,7,4)
x.shape,x2.shape

(torch.Size([3, 5, 5]), torch.Size([3, 7, 4]))

In [12]:
test_conv=nn.Conv2d(3,768,2,2)

In [14]:
out=test_conv(x.unsqueeze(0))
out.shape

torch.Size([1, 768, 2, 2])

In [16]:
out2=test_conv(x2.unsqueeze(0))
out2.shape

torch.Size([1, 768, 3, 2])

In [17]:
out2=out2.flatten(2)
out2.shape

torch.Size([1, 768, 6])

* **So we can also design transformers to work with rectangle images**

# Implementing Attention

In [18]:
class Attention(nn.Module):
    """
    Attention Mechanism
    
    Parameters:
    -------
    dim: int  Input and output ims of per patch(or token) features
    
    n_heads: int   No of attention heads
    
    qkv_bias: True  If True, then we include bias to the query, key, and value projections
    
    attn_dp: float  Dropout prob applied to the k,q,and v tensors
    
    fin_proj_dp: float Dropout prob applied to the output tensor
    
    Attributes:
    --------
    
    """
    
    def __init__(self,dim,n_heads=12,qkv_bias=True, attn_dp=0.0, fin_proj_dp=0.0):
        super().__init__()
        
        self.dim=dim
        self.n_heads=n_heads
        self.head_dim= dim//n_heads
        self.scale=self.head_dim ** -0.5
        
        self.qkv=nn.Linear(dim,dim*3,bias=qkv_bias)
        self.attn_dp=nn.Dropout(attn_dp)
        self.fin_proj=nn.Linear(dim,dim)
        self.fin_proj_dp=nn.Dropout(fin_proj_dp)
        
    def forward(self,x):
        
        """
        Run Forward Pass
        """
        
        n_samples,n_tokens,dim=x.shape
        qkv=self.qkv(x) # (n_samples,n_patches+1,dim*3)
        qkv=qkv.reshape(
                        n_samples,n_tokens,3,self.n_heads,self.head_dim
        )
        qkv=qkv.permute(2,0,3,1,4)  # (3,n_samples,n_heads,n_patches+1,head_dim)
        k_t=k.transpose(-2,-1) # (n_samples,n_heads,head_dim,n_patches+1)
        dp=(q @ k_t) * self.scale   # (n_samples,n_heads, n_patches+1, n_patches+1)
        attn=dp.softmax(dim=-1)
        attn=self.attn_dp(attn)
        
        weighted_avg=attn @ v   # (n_samples,n_heads, n_patches+1, heads_dim)
        weighted_avg=weighted_avg.transpose(1,2)  # (n_samples,n_patches+1,n_heads,heads_dim)
        
        weighted_avg=weighted_avg.flatten(2)  # (n_samples,n_patches+1,dim)
        x=self.fin_proj(weighted_avg)
        x=self.fin_proj_dp(x)
        
        # So we returned the same shape as input was of.
        return x
        
        
        

# Writing MLP

In [19]:
class MLP(nn.Module):
    
    def __init__(self,in_features,hidden_features,out_features,p=0.):
        super().__init__()
        
        self.fc1=nn.Linear(in_features,hidden_features)
        self.act=nn.GELU()
        self.fc2=nn.Linear(hidden_features,out_features)
        self.drop=nn.Dropout(p)
        
        
    def forward(self, x):
        """Run forward pass.
        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, in_features)`.
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches +1, out_features)`
        """
        x = self.fc1(
                x
        ) # (n_samples, n_patches + 1, hidden_features)
        x = self.act(x)  # (n_samples, n_patches + 1, hidden_features)
        x = self.drop(x)  # (n_samples, n_patches + 1, hidden_features)
        x = self.fc2(x)  # (n_samples, n_patches + 1, hidden_features)
        x = self.drop(x)  # (n_samples, n_patches + 1, hidden_features)

        return x

In [36]:
class Block(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4.0, qkv_bias=True, p=0., attn_p=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(
                dim,
                n_heads=n_heads,
                qkv_bias=qkv_bias,
                attn_dp=attn_p,
                fin_proj_dp=p
        )
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)
        self.mlp = MLP(
                in_features=dim,
                hidden_features=hidden_features,
                out_features=dim,
        )

    def forward(self, x):
        """Run forward pass.
        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        """
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        return x


# Vision Transformer

In [37]:
class VisionTransformer(nn.Module):
    """
    Parameters(img_size,patch_size,in_chans,n_classes,
                embed_dims,depth,n_heads,mlp_ratio,qkv_bias,
                p,attn_p):
    -------
    
    """
    def __init__(self,
                img_size=384,patch_size=16,
                in_chans=3,n_classes=1000,
                embed_dims=768,depth=12,
                n_heads=12,mlp_ratio=4,
                qkv_bias=True,p=0.0,attn_p=0.0):
        super().__init__()
        
        self.patch_embed=PatchEmbed(img_size=img_size,
                                   patch_size=patch_size,
                                   in_chans=in_chans,
                                   embed_dims=embed_dims)
        self.cls_token=nn.Parameter(torch.zeros(1,1,embed_dims))
        self.pos_embeds=nn.Parameter(torch.zeros(1,1+self.patch_embed.n_patches,embed_dims))
        self.pos_drop=nn.Dropout(p=p)
        
        self.blocks=nn.ModuleList(
                [
                    Block(
                        dim=embed_dims,
                        n_heads=n_heads,
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        p=p,
                        attn_p=attn_p
                    
                    )
                    for _ in range(depth)
                ]
        )
        
        self.norm=nn.LayerNorm(embed_dims,eps=1e-6)
        self.head=nn.Linear(embed_dims,n_classes)
        
        
    def forward(self,x):
        """
        Run the forward pass
        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size)`.
        Returns
        -------
        logits : torch.Tensor
            Logits over all the classes - `(n_samples, n_classes)`.
        """
        n_samples=x.shape[0]
        x=self.patch_embed(x) # will return(n_samples,n_patches,embed_dims)
        cls_token=self.cls_token.expand(n_samples,-1,-1) # (n_samples,1,embed_dims)
        x=torch.cat((cls_token,x),dim=1)  #(n_samples,1+n_samples,embed_dims)
        x=x+ self.pos_embeds
        
        for block in self.blocks:
            x=block(x)
            
        x=self.norm(x)   # It is always the last dimension which is normalized
        
        cls_token_final=x[:, 0]  # Just the class token
        x=self.head(cls_token_final)
        
        return x

# Testing to check whether our vision transformer is right

In [38]:
config={
        "img_size": 384,
        "in_chans": 3,
        "patch_size": 16,
        "embed_dims": 768,
        "depth": 12,
        "n_heads": 12,
        "qkv_bias": True,
        "mlp_ratio": 4,
}

In [39]:
our_model=VisionTransformer(**config)

In [41]:
our_model.eval();

In [43]:
!pip install timm -q



In [44]:
import timm

In [45]:
model_name = "vit_base_patch16_384"
model_official = timm.create_model(model_name, pretrained=True)
model_official.eval()
print(type(model_official))

<class 'timm.models.vision_transformer.VisionTransformer'>


In [46]:
print(type(our_model))

<class '__main__.VisionTransformer'>


In [48]:
# Helpers
def get_n_params(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

def assert_tensors_equal(t1, t2):
    a1, a2 = t1.detach().numpy(), t2.detach().numpy()

    np.testing.assert_allclose(a1, a2)


In [49]:

for (n_o, p_o), (n_c, p_c) in zip(
        model_official.named_parameters(), our_model.named_parameters()
):
    assert p_o.numel() == p_c.numel()
    print(f"{n_o} | {n_c}")

    p_c.data[:] = p_o.data

    assert_tensors_equal(p_c.data, p_o.data)

cls_token | cls_token
pos_embed | pos_embeds
patch_embed.proj.weight | patch_embed.proj.weight
patch_embed.proj.bias | patch_embed.proj.bias
blocks.0.norm1.weight | blocks.0.norm1.weight
blocks.0.norm1.bias | blocks.0.norm1.bias
blocks.0.attn.qkv.weight | blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias | blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight | blocks.0.attn.fin_proj.weight
blocks.0.attn.proj.bias | blocks.0.attn.fin_proj.bias
blocks.0.norm2.weight | blocks.0.norm2.weight
blocks.0.norm2.bias | blocks.0.norm2.bias
blocks.0.mlp.fc1.weight | blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias | blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight | blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias | blocks.0.mlp.fc2.bias
blocks.1.norm1.weight | blocks.1.norm1.weight
blocks.1.norm1.bias | blocks.1.norm1.bias
blocks.1.attn.qkv.weight | blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias | blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight | blocks.1.attn.fin_proj.weight
blocks.1.attn.proj.bias | blocks.