In [1]:
!pip3 install torchinfo



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torchinfo
import matplotlib.pyplot as plt

## Patching + linear_embed

In [3]:
class Patching(nn.Module):
    """
    Input: Images (B,C,H,W)
    Output: (B,N+1,d_model)
    """
    def __init__(self,C,patch_size=16,d_model=786):
        super().__init__()
        self.patch_size=patch_size
        self.C=C
        self.in_=int((patch_size**2)*self.C)
        self.linear=nn.Linear(in_features=self.in_,out_features=d_model)
    
    def forward(self,img):
        B,C,H,W=img.shape
        patches=img.unfold(2,self.patch_size,self.patch_size).unfold(3,self.patch_size,self.patch_size) # (B,C, num_patches_H, num_patches_W, patch_size, patch_size)
        patches=patches.permute(0,2,3,1,4,5).reshape(B,-1,self.in_)
        #N,P**2C
        
        return self.linear(patches) # B,N,d_model
        
       

In [4]:
Patch_test=Patching(C=3,patch_size=16,d_model=768)
img_test=torch.randn((10,3,224,224))
Patch_test(img_test).shape

torch.Size([10, 196, 768])

In [5]:
class ViT_test(nn.Module):
    def __init__(self,H,W,C,patch_size=16,d_model=786):
        super().__init__()
        self.C=C
        self.patch_size=patch_size
        self.num_patches=int((H*W)/(self.patch_size**2))
        self.d_model=d_model
        self.pos_embed=nn.Parameter(torch.randn(1,self.num_patches+1,d_model))
        self.cls_token=nn.Parameter(torch.zeros(1,d_model)) #1,d_model
        self.patching=Patching(C=self.C,patch_size=self.patch_size,d_model=self.d_model)
    def forward(self,imgs):
        imgs=self.patching(imgs) #B,N,d_model
        cls_token=self.cls_token.expand(imgs.shape[0],-1,-1)
        imgs=torch.cat((cls_token,imgs),dim=1) # B,N+1,d_model
        imgs=imgs + self.pos_embed
        return imgs
        
        
    

In [6]:
vit_test=ViT_test(H=224,W=224,C=3)
vit_test(img_test).shape # B,num_patches+1,d_model


torch.Size([10, 197, 786])

## Transformer block

In [7]:
import math
import copy
import torch.nn.functional as F
import torch.optim as optim


device="cuda" if torch.cuda.is_available() else "cpu"
def clones(module,N):
    """
    Create a list of N identical layers.

    Args:
        module (nn.Module): A neural network module to be cloned.
        N (int): The number of clones to create.

    Returns:
        nn.ModuleList: A list containing N deep copies of the input module.
    """
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def attention(query,key,value,dropout=0):
  
 
  #query, key, value -> N,h,T,d_k
  d_k=query.size(-1)
  scores=torch.matmul(query,key.transpose(-1,-2))/math.sqrt(d_k) # N,h,T,d_k @ N,h,d_k,T = N,h,T,T
  
  p_attn=scores.softmax(dim=-1)
  if dropout is not None:
    p_attn=dropout(p_attn)

  return torch.matmul(p_attn,value), p_attn


class MultiHeadedAttention(nn.Module):
  def __init__(self,h,d_model,dropout=0):
    """
    Create a MultiHeadedAttention layer.

    Args:
        h (int): The number of heads in the multi-head attention mechanism.
        d_model (int): The number of expected features in the input.
        dropout (float, optional): The dropout to apply to the attention weights. Defaults to 0.1.
    """

    super(MultiHeadedAttention,self).__init__()

    self.d_k=d_model//h
    self.h=h
    self.linears=clones(nn.Linear(d_model, d_model), 4)
    self.dropout=nn.Dropout(p=dropout)

  def forward(self,query,key,value):


    """
    Compute the forward pass of the multi-headed attention layer.

    Args:
        query (torch.Tensor): The query tensor.
        key (torch.Tensor): The key tensor.
        value (torch.Tensor): The value tensor.
       

    Returns:
        torch.Tensor: The output of the forward pass.
    """


    nbatches=query.shape[0]

    query,key,value=[
        lin(x).view(nbatches,-1,self.h,self.d_k).transpose(1,2)
        for lin,x in zip(self.linears,(query,key,value))
    ]

    x,self_attn=attention(query,key,value,dropout=self.dropout)

    x=(x.transpose(1,2).contiguous().view(nbatches,-1,self.h*self.d_k))

    del query
    del key
    del value

    return self.linears[-1](x)

### MLP

In [8]:
class MLP(nn.Module):
    def __init__(self,d_model,dropout=0):
        super().__init__()
        self.dropout=nn.Dropout(p=dropout)  
        self.linear1=nn.Linear(in_features=d_model,out_features=4*d_model)
        self.linear2=nn.Linear(in_features=4*d_model,out_features=d_model)
        
    def forward(self,x):
        x=self.linear1(x)
        x=self.dropout(F.gelu(x))
        x=self.linear2(x)
        x=self.dropout(x)
        return x
        

In [9]:
mlp_test = MLP(d_model=768, dropout=0.1)

x_test = torch.randn(10, 197, 768)  # B=10, N=197, D=768

out = mlp_test(x_test)
print(out.shape)

torch.Size([10, 197, 768])


### residual connection

In [10]:
class SubLayerConnection(nn.Module):
  def __init__(self,d_model,dropout):
    """
  Initialize a sublayer connection layer.
  
  Parameters:
  d_model (int): The number of expected features in the input.
  dropout (float): The amount of dropout to apply.
    """
    super(SubLayerConnection,self).__init__()
    self.norm=nn.LayerNorm(d_model)
    self.dropout=nn.Dropout(dropout)

  def forward(self,x,sublayer):

    return x + self.dropout(sublayer(self.norm(x)))

In [11]:
class ViT_layer(nn.Module):

  def __init__(self,self_attn,mlp,dropout,hidden=768*4,h=12,d_model=768):

    """
  Initialize a ViT_layer
  
  Parameters:
  self_attn (nn.Module): A multi-headed self-attention layer.
  feed_forward (nn.Module): A feed-forward layer.
  dropout (float): The amount of dropout to apply.
  hidden (int): The number of neurons in the hidden layer of the feed-forward layer. Defaults to 768*4.
  h (int): The number of attention heads. Defaults to 12.
  d_model (int): The number of expected features in the input. Defaults to 768.
  """
    super().__init__()
    self.self_attn=self_attn
    self.h=h
    self.mlp=mlp
    self.sublayer=clones(SubLayerConnection(d_model,dropout),2)
    self.d_model=d_model


  def forward(self,x):
  
    x=self.sublayer[0](x, lambda x:self.self_attn(x,x,x))
    return self.sublayer[1](x, self.mlp)



In [12]:
layer_test=ViT_layer(self_attn=MultiHeadedAttention(h=12,d_model=768),mlp=MLP(d_model=768),dropout=0.1)
layer_test(x_test).shape

torch.Size([10, 197, 768])

### Transformer

In [13]:
class ViT(nn.Module):
    def __init__(self,n_classes=1000,H=224,W=224,C=3,patch_size=16,d_model=768):
        super().__init__()
        self.H=H
        self.W=W
        self.C=C
        self.patch_size=patch_size
        self.d_model=d_model
        self.num_patches=int((H*W)/(self.patch_size**2))
        self.pos_embed=nn.Parameter(torch.randn(1,self.num_patches+1,d_model))
        self.cls_token=nn.Parameter(torch.zeros(1,d_model))
        self.patching=Patching(C=self.C,patch_size=self.patch_size,d_model=self.d_model)
        self.layers=clones(ViT_layer(self_attn=MultiHeadedAttention(h=12,d_model=self.d_model),mlp=MLP(d_model=self.d_model),dropout=0.1),12)
        self.norm=nn.LayerNorm(self.d_model)
        self.mlp_head=nn.Linear(in_features=self.d_model,out_features=n_classes)

    def forward(self,x):
        x=self.patching(x)
        x=torch.cat((self.cls_token.expand(x.shape[0],-1,-1),x),dim=1)
        x=x+self.pos_embed
        for layer in self.layers:
            x=layer(x)
        x=self.norm(x)
        cls_token_final=x[:,0] #batch_size,d_model
        x=self.mlp_head(cls_token_final) #batch_size,n_classes
        return x,cls_token_final #batch_size,n_classes batch_size,d_model

In [14]:
model_test = ViT(
    n_classes=1000,
    H=224,
    W=224,
    C=3,
    patch_size=16,
    d_model=768
)

# Random test input: batch of 8 RGB images of size 224x224
x_test = torch.randn(8, 3, 224, 224)

# Forward pass
logits, cls_token_output = model_test(x_test)

# Print shapes to verify
print("Logits shape:", logits.shape)               # Expected: [8, 1000]
print("CLS Token Output shape:", cls_token_output.shape)  # Expected: [8, 768]

Logits shape: torch.Size([8, 1000])
CLS Token Output shape: torch.Size([8, 768])


In [15]:
x_test=torch.randn(4,3,224,224)
logits, cls_token_output=model_test(x_test)
print("Logits shape:", logits.shape)              
print("CLS Token Output shape:", cls_token_output.shape) 

Logits shape: torch.Size([4, 1000])
CLS Token Output shape: torch.Size([4, 768])


## Data transformations

In [16]:
from torchvision import transforms
from torchvision.transforms import v2
from torchvision.transforms.functional import InterpolationMode

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
#Imagenet statistics 
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

In [18]:
v2.RandomResizedCrop

torchvision.transforms.v2._geometry.RandomResizedCrop

In [19]:
def train_transformatons(image_size=(224,224),image_mean=IMAGENET_MEAN,image_std=IMAGENET_STD
                         , hflip_probab=0.5, interpolation=InterpolationMode.BILINEAR, random_aug_magnitude=9
                         ):
    #interpolation: How to resize image when cropping/resizing. BILINEAR is smooth and preferred for ViTs.
    transformation_chain=[]
    
    transformation_chain.append(
        v2.RandomResizedCrop(image_size,interpolation=interpolation,antialias=True)
        )
    
    if hflip_probab > 0:
        transformation_chain.append(v2.RandomHorizontalFlip(p=hflip_probab)) #flips image horizontaly with 50% probab
    
    if random_aug_magnitude > 0:
        print("Enabling Random Augmentation!")
        transformation_chain.append(v2.RandAugment(magnitude=rand_aug_magnitude, interpolation=interpolation))
        #Applies random transformations (e.g., brightness, contrast, shear, rotation).
        #	Magnitude controls how strong the augmentations are.
        #Makes your dataset more diverse → better generalization.
        
    
    transformation_chain.append(v2.PILToTensor()) #Converts from PIL Image to PyTorch Tensor.
    transformation_chain.append(v2.ToDtype(torch.float32,scale=True)) #scales pixels values from [0,255] to [0.0,1.0]
    transformation_chain.append(v2.Normalize(mean=image_mean,std=img_std)) # pixel-mean/std
    #brings each channel to zero-mean, unit variance. Essential for fast training and stability
    
    return transforms.Compose(trasformation_chain)


    


In [20]:
def eval_transformations(image_size=(224,224), resize_image=(256,256), image_mean=IMAGENET_MEAN, image_std=IMAGENET_STD,
                         interpolation=InterpolationMode.BILINEAR
                         ):
    
    #meant to be used during evaluation or inference, not training
    transformations=transforms.Compose([
        v2.Resize(resize_image,interpolation=interpolation, antialias=True),
        v2.CenterCrop(image_size),
        v2.PILToTensor(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=image_mean, std=image_std)
    ])
    
    return transformations


In [21]:
from torch.utils.data import default_collate
# customizing the collat_fn of the dataloader
#blend images after theyve been batched together
#but before sending them to the model
def mixup_cutmix_collate_fn(mixup_alpha=0.2, cutmix_alpha=1.0, num_classes=1000):
    #returns a custom collate fn for the dataloader
    #mixup_alpha: the probability of applying mixup
    #cutmix_alpha: the probability of applying cutmix
    #num_classes: the number of classes in the dataset
    
    mix_cut_transforms=None
    mixup_cutmix=[]
    
    if mixup_alpha > 0:
        print("Enabling Mixup!")
        mixup_cutmix.append(v2.MixUp(alpha=mixup_alpha,num_classes=num_classes))
    
    if cutmix_alpha > 0:
        print("Enabling Cutmix!")
        mixup_cutmix.append(v2.CutMix(alpha=cutmix_alpha,num_classes=num_classes))
    
    if len(mixup_cutmix) > 0:
        mix_cut_transforms=v2.RandomChoice(mixup_cutmix)
    
    def collate_fn(batch):
        collated=default_collate(batch)
        
        if mix_cut_transforms is not None:
            collated=mix_cut_transforms(collated)
        
        return collated
        

In [22]:
from datasets import load_dataset

ds = load_dataset("zh-plus/tiny-imagenet",split="train")
ds

DatasetNotFoundError: Dataset 'zh-plus/tiny-imagenet' doesn't exist on the Hub or cannot be accessed.