In [6]:
import numpy as np
import torch

In [7]:
!pip install einops



In [8]:
import einops

In [9]:
from einops import rearrange
from torch import nn

In [10]:
class PatchEmbed(nn.Module):
    
    """
    To split images into patches of given size, and convert to embeddings
    
    Parameters:
    --------
    img_size: int   (Size of the image)
    
    patch_size: int  (Size of Patch)
    
    in_chans: int (No of input channels)
    
    embedding_dims: int (The embeddings dimension)
    
    
    Attributes:
    -------
    n_patches : int   (No of patches inside the image)
    
    proj: nn.Conv2d
        Convolution Layer that do both the splitting into patches, and their embeddings
        
    """
    
    def __init__(self,img_size,patch_size,in_chans=3,embedding_dims=768):
        super().__init__()
        
        self.img_size=img_size
        self.patch_size=patch_size
        self.n_patches=(img_size//patch_size)**2
        
        
        self.proj=nn.Conv2d(in_chans,embedding_dims,
                           kernel_size=patch_size,
                           stride=patch_size
                           )
        
    def forward(self,x):
        
        """
        parameters:
        ---------
        x: torch.Tensor   Shape '(n_samples,in_chans, img_size,img_size)'
        
        Returns
        --------
        torch.Tensor  Shape '(n_samples,n_patches, embed_dims)'
        
        """
        
        x=self.proj(
                    x
                    )  # (n_samples,embed_dims,n_patches ** 0.5, n_patches ** 0.5)
        x=x.flatten(2)
        x=x.transpose(1,2) # (n_samples,n_patches, embed_dims)
        
        
        return x
    
    
        
        
        
     

# Experimental

In [15]:
#to del
x=torch.ones(3,5,5)
x2=torch.ones(3,7,4)
x.shape,x2.shape

(torch.Size([3, 5, 5]), torch.Size([3, 7, 4]))

In [12]:
test_conv=nn.Conv2d(3,768,2,2)

In [14]:
out=test_conv(x.unsqueeze(0))
out.shape

torch.Size([1, 768, 2, 2])

In [16]:
out2=test_conv(x2.unsqueeze(0))
out2.shape

torch.Size([1, 768, 3, 2])

In [17]:
out2=out2.flatten(2)
out2.shape

torch.Size([1, 768, 6])

* **So we can also design transformers to work with rectangle images**

# Implementing Attention

In [18]:
class Attention(nn.Module):
    """
    Attention Mechanism
    
    Parameters:
    -------
    dim: int  Input and output ims of per patch(or token) features
    
    n_heads: int   No of attention heads
    
    qkv_bias: True  If True, then we include bias to the query, key, and value projections
    
    attn_dp: float  Dropout prob applied to the k,q,and v tensors
    
    fin_proj_dp: float Dropout prob applied to the output tensor
    
    Attributes:
    --------
    
    """
    
    def __init__(self,dim,n_heads=12,qkv_bias=True, attn_dp=0.0, fin_proj_dp=0.0):
        super().__init__()
        
        self.dim=dim
        self.n_heads=n_heads
        self.head_dim= dim//n_heads
        self.scale=self.head_dim ** -0.5
        
        self.qkv=nn.Linear(dim,dim*3,bias=qkv_bias)
        self.attn_dp=nn.Dropout(attn_dp)
        self.fin_proj=nn.Linear(dim,dim)
        self.fin_proj_dp=nn.Dropout(fin_proj_dp)
        
    def forward(self,x):
        
        """
        Run Forward Pass
        """
        
        n_samples,n_tokens,dim=x.shape
        qkv=self.qkv(x) # (n_samples,n_patches+1,dim*3)
        qkv=qkv.reshape(
                        n_samples,n_tokens,3,self.n_heads,self.head_dim
        )
        qkv=qkv.permute(2,0,3,1,4)  # (3,n_samples,n_heads,n_patches+1,head_dim)
        k_t=k.transpose(-2,-1) # (n_samples,n_heads,head_dim,n_patches+1)
        dp=(q @ k_t) * self.scale   # (n_samples,n_heads, n_patches+1, n_patches+1)
        attn=dp.softmax(dim=-1)
        attn=self.attn_dp(attn)
        
        weighted_avg=attn @ v   # (n_samples,n_heads, n_patches+1, heads_dim)
        weighted_avg=weighted_avg.transpose(1,2)  # (n_samples,n_patches+1,n_heads,heads_dim)
        
        weighted_avg=weighted_avg.flatten(2)  # (n_samples,n_patches+1,dim)
        x=self.fin_proj(weighted_avg)
        x=self.fin_proj_dp(x)
        
        # So we returned the same shape as input was of.
        return x
        
        
        