[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/devdastl/EVA-8_Phase-1_Assignment-10/blob/main/ViT_with_convolution/ViT_with_convoluton.ipynb)

In [1]:
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision

In [2]:
DATA_DIR='./data'

In [3]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)

device: cuda


In [4]:
!pip install einops

Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 374 kB/s eta 0:00:01
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [5]:
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes
class ModifyConv2d(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.my_conv2d = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, bias=False)

    def forward(self,x):
        b,p,_ = x.shape
        #print(f"x before permute {x.shape}")
        x = x.permute(0,2,1)
        #print(f"x after permute {x.shape}")
        x = x.unsqueeze(dim=-1)
        #print(f"x after adding 1 {x.shape}")
        # print(f"x after unsqueeze {x.shape}")
        x = self.my_conv2d(x)
        #print(f"x after conv {x.shape} with {self.out_channel}")
        #print(f"x after repermute {x.squeeze().permute(0,2,1).shape}")
        return x.squeeze().permute(0,2,1)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            #nn.Linear(dim, hidden_dim),
            ModifyConv2d(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            #nn.Linear(hidden_dim, dim),
            ModifyConv2d(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        #self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_qkv = ModifyConv2d(dim, inner_dim * 3)

        self.to_out = nn.Sequential(
            #nn.Linear(inner_dim, dim),
            ModifyConv2d(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        #print(f"x before entering attention head {x.shape}")
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            ModifyConv2d(patch_dim,dim),
            #nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [6]:
model = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)

In [7]:
from torchsummary import summary

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
depth = 10
hdim = 256
psize = 2
conv_ks = 5
clip_norm = True

#model = ViT(hdim, depth, patch_size=psize, kernel_size=conv_ks, n_classes=10).get_model()


summary(model.to(device), input_size=(3, 32, 32))

cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1               [-1, 64, 48]               0
            Conv2d-2           [-1, 512, 64, 1]          24,576
      ModifyConv2d-3              [-1, 64, 512]               0
           Dropout-4              [-1, 65, 512]               0
         LayerNorm-5              [-1, 65, 512]           1,024
            Conv2d-6          [-1, 1536, 65, 1]         786,432
      ModifyConv2d-7             [-1, 65, 1536]               0
           Softmax-8            [-1, 8, 65, 65]               0
            Conv2d-9           [-1, 512, 65, 1]         262,144
     ModifyConv2d-10              [-1, 65, 512]               0
          Dropout-11              [-1, 65, 512]               0
        Attention-12              [-1, 65, 512]               0
          PreNorm-13              [-1, 65, 512]               0
        LayerNorm-14              

In [8]:
model.to(DEVICE)

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)
    (1): ModifyConv2d(
      (my_conv2d): Conv2d(48, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (to_qkv): ModifyConv2d(
              (my_conv2d): Conv2d(512, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)
            )
            (to_out): Sequential(
              (0): ModifyConv2d(
                (my_conv2d): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              )
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine

In [9]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))

Number of parameters: 9,513,994


In [10]:
IMAGE_SIZE = 32

NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 128 * 2
EPOCHS = 25

LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=4*2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=4)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [11]:
import time

model = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)
model.to(device)

clip_norm = True
lr_schedule = lambda t: np.interp([t], [0, EPOCHS*2//5, EPOCHS*4//5, EPOCHS], 
                                  [0, 0.01, 0.01/20.0, 0])[0]

model = nn.DataParallel(model, device_ids=[0,1]).cuda()
# opt = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
opt = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, EPOCHS)


for epoch in range(EPOCHS):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    for i, (X, y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda()

       # lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        #opt.param_groups[0].update(lr=lr)

        opt.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(X)
            loss = criterion(output, y)

        scaler.scale(loss).backward()
        # if clip_norm:
        #     scaler.unscale_(opt)
        #     nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)

    print(f'ViT: Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}')
    scheduler.step(epoch-1) # step cosine scheduling
    
    # list_loss.append(val_loss)
    # list_acc.append(acc)


ViT: Epoch: 0 | Train Acc: 0.3036, Test Acc: 0.4442, Time: 49.6




ViT: Epoch: 1 | Train Acc: 0.4160, Test Acc: 0.5140, Time: 45.7
ViT: Epoch: 2 | Train Acc: 0.4574, Test Acc: 0.5311, Time: 45.4
ViT: Epoch: 3 | Train Acc: 0.4842, Test Acc: 0.5488, Time: 46.4
ViT: Epoch: 4 | Train Acc: 0.5062, Test Acc: 0.5696, Time: 45.7
ViT: Epoch: 5 | Train Acc: 0.5207, Test Acc: 0.5680, Time: 45.3
ViT: Epoch: 6 | Train Acc: 0.5365, Test Acc: 0.5981, Time: 45.7
ViT: Epoch: 7 | Train Acc: 0.5482, Test Acc: 0.5996, Time: 45.3
ViT: Epoch: 8 | Train Acc: 0.5617, Test Acc: 0.6154, Time: 44.9
ViT: Epoch: 9 | Train Acc: 0.5685, Test Acc: 0.6183, Time: 45.8
ViT: Epoch: 10 | Train Acc: 0.5748, Test Acc: 0.6297, Time: 45.5
ViT: Epoch: 11 | Train Acc: 0.5892, Test Acc: 0.6447, Time: 45.2
ViT: Epoch: 12 | Train Acc: 0.5946, Test Acc: 0.6405, Time: 45.4
ViT: Epoch: 13 | Train Acc: 0.5988, Test Acc: 0.6618, Time: 45.6
ViT: Epoch: 14 | Train Acc: 0.6068, Test Acc: 0.6532, Time: 45.0
ViT: Epoch: 15 | Train Acc: 0.6123, Test Acc: 0.6639, Time: 45.4
ViT: Epoch: 16 | Train Acc: 0.6210