In [132]:
import torch
import torch.nn as nn
from torchsummary import summary
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from typing import Any,Callable,Tuple

In [78]:
vit_config = {
    "user" :
        {"num_layers" : 8, "embed_dim"  : 192, "mlp_dim" : 192*4, "num_heads" : 8},
    "base" :
        {"num_layers" : 12, "embed_dim"  : 768, "mlp_dim" : 3072, "num_heads" : 12},
    "large" :
        {"num_layers" : 24, "embed_dim"  : 1024, "mlp_dim" : 4096, "num_heads" : 16},
    "huge" :
        {"num_layers" : 32, "embed_dim"  : 1280, "mlp_dim" : 5120, "num_heads" : 16}
}

In [105]:
class Patch2Vector(nn.Module):
    def __init__ (self, patch_size, channel, embed_dim, num_patches, embed_type = 'conv', ):
        super().__init__()
        if embed_type =='conv':
            self.projection = nn.Sequential(
                nn.Conv2d(channel, embed_dim, kernel_size=patch_size, stride=patch_size),
                Rearrange('batch (embed_dim) h w -> batch (h w) embed_dim')
            )
        elif embed_type == 'flatten': #proposed method in original VIT
            self.projection = nn.Sequential(
                Rearrange('batch c (h p1) (w p2) -> batch (h w) (p1 p2)', p1 = patch_size, p2 = patch_size),
                nn.Linear(patch_size**2*channel, embed_dim),
            )
        else :
            raise NotImplementedError("embed_type only 'conv' or flat ")
        self.cls_token= nn.Parameter(torch.randn(1, 1, embed_dim)) #shape (1, 1, embed_dim)
        self.pos_embedding = nn.Parameter(torch.randn(num_patches + 1, embed_dim)) #shape (num_patches+1, embed_dim)
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        x = self.projection(x)# b,c,h,w -> b (h*w)/p^2 c*p^2 : batch num_patches embed_dim
        batch = x.shape[0]
        x_cls = repeat(self.cls_token, '1 1 e -> b 1 e', b=batch) 
        embedding = torch.cat((x_cls, x), dim=1) + self.pos_embedding # b n+1 e. add cls_token 
        return embedding

In [106]:
class MultiheadSelfAtteintion(nn.Module):
    def __init__(self, embed_dim, num_heads, drop_out):
        super().__init__()
        self.num_heads = num_heads
        self.scaling = (embed_dim//num_heads)**-0.5
        
        self.qkv = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*3), 
            Rearrange('b n (qkv d) -> qkv b n d' , qkv = 3),
            Rearrange('qkv b n (h d) -> qkv b h n d', h = num_heads)
            )
        self.dropout = nn.Dropout(drop_out)
        self.o = nn.Linear(embed_dim, embed_dim)
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        q, k, v= self.qkv(x) #b h n d
        
        att_score = torch.matmul(q,k.transpose(-2,-1))*self.scaling # [b h (n d)]@[b h (d n)] -> b h n n, @ is mat mul  
        att_score = torch.softmax(att_score, dim = -1)
        att_score = self.dropout(att_score)
        
        att = torch.matmul(att_score, v) # [b h (n n)] @ [b h (n d)] -> b h n d
        att = rearrange(att, 'b h n d -> b n (h d)')
        
        return self.o(att)

In [107]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim, mlp_dim, drop_out):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(drop_out),
            nn.Linear(mlp_dim,embed_dim),
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.ff(x)

In [118]:
class EncoderBlock(nn.ModuleList):
    def __init__(self, embed_dim, mlp_dim, num_heads, drop_out):
        super().__init__()
        self.MSA = nn.Sequential(
            nn.LayerNorm(embed_dim),
            MultiheadSelfAtteintion(embed_dim, num_heads, drop_out),
            )
        self.FF = nn.Sequential(
            nn.LayerNorm(embed_dim),
            FeedForward(embed_dim, mlp_dim, drop_out),
        )
        self.dropout = nn.Dropout(drop_out)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.MSA(x) + x
        x = self.dropout(x)
        x = self.FF(x) + x
        x = self.dropout(x)
        return x

In [150]:
# class 
class VIT(nn.Module):
    def __init__(self,  image_size : Tuple[int, int, int] = (3, 224, 224), patch_size : int = 16, num_classes : int = 1000, embed_type :str = 'conv', vit_type :str = 'base' ,pred_type : str = 'mean', dropout = 0., embed_dropout = 0.):
        super().__init__()
        assert vit_type in ('user','base', 'large', 'huge'), f"vit_type must be 'usser' or 'base' or 'large' or 'huge'. but {vit_type}"
        assert pred_type in ('mean', 'cls_token'), f"pred_type must be 'mean' or 'cls_token'. but {pred_type}"
        self.pred_type = pred_type
        self.num_classes = num_classes
        self.channel , self.height,self.width = image_size
        assert not (self.height%patch_size or self.width%patch_size)  , f"image size must be divisible by patch size,({self.height},{self.width}) can't devide by {patch_size} "
        self.num_patches = (self.height*self.width)//patch_size**2
        self.patch_size = patch_size
        self.config = vit_config[vit_type]
        for k, v in self.config.items(): setattr(self, k, v)
        
        self.att_dropout_ratio = dropout
        self.embed_dropout_ratio = embed_dropout
        
        assert self.embed_dim==self.patch_size**2*self.channel, f"embed dimension of VIT-{vit_type} : {self.embed_dim}. but {self.patch_size**2*self.channel} " 
        ###patch embedding
        self.patch2vec = Patch2Vector(patch_size,self.channel, self.embed_dim, self.num_patches, embed_type)
        self.embed_dropout = nn.Dropout(embed_dropout)
        ###Tansformer Encoder
        self.encoders = nn.ModuleList([EncoderBlock(self.embed_dim, self.mlp_dim, self.num_heads, dropout) for _ in range(self.num_layers)])
        ###classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, num_classes))
    def forward(self,x : torch.Tensor) -> torch.Tensor:
        """
        x : (batch, channel, height, width)
        """
        if len(x.shape)==4:
            pass
        elif len(x.shape)==3:
            x = x.unsqueeze(dim=0)
        else :
            raise ValueError(f"input dimension only allowed by (batch, channel, height, width) or (channel, height, width) but {x.shape}")
        assert x.shape[-2]==self.height and x.shape[-1]==self.width, f"expected height and width are {(self.height, self.width)} but {x.shape}"
        x = self.patch2vec(x) 
        x = self.embed_dropout(x) #batch num_patches+1 embed_dim
        
        for encoder in self.encoders:
            x = encoder(x)+x#batch num_patches+1 embed_dim
        x = x.mean(dim=1) if self.pred_type=='mean' else x[:,0] # b n+1 e -> b e
        #class token을 이용해 예측하는 경우 b n e 중 n의 첫번째에 해당하는 벡터이므로 이를 이용, 평균을 이용하는 경우 n에 대한 평균
        '''
        Both during pre-training and fine-tuning, a classification head is attached to z^0_L
        The classification head is implemented by a MLP with one hidden layer at pre-training
        time and by a single linear layer at fine-tuning time.
        '''
        return self.classifier(x)

In [159]:
model = VIT(image_size=(3,32,32), patch_size=8, num_classes = 10,vit_type='user' , embed_type = 'conv', embed_dropout=0.1,dropout=0.1)
model(torch.randn(1,3,32,32))
summary(model, (3,32,32),device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 192, 4, 4]          37,056
         Rearrange-2              [-1, 16, 192]               0
      Patch2Vector-3              [-1, 17, 192]               0
           Dropout-4              [-1, 17, 192]               0
         LayerNorm-5              [-1, 17, 192]             384
            Linear-6              [-1, 17, 576]         111,168
         Rearrange-7           [-1, 2, 17, 192]               0
         Rearrange-8         [-1, 2, 8, 17, 24]               0
           Dropout-9            [-1, 8, 17, 17]               0
           Linear-10              [-1, 17, 192]          37,056
MultiheadSelfAtteintion-11              [-1, 17, 192]               0
          Dropout-12              [-1, 17, 192]               0
        LayerNorm-13              [-1, 17, 192]             384
           Linear-14             

# Train CIFAR 10

In [155]:
from torchvision import datasets,transforms
from torchvision.transforms import ToTensor,Compose,Normalize
from torch import optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader

In [156]:
train_transform= Compose(
    [
     transforms.RandomHorizontalFlip(),
     ToTensor(),
     Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
     ])
test_transform  = Compose(
    [
     transforms.ToTensor(),
     Normalize((0.491, 0.482 ,0.447), (0.247, 0.243, 0.262))
     ])
train_set = datasets.CIFAR10(root='./data/',train=True,download=True,transform = train_transform)
test_set = datasets.CIFAR10(root='./data/',train=False,download=True,transform = test_transform)

Files already downloaded and verified
Files already downloaded and verified


In [157]:
BATCH_SIZE = 64 
EPOCHS = 10
LR = 1e-3
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = LR ,weight_decay=0.01)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max= EPOCHS, eta_min=LR/20)
train_loader = DataLoader(train_set, BATCH_SIZE)
test_loader = DataLoader(test_set, BATCH_SIZE)