In [None]:
# Importing Necessary libraris
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim 

import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt 
from collections import defaultdict

In [None]:
# Importing data, setting configuration for the device.
DATA_DIR = './data'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
print(device," is  available!")

cuda  is  available!


In [None]:



# Residual network to add at the end of output.
class Residual(nn.Module):
  def __init__(self,*layers):
    super().__init__()
    self.residual = nn.Sequential(*layers)
    self.gamma = nn.Parameter(torch.zeros(1))
    # nn.Parameter means, the parameter in the paranthesis will be considered as trainable parameter which isn't the one before
    # nn.Parameter is a class in the PyTorch neural network module (torch.nn) that is used to create trainable parameters within a neural network model.
    # This means that it will be included in the computation of gradients during backpropagation and can be optimized by an optimizer.

  def forward(self,x):
    return x + self.gamma*self.residual(x) 
    # Adding residue at the end of output layer.


In [None]:
# Performing Layernormalization to the channels.
class LayerNormChannels(nn.Module):
  def __init__(self,channels):
    super().__init__()
    self.norm = nn.LayerNorm(channels) 

  def forward(self,x):
    x = x.transpose(1,-1) 
    x = self.norm(x) 
    x = x.transpose(-1,1) 
    return x 

In [None]:
# Describing the Self-Attention Mechanism Process 
# STEP1: splitting the fully connected layers obtained after making them linear from patches in to heads by dividing with number of head channels
# STEP2: and specifying the key, Query, Value parameters.
# STEP3: Inventing the position encoding(injecting number for patch) so that we can add it position embedddings.
# Forward:
# STEP1: Defining Key, Query, Value by viewing them as in shape of (batch_size,(width,1),number_of_head_channels,-1)
# STEP2: Calculating attention score and applying softmax, and calc. attention matrix and returning output 
# StaticMethod: Getting indices.


class SelfAttention2d(nn.Module):
  def __init__(self,in_channels,out_channels,head_channels,shape):
    super().__init__()
    # STEP1:
    # 512/8
    self.heads = out_channels //head_channels
    self.head_channels = head_channels 
    self.scale = head_channels**-0.5 
    # STEP2:
    # Calc. hnumber of heads, defining keys,queries,values, unifyheads
    self.to_keys = nn.Conv2d(in_channels,out_channels,1) 
    self.to_queries = nn.Conv2d(in_channels,out_channels,1) 
    self.to_values = nn.Conv2d(in_channels,out_channels,1) 
    self.unifyheads = nn.Conv2d(in_channels,out_channels,1) 
    # STEP3: 
    height,width = shape 
    self.pos_enc = nn.Parameter(torch.Tensor(self.heads,(2*height-1)*(2*width-1)))
    # (8,16*16)
    # , register_buffer() is a method that allows you to register a tensor as a buffer parameter of a PyTorch module.
    # Buffers are tensor-like objects that are not considered model parameters and therefore are not updated during backpropagation.
    #  They are often used to store non-learnable parameters that are associated with a model, such as running statistics used for normalization layers.
    self.register_buffer("relative_indices",self.get_indices(height,width)) 

  def forward(self,x):
    b,_,h,w = x.shape 
    # STEP1:
    keys = self.to_keys(x).view(b,self.heads,self.head_channels,-1)
    values = self.to_values(x).view(b,self.heads,self.head_channels,-1)
    queries = self.to_queries(x).view(b,self.heads,self.head_channels,-1)
    # STEP2:
    attention_score = keys.transpose(-2,-1)@queries 
    indices = self.relative_indices.expand(self.heads,-1) 
    rel_pos_enc = self.pos_enc.gather(-1,indices) 
    print('re',rel_pos_enc.shape)
    rel_pos_enc = rel_pos_enc.unflatten(-1,(h*w,h*w)) 
    print('re flat',rel_pos_enc.shape)
    attention_score  = attention_score* self.scale + rel_pos_enc 
    attention_score = F.softmax(attention_score,dim=-2) 

    out = values @ attention_score 
    out = out.view(b,-1,h,w) 
    out = self.unifyheads(out)  
    return out 
  
  @staticmethod 
  def get_indices(h,w):
    y = torch.arange(h,dtype=torch.long) 
    x = torch.arange(w,dtype=torch.long) 
    y1,x1,y2,x2 = torch.meshgrid(y,x,y,x,indexing='ij')
    indices = (y1-y2+h-1)*(2*w-1) + x1-x2 + w-1 
    indices = indices.flatten() 
    return indices    



In [None]:
# Feed Forward N/w
# numberofhiddenlayers = in_channels * mult 
# A Convolution layer is built on top resulting in in_channels & producing hidden_channels outputs
# And again in-taking hidden_channels as number of input_channels to consider and result in out_channels number
class FeedForward(nn.Sequential):
  def __init__(self,in_channels,out_channels,mult=4):
    hidden_channels = in_channels*mult 
    super().__init__(
        nn.Conv2d(in_channels,hidden_channels,1),
        nn.GELU(),
        nn.Conv2d(hidden_channels,out_channels,1) 
    )

In [None]:
# A Transformer Block Where we group:
## A Residual Block: Attention2d + residue 
## A Residual Block: LayerNormalization o/p + FeedForwarded O/p (Conv + Conv) (i,e,,class FeedForward)
class TransformerBlock(nn.Sequential): 
  def __init__(self,channels,head_channels,shape,p_drop=0.):
    super().__init__(
        Residual(
            LayerNormChannels(channels),
            SelfAttention2d(channels,channels,head_channels,shape),
            nn.Dropout(p_drop)
        ),
        Residual(
            LayerNormChannels(channels),
            FeedForward(channels,channels), 
            nn.Dropout(p_drop)
        )
    )

In [None]:
## Stacking the transformer block num_block times as a list 
class TransformerStack(nn.Sequential):
  def __init__(self,num_blocks,channels,head_channels,shape,p_drop=0.):
    layers = [TransformerBlock(channels,head_channels,shape,p_drop) for _ in range(num_blocks)]
    super().__init__(*layers)

In [None]:
## To split the image in patches by convolving with patch_size and a stride of patch_size 
class ToPatches(nn.Sequential):
  def __init__(self,in_channels,channels,patch_size,hidden_channels=32):
    super().__init__(
        nn.Conv2d(in_channels,hidden_channels,3,padding = 1), 
        nn.GELU(), 
        nn.Conv2d(hidden_channels,channels,patch_size,stride=patch_size)
    )

In [None]:
## Adding Position Embedding to existing input of forward function 
class AddPositionEmbedding(nn.Module):
  def __init__(self,channels,shape):
    super().__init__()
    self.pos_embedding = nn.Parameter(torch.Tensor(channels,*shape))
  
  def forward(self,x):
    print(x.shape)
    # 128,32,16,16
    print('pos',self.pos_embedding.shape) # 32,16,16
    y = x + self.pos_embedding
    # 128,32,16
    print('after_pos',y.shape)
    return x+self.pos_embedding 

In [None]:
# This is a one of main block which involves the processes:
## A. Splitting the image into patches
## B. Adding Position Embedding
## This involves both codes and do operations
class ToEmbedding(nn.Sequential): 
  def __init__(self,in_channels,channels,patch_size,shape,p_drop=0.):
    super().__init__(
        # in_channels = 3, channels = 32, shape = (16,16)
        ToPatches(in_channels,channels,patch_size), 
        AddPositionEmbedding(channels,shape),
        nn.Dropout(p_drop)
    )
    

In [None]:
# Head, Where it does layer normalizations and use a gap layer to flatten it out and again convert to FC layer with  flattened linear layer as a num of classes layer.
class Head(nn.Sequential):
  def __init__(self,in_channels,classes,p_drop=0.):
    super().__init__(
        LayerNormChannels(in_channels), 
        nn.GELU(), 
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(), 
        nn.Dropout(p_drop),
        nn.Linear(in_channels,classes)
    )

In [None]:
## Main Class where it takes all of inputs and involves all the codes above one way or other 
class ViT(nn.Sequential):
  def __init__(self,classes,image_size,channels,head_channels,num_blocks,patch_size,in_channels=3,emb_p_drop=0.,trans_p_drop=0.,head_p_drop=0.):
    reduced_size = image_size//patch_size 
    shape = (reduced_size,reduced_size) 
    
    # Involving process of embedding and stacking transformer layers and takes the output and result in num_of_classes as final output.
    # Process From Input to Output # Making to patches --> Stacking to transformers --> Concatentating all layers and resulting in num_classes as o/p
    super().__init__(
        ToEmbedding(in_channels,channels,patch_size,shape,emb_p_drop),
        TransformerStack(num_blocks,channels,head_channels,shape,trans_p_drop),
        Head(channels,classes,head_p_drop))
    self.reset_parameters() 

  def reset_parameters(self):
    for m in self.modules():
      if isinstance(m,(nn.Conv2d,nn.Linear)):
        nn.init.kaiming_normal_(m.weight) 
        if m.bias is not None:
          nn.init.zeros_(m.bias) 
      elif isinstance(m,nn.LayerNorm):
        nn.init.constant_(m.weight,1.) 
        nn.init.zeros_(m.bias) 
      elif isinstance(m,AddPositionEmbedding): 
        nn.init.normal_(m.pos_embedding,mean=0.0,std=0.02)
      elif isinstance(m,SelfAttention2d):
        nn.init.normal_(m.pos_enc,mean=0.0,std=0.02)
      elif isinstance(m,Residual):
        nn.init.zeros_(m.gamma) 

  def separate_parameters(self):
    parameters_decay = set() 
    parameters_no_decay = set() 
    modules_weight_decay = (nn.Linear,nn.Conv2d)
    modules_no_weight_decay = (nn.LayerNorm,)

    for m_name,m in self.named_modules():
      for param_name,param in m.named_parameters():
        full_param_name = f"{m_name}.{param_name}" if m_name else param_name 

        if isinstance(m,modules_no_weight_decay):
          parameters_no_decay.add(full_param_name) 
        elif param_name.endswith("bias"):
          parameters_no_decay.add(full_param_name) 
        elif isinstance(m,Residual) and param_name.endswith("gamma"):
          parameters_no_decay.add(full_param_name) 
        elif isinstance(m,AddPositionEmbedding) and param_name.endswith("pos_embedding"): 
          parameters_no_decay.add(full_param_name)
        elif isinstance(m,selfAttention2d) and param_name.endswith("pos_enc"):
          parameters_no_decay.add(full_param_name) 
        elif isinstance(m,modules_weight_decay):
          parameters_decay.add(full_param_name) 
    # Sanity Check 
    # assert len(parameters_decay & parameters_no_decay) == 0 
    # asser len(parameters_decay) + len(parameters_no_decay) == len(list(model.parameters()))

    return parameters_decay,parameters_no_decay

In [None]:
NUM_CLASSES,IMAGE_SIZE = 10,32
model = ViT(NUM_CLASSES, IMAGE_SIZE, channels=32, head_channels=8, num_blocks=4, patch_size=2,
               emb_p_drop=0., trans_p_drop=0., head_p_drop=0.1)

In [None]:
model.to(device)

ViT(
  (0): ToEmbedding(
    (0): ToPatches(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): GELU(approximate='none')
      (2): Conv2d(32, 32, kernel_size=(2, 2), stride=(2, 2))
    )
    (1): AddPositionEmbedding()
    (2): Dropout(p=0.0, inplace=False)
  )
  (1): TransformerStack(
    (0): TransformerBlock(
      (0): Residual(
        (residual): Sequential(
          (0): LayerNormChannels(
            (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          )
          (1): SelfAttention2d(
            (to_keys): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
            (to_queries): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
            (to_values): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
            (unifyheads): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
          )
          (2): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Residual(
        (residual): Sequential(
          (0): La

In [None]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))

Number of parameters: 79,810


In [None]:
IMAGE_SIZE = 32

NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 128
EPOCHS = 25

LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=4)


Files already downloaded and verified




Files already downloaded and verified


In [None]:
import time
# GradScaler is used for computing gradients and updating carefully such that no overflow and underflow happen.
# autocast, which converts high valued floating point numbers in to lower data type floating points such that to efficiently use memory and improve performance>

clip_norm = True
lr_schedule = lambda t: np.interp([t], [0, EPOCHS*2//5, EPOCHS*4//5, EPOCHS], 
                                  [0, 0.01, 0.01/20.0, 0])[0]

model = nn.DataParallel(model, device_ids=[0]).cuda()
opt = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()
for epoch in range(EPOCHS):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    for i, (X, y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda() 
        lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        opt.param_groups[0].update(lr=lr)

        opt.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(X)
            loss = criterion(output, y)

        scaler.scale(loss).backward()
        if clip_norm:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)

    print(f'ConvMixer: Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')


torch.Size([128, 32, 16, 16])
pos torch.Size([32, 16, 16])
after_pos torch.Size([128, 32, 16, 16])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
torch.Size([128, 32, 16, 16])
pos torch.Size([32, 16, 16])
after_pos torch.Size([128, 32, 16, 16])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
torch.Size([128, 32, 16, 16])
pos torch.Size([32, 16, 16])
after_pos torch.Size([128, 32, 16, 16])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
re torch.Size([4, 65536])
re flat torch.Size([4, 256, 256])
re torch.Size([4, 65536])
re flat torch.Siz

KeyboardInterrupt: ignored