https://www.youtube.com/watch?v=ovB0ddFtzzA

https://github.com/rwightman/pytorch-image-models

In [1]:
import torch

In [7]:
class PatchEmbed(torch.nn.Module):
    '''Split image into patches and then embeds them
    
    Parameters
    --------------
    img_size:int Size of the image (it's a square)
    
    patch_size:int Size of the patch (it's a square)
    
    in_channels:int Number of input channels
    
    embed_dimensions:int The embedding dimension
    
    Attributes
    ---------------
    n_patches:int Number of patches inside our image
    
    proj:nn.Conv2D Convolution layer to do patching and their embedding both
    
    '''
    
    def __init__(self, img_size:int, patch_size:int, 
                 in_channels:int = 3, embed_dimensions:int = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dimension = embed_dimensions
        self.n_patches = (img_size//patch_size)**2
        
        # Both kernel and stride size is set to patch size to get
        # non-overlapping patches
        self.proj = torch.nn.Conv2d(in_channels, embed_dimensions, 
                                    kernel_size=patch_size,
                                    stride=patch_size
                                    )
    def forward(self,x):
        '''Run forward pass
        Parameters
        ----------------
        x:torch.Tensor Shape(n_samples/batch_size, in_channels, img_size, img_size)
        
        Returns
        ----------------
        torch.Tensor Shape(batch_size, n_patches, embed_dimension)
        
        '''
        x = self.proj(x) # batch_size x embed_dim x n_patches**0.5 x n_patches**0.5
        x = x.flatten(2) # batch_size x embed_dim x n_patches
        x = x.transpose(1,2) # batch_size x n_patches x embed_dim
        return x
    
    

In [8]:
class SelfAttention(torch.nn.Module):
    '''Multi-headed Self Attention Mechanism
    
    Parameters:
    ---------------
    dim: int 
        Input and Output dimension of per token features.
    
    n_heads: int
        Number of Attention Heads
    
    qkv_bias: bool
        If True, include bias in query, key and values projections.
    
    attn_p: float
        Dropout probability applied to QKV tensors.
    
    proj_p: float
        Dropout probability applied to the ouput tensor.
        
    Attributes:
    -----------------
    scale: float
        Normalizing constant for the dot product
    qkv: nn.Linear
        Linear projection for the query, key and value 
    proj: torch.nn.Linear
        Linear mapping that takes in the concactenated output of all attention heads
        and maps it on to a new space.
    attn_dropout, proj_dropout: torch.nn.Dropout
        Dropout Layers
    '''
    
    def __init__(self,dim,n_heads=12,qkv_bias=True,attn_p=0.0,proj_p=0.0):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        # This would be the dimension of the concatenated attention heads
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5 # (1/sqrt(k))
        
        # This can also be broken into 3 separate linear layers
        self.qkv = torch.nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_drop = torch.nn.Dropout(attn_p)
        self.proj = torch.nn.Linear(dim, dim)
        self.proj_drop = torch.nn.Dropout(proj_p)
    
    def forward(self,x):
        '''Run forward pass
        Parameters
        ----------------
        x: torch.Tensor 
            Shape(n_samples/batch_size, n_patches + 1, dim)
            The + 1 comes from the class token
            Note that input and output shape is the same.
        
        Returns
        ----------------
        torch.Tensor 
            Shape(batch_size, n_patches + 1, dim)
        
        '''
        batch_size, n_tokens, dim = x.shape
        
        if dim!=self.dim:
            raise ValueError
        
        # Usually the Linear layer expects 2D input of shape (batch_size, input_dimension)
        # but if you give more dimensions than the last dimension should be input
        qkv = self.qkv(x) # (batch_size, n_patches+1, 3*dim)
        
        # (batch_size, n_patches+1, 3, n_heads, head_dim)
        qkv = qkv.reshape(batch_size, n_tokens, 3, self.n_heads, self.head_dim)
        
        # reshape to simulate 3 different layers
        qkv = qkv.permute(2,0,3,1,4) # (3,batch_size,n_heads,n_patches+1,head_dim)
        
        q, k , v = qkv[0], qkv[1], qkv[2]
        
        # transpose the keys to do dot product later
        k_t = k.transpose(-2,-1)
        
        dp = (q @ k_t) * self.scale # (batch_size, n_heads, n_patches+1, n_patches+1)
        
        attn = dp.softmax(dim=-1) # (batch_size, n_heads, n_patches+1, n_patches+1)
        
        attn = self.attn_drop(attn)
        
        weighted_avg = attn @ v # (batch_size, n_heads, n_patches+1, head_dim)
        
        weighted_avg = weighted_avg.transpose(1,2) # (batch_size, n_patches+1,n_heads, head_dim)
        
        weighted_avg = weighted_avg.flatten(2) # (batch_size, n_patches+1, dim)
        
        x = self.proj(weighted_avg) # (batch_size, n_patches+1, dim)
        x = self.proj_drop(x) # (batch_size, n_patches+1, dim)
        
        return x

In [9]:
class MLP(torch.nn.Module):
    '''Multi-Layer Perceptron
    Parameters
    --------------
    in_features: int
        Number of input features
    hidden_features: int
        Number of nodes in the hidden layer
    out_features: int
        Number of output features
    p: float
        Dropout Probability
    Attributes
    ---------------
    fc1: nn.Linear
        First linear layer
    act: nn.GELU
        Gaussian Error Linear Unit activation function
    fc2: nn.Linear
        Second Linear Layer
    drop: nn.Dropput
        Dropout Layer
    '''
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        super().__init__()
        self.fc1 = torch.nn.Linear(in_features, hidden_features)
        self.fc2 = torch.nn.Linear(hidden_features, out_features)
        self.drop = torch.nn.Dropout(p)
        self.act = torch.nn.GELU()
        
    def forward(self, x):
        '''Runs a forward pass
        Parameter
        -------------
        x: torch.Tensor
            (batch_size, n_patches + 1, in_features)
            
        Returns
        --------------
        torch.Tensor
            (batch_size, n_patches+1, out_features)
        '''
        x = self.fc1(x) # (batch_size, n_patches + 1, hidden_features)
        x = self.act(x) # (batch_size, n_patches + 1, hidden_features)
        x = self.drop(x) # (batch_size, n_patches + 1, hidden_features)
        x = self.fc2(x) # (batch_size, n_patches + 1, hidden_features)
        x = self.drop(x) # (batch_size, n_patches + 1, hidden_features)
        return x

In [10]:
class TransformerBlock(torch.nn.Module):
    '''Transformer Block
    
    Parameters
    --------------
    dim: int
        Embedding dimension
    n_heads: int
        No. of self-attention heads
    mlp_ratio: float
        Determines the hidden dimension of the 'MLP' module w.r.t 'dim'
    qkv_bias: bool
        If True, include bias in query, key and values projections.
    
    p, attn_p: float
        Dropout probability.
    
    Attributes:
    norm1, norm2: LayerNorm
        Layer Normalization
    attn: SelfAttention 
        Self Attention Module
    mlp: MLP
        MLP Module
    
    '''
    def __init__(self, dim,n_heads,mlp_ratio=4.0,qkv_bias=True,p=0.,attn_p=0.):
        super().__init__()
        # You can set elementwise_affine=False to not train any parameters of LayerNorm
        # default is True and it takes the mean and std. deviation for each sample/example
        # therefore, it is independent of the batch_size unlike BatchNorm
        self.norm1 = torch.nn.LayerNorm(dim, eps=1e-6)
        self.attn = SelfAttention(
                    dim,
                    n_heads=n_heads,
                    qkv_bias=qkv_bias,
                    attn_p=attn_p,
                    proj_p=p
                    )
        self.norm2 = torch.nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim*mlp_ratio)
        self.MLP = MLP(
                    in_features=dim,
                    hidden_features=hidden_features,
                    out_features=dim
                    )
        
    def forward(self,x):
        '''Run forward pass
        Parameters
        ----------------
        x: torch.Tensor 
            Shape(n_samples/batch_size, n_patches + 1, dim)
            The + 1 comes from the class token
            Note that input and output shape is the same.

        Returns
        ----------------
        torch.Tensor 
            Shape(batch_size, n_patches + 1, dim)
        '''
        x = x + self.attn(self.norm1(x))
        x = x + self.MLP(self.norm2(x))
        return x

In [None]:
class VisionTransformer(torch.nn.Module):
    '''Simplified Implementation of Vision Transformers
    
    Parameters
    --------------
    img_size: int 
        Size of the image (it's a square)
    
    patch_size: int 
        Size of the patch (it's a square)
    
    in_channels: int 
        Number of input channels
    
    n_classes: int
        Number of classes in the dataset
    
    embed_dimensions: int 
        The embedding dimension
        
    depth: int
        Number of Transformer Blocks
        
    n_heads: int
        No. of self-attention heads
    mlp_ratio: float
        Determines the hidden dimension of the 'MLP' module w.r.t 'dim'
    qkv_bias: bool
        If True, include bias in query, key and values projections.
    
    p, attn_p: float
        Dropout probability.
    
    Attributes:
    -------------
    
    patchEmbed: PatchEmbed
        Instance of 'PatchEmbed' Layer
        
    clsToken: torch.nn.Parameter
        Learnable Parameter of embed_dim dimensions. Representes first token in input.
        
    pos_embed: torch.nn.Parameter
        Positional Encoding of all the patches + cls token
        It has (n_patches+1)*embed_dim elements
        
    pos_dropout: torch.nn.Dropout
        Dropout Layer
    
    transformer_blocks: torch.nn.ModuleList
        List of TransformerBlock Modules
    
    norm: torch.nn.LayerNorm
        Layer Normalization
    
    
    '''
    def __init__(self,img_size=384,
                 patch_size=16,
                 in_channels=3,
                 n_classes=1000,
                 embed_dim=768,
                 depth=12,
                 n_heads=12,
                 mlp_ratio=4.0,
                 qkv_bias=True,
                 attn_p=0.0,
                 p=0.
                ):
        super().__init__()
        self.patch_embed = PatchEmbed(
                            img_size=img_size,
                            patch_size=patch_size,
                            in_channels=in_channels,
                            embed_dim=embed_dim
                            )
        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = torch.nn.Parameter(
                                torch.zeros(1,1+self.patch_embed.n_patches,embed_dim))
        self.pos_drop = torch.nn.Dropout(p=p)
        self.blocks = torch.nn.ModuleList(
                    [
                        TransformerBlock(
                        dim=embed_dim,
                        n_heads=n_heads,
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        p=p,
                        attn_bias=attn_bias) 
                        for _ in range(depth)
                    ]
                        )
        self.norm = torch.nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = torch.nn.Linear(embed_dim,n_classes)
        
        def forward(self, x):
            '''Run the forward pass
            
            Parameters:
            ---------------
            x: torch.Tensor
                Shape (batch_size, in_channels, img_size, img_size)
            Returns:
            logits: torch.Tensor
                Logits over all the classes (batch_size, n_classes)
            
            '''
            batch_size = x.shape[0]
            x = self.patch_embed(x)
            cls_token = self.cls_token.expand(batch_size, -1, -1)
            x = torch.cat((cls_token,x),dim=-1) # (batch_size, n_patches+1, embed_dim)
            x = x + self.pos_embed() # (batch_size, n_patches+1, embed_dim)
            x = self.pos_drop(x)
            
            for block in self.blocks:
                x = block(x)
            
            x = self.norm(x)
            cls_token_final = x[:,0] # only selecting the class token
            x = self.head(cls_token_final)
            return x