<a href="https://colab.research.google.com/github/gremlin97/EVA-8/blob/main/S10/Implementations/ConvMixer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementation: Patches are all you Need?

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import argparse

In [2]:
# Residual Class Template created for further use in the network
class Residual(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn
  
  def forward(self, x):
    return self.fn(x) + x

In [3]:
def ConvMixer(dim, depth, kernel_size=3, patch_size=2, n_classes=10):
  return nn.Sequential(
      nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), # Initial Convolution to create the patches as well as the embeddings
      nn.GELU(), # Batchnorm and Gelu (~ViT activation function) activation used
      nn.BatchNorm2d(dim),
      *[  # Unroll (list comprehension) Sequential block consisting of conv-gelu-bn => point-wise convolution to specified depth
          nn.Sequential(
              Residual( # Residual block to add residue to high connections x=f(x)+x
                  nn.Sequential( 
                      nn.Conv2d(dim, dim, kernel_size, groups=dim, padding = "same"), 
                      nn.GELU(),
                      nn.BatchNorm2d(dim)
                  ) # Logic: Inital conv2d creates our value matrices (256) from our patch embeddings. We use the kernel as our effective attention matrix (q x k.T) and repeat this structure as successive attention blocks. The channels represent multi-head attention 
              ),
              nn.Conv2d(dim, dim, kernel_size=1), # Pointwise convolution to mix and merge channels (dim is constant)
              nn.GELU(),
              nn.BatchNorm2d(dim)
          )for i in range(depth) # Unroll to depth d
      ],
      nn.AdaptiveAvgPool2d((1,1)), # GAP Layer
      nn.Flatten(),
      nn.Linear(dim, n_classes) # FC layer=num of classes / MLP head
  )

In [4]:
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

epochs = 25
batch_size = 512

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=4)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data




Files already downloaded and verified


In [13]:
depth = 10
hdim = 256
psize = 2
conv_ks =  5
clip_norm =  True

model = ConvMixer(hdim, depth, psize, conv_ks, 10)
model = nn.DataParallel(model, device_ids = [0]).cuda()

opt = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

lr_schedule = lambda t: np.interp([t],[0, epochs*2//5, epochs*4//5, epochs], [0, 0.01, 0.01/20, 0])[0]

In [14]:
from torchsummary import summary
summary(model, input_size=(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 256, 6, 6]          19,456
              GELU-2            [-1, 256, 6, 6]               0
       BatchNorm2d-3            [-1, 256, 6, 6]             512
            Conv2d-4            [-1, 256, 6, 6]           1,280
              GELU-5            [-1, 256, 6, 6]               0
       BatchNorm2d-6            [-1, 256, 6, 6]             512
          Residual-7            [-1, 256, 6, 6]               0
            Conv2d-8            [-1, 256, 6, 6]          65,792
              GELU-9            [-1, 256, 6, 6]               0
      BatchNorm2d-10            [-1, 256, 6, 6]             512
           Conv2d-11            [-1, 256, 6, 6]           1,280
             GELU-12            [-1, 256, 6, 6]               0
      BatchNorm2d-13            [-1, 256, 6, 6]             512
         Residual-14            [-1, 25

In [17]:
# Checking lr_schedule and numpy interpolation
np.interp([1],[0, epochs*2//5, epochs*4//5, epochs], [0, 0.01, 0.01/20, 0])[0]
lr_schedule(1)

0.001

In [18]:
for epoch in range(epochs):
  start = time.time()

  train_loss, train_acc, n = 0, 0 ,0

  for i, (X,y) in enumerate(trainloader):
    model.train()
    X, y = X.cuda(), y.cuda()
    lr = lr_schedule(epoch + (i+1)/len(trainloader))
    opt.param_groups[0].update(lr=lr)
    opt.zero_grad()

    with torch.cuda.amp.autocast():
      output = model(X)
      loss = criterion(output,y)
    
    scaler.scale(loss).backward()

    if clip_norm:
      scaler.unscale_(opt)
      nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
    scaler.step(opt)
    scaler.update()

    # print(y.size(0))
    # print(loss.item())
    train_loss += loss.item()*y.size(0)
    train_acc += (output.max(1)[1]==y).sum().item()
    n+=y.size(0)

  model.eval()
  test_acc, m = 0, 0

  with torch.no_grad():
    for i, (X,y) in enumerate(testloader):
      X, y = X.cuda(), y.cuda()
      with torch.cuda.amp.autocast():
        output = model(X)
      
      test_acc += (output.max(1)[1] == y).sum().item()
      m += y.size(0)

  print(f'ConvMixer: Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')



ConvMixer: Epoch: 0 | Train Acc: 0.2859, Test Acc: 0.4253, Time: 58.9, lr: 0.001000
ConvMixer: Epoch: 1 | Train Acc: 0.4454, Test Acc: 0.5106, Time: 59.4, lr: 0.002000
ConvMixer: Epoch: 2 | Train Acc: 0.5193, Test Acc: 0.5741, Time: 58.8, lr: 0.003000
ConvMixer: Epoch: 3 | Train Acc: 0.5661, Test Acc: 0.6028, Time: 58.2, lr: 0.004000
ConvMixer: Epoch: 4 | Train Acc: 0.6049, Test Acc: 0.6550, Time: 57.7, lr: 0.005000
ConvMixer: Epoch: 5 | Train Acc: 0.6319, Test Acc: 0.6848, Time: 60.2, lr: 0.006000
ConvMixer: Epoch: 6 | Train Acc: 0.6512, Test Acc: 0.6913, Time: 57.7, lr: 0.007000
ConvMixer: Epoch: 7 | Train Acc: 0.6682, Test Acc: 0.6978, Time: 58.3, lr: 0.008000
ConvMixer: Epoch: 8 | Train Acc: 0.6772, Test Acc: 0.7096, Time: 57.1, lr: 0.009000
ConvMixer: Epoch: 9 | Train Acc: 0.6884, Test Acc: 0.7225, Time: 59.4, lr: 0.010000
ConvMixer: Epoch: 10 | Train Acc: 0.7058, Test Acc: 0.7280, Time: 57.1, lr: 0.009050
ConvMixer: Epoch: 11 | Train Acc: 0.7228, Test Acc: 0.7479, Time: 57.1, lr: