<a href="https://colab.research.google.com/github/navrat/EVA_phase1_2022_23/blob/main/S10_VIT/CV_Patches_Are_All_You_Need_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://arxiv.org/abs/2201.09792

## Patches Are All You Need?
### Asher Trockman, J. Zico Kolter


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import argparse

In [3]:
torch.cuda.is_available()

True

In [10]:
class Residual(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x):
    return self.fn(x) + x
  
def convmixer(dim, depth, kernel_size = 5, patch_size=2, n_classes=10):
  return nn.Sequential(
      nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), # 16X16 | kernel 2X2 with stride = 2 | 8X8 output
      nn.GELU(),
      nn.BatchNorm2d(dim),
      *[
          nn.Sequential(
              Residual(
                  nn.Sequential(nn.Conv2d(dim, dim, kernel_size, groups=dim, padding= "same"),
                  nn.GELU(),
                  nn.BatchNorm2d(dim)
              )),          
          nn.Conv2d(dim, dim, kernel_size =1),
          nn.GELU(),
          nn.BatchNorm2d(dim)
          ) for i in range(depth)],
      nn.AdaptiveAvgPool2d((1,1)),
      nn.Flatten(),
      nn.Linear(dim, n_classes)
  )


In [5]:
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

epochs = 20
batch_size =64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=4)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data




Files already downloaded and verified


In [7]:
lr_schedule = lambda t: np.interp([t], [0, sample_epochs*2//5, sample_epochs*4//5, sample_epochs], 
                                  [0, 0.01, 0.01/20.0, 0])[0]
sample_epochs = 100
print([0, sample_epochs*2//5, sample_epochs*4//5, sample_epochs])
print(lr_schedule(10))
print(lr_schedule(20))
print(lr_schedule(40))
print(lr_schedule(70))

[0, 40, 80, 100]
0.0025
0.005
0.01
0.002875


In [11]:
lr_schedule = lambda t: np.interp([t], [0, epochs*2//5, epochs*4//5, epochs], [0,0.01, 0.01/20.0, 0])[0]

depth = 10
hdim = 256
psize = 2
conv_ks = 5
clip_norm = True

model_conv = convmixer(hdim, depth, patch_size=psize, kernel_size = conv_ks, n_classes=10)
model = nn.DataParallel(model_conv, device_ids = [0]).cuda()

opt = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

for epoch in range(epochs):
  start = time.time()
  train_loss, train_acc, n = 0, 0, 0
  for i, (X, y) in enumerate(trainloader):
    model.train()
    X, y = X.cuda(), y.cuda()
    lr = lr_schedule(epoch + (i+1)/len(trainloader))
    opt.param_groups[0].update(lr=lr)

    opt.zero_grad()

    with torch.cuda.amp.autocast():
      output = model(X)
      loss = criterion(output, y)

    scaler.scale(loss).backward()
    if clip_norm:
      scaler.unscale_(opt)
      nn.utils.clip_grad_norm_(model.parameters(), 1.0)
      scaler.update()

    train_loss += loss.item() * y.size(0)
    train_acc += (output.max(1)[1] == y).sum().item()
    n += y.size(0)
  
  model.eval()
  test_acc, m = 0, 0
  with torch.no_grad():
    for i, (X, y) in enumerate(testloader):
      X, y = X.cuda(), y.cuda()
      with torch.cuda.amp.autocast():
        output = model(X)
        test_acc += (output.max(1)[1] == y).sum().item()
        m += y.size(0)

  print(f'ConvMixer: Epoch: {epoch} | Train Acc: {train_acc/n: .4f}, Test_Acc: {test_acc/m: .4f}, Time: {time.time() - start: .1f}, lr: {lr: .6f}')




ConvMixer: Epoch: 0 | Train Acc:  0.1071, Test_Acc:  0.1119, Time:  62.6, lr:  0.001250
ConvMixer: Epoch: 1 | Train Acc:  0.1068, Test_Acc:  0.1128, Time:  63.5, lr:  0.002500
ConvMixer: Epoch: 2 | Train Acc:  0.1090, Test_Acc:  0.1112, Time:  63.5, lr:  0.003750
ConvMixer: Epoch: 3 | Train Acc:  0.1078, Test_Acc:  0.1111, Time:  63.5, lr:  0.005000
ConvMixer: Epoch: 4 | Train Acc:  0.1068, Test_Acc:  0.1110, Time:  63.7, lr:  0.006250
ConvMixer: Epoch: 5 | Train Acc:  0.1090, Test_Acc:  0.1106, Time:  63.7, lr:  0.007500
ConvMixer: Epoch: 6 | Train Acc:  0.1085, Test_Acc:  0.1136, Time:  64.1, lr:  0.008750
ConvMixer: Epoch: 7 | Train Acc:  0.1087, Test_Acc:  0.1127, Time:  63.9, lr:  0.010000
ConvMixer: Epoch: 8 | Train Acc:  0.1094, Test_Acc:  0.1119, Time:  63.7, lr:  0.008813
ConvMixer: Epoch: 9 | Train Acc:  0.1075, Test_Acc:  0.1116, Time:  64.1, lr:  0.007625
ConvMixer: Epoch: 10 | Train Acc:  0.1113, Test_Acc:  0.1129, Time:  64.8, lr:  0.006438
ConvMixer: Epoch: 11 | Train Ac

KeyboardInterrupt: ignored

In [None]:
## The abvoe code has very limited accuracy boost per epoch

In [None]:
## The below code has been picked from the paper directly and shows immediate accuracy boost per epoch

In [12]:
class Residual(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn 
  def forward(self, x):
    return self.fn(x) + x


def ConvMixer(dim, depth, kernel_size=5, patch_size=2, n_classes=10):
  return nn.Sequential(
      nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
      nn.GELU(),
      nn.BatchNorm2d(dim),
      *[nn.Sequential(
          Residual(nn.Sequential(
              nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
              nn.GELU(),
              nn.BatchNorm2d(dim)
          )),
          nn.Conv2d(dim, dim, kernel_size=1),
          nn.GELU(),
          nn.BatchNorm2d(dim)
      ) for i in range(depth)],
      nn.AdaptiveAvgPool2d((1, 1)),
      nn.Flatten(),
      nn.Linear(dim, n_classes)
  )

In [13]:
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

epochs = 25
batch_size =512

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
lr_schedule = lambda t: np.interp([t], [0, epochs*2//5, epochs*4//5, epochs], 
                                  [0, 0.01, 0.01/20.0, 0])[0]

depth = 10
hdim = 256
psize = 2
conv_ks = 5
clip_norm = True

model = ConvMixer(hdim, depth, patch_size=psize, kernel_size=conv_ks, n_classes=10)
model = nn.DataParallel(model, device_ids=[0]).cuda()

opt = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

for epoch in range(epochs):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    for i, (X, y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda()

        lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        opt.param_groups[0].update(lr=lr)

        opt.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(X)
            loss = criterion(output, y)

        scaler.scale(loss).backward()
        if clip_norm:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)

    print(f'ConvMixer: Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')



ConvMixer: Epoch: 0 | Train Acc: 0.3425, Test Acc: 0.4941, Time: 62.8, lr: 0.001000
ConvMixer: Epoch: 1 | Train Acc: 0.5456, Test Acc: 0.5807, Time: 54.1, lr: 0.002000
ConvMixer: Epoch: 2 | Train Acc: 0.6424, Test Acc: 0.6467, Time: 54.9, lr: 0.003000
ConvMixer: Epoch: 3 | Train Acc: 0.7056, Test Acc: 0.7444, Time: 54.8, lr: 0.004000
ConvMixer: Epoch: 4 | Train Acc: 0.7443, Test Acc: 0.7637, Time: 55.3, lr: 0.005000
ConvMixer: Epoch: 5 | Train Acc: 0.7632, Test Acc: 0.7765, Time: 54.5, lr: 0.006000
ConvMixer: Epoch: 6 | Train Acc: 0.7796, Test Acc: 0.8000, Time: 54.8, lr: 0.007000
ConvMixer: Epoch: 7 | Train Acc: 0.7933, Test Acc: 0.7926, Time: 55.0, lr: 0.008000
ConvMixer: Epoch: 8 | Train Acc: 0.8048, Test Acc: 0.8024, Time: 59.4, lr: 0.009000
ConvMixer: Epoch: 9 | Train Acc: 0.8136, Test Acc: 0.8317, Time: 54.6, lr: 0.010000
ConvMixer: Epoch: 10 | Train Acc: 0.8273, Test Acc: 0.8463, Time: 55.3, lr: 0.009050
ConvMixer: Epoch: 11 | Train Acc: 0.8444, Test Acc: 0.8464, Time: 55.3, lr: