# 00.Introduction
- ViT에 대해 Library나 Sample로 구현된 것들은 있으나, 하나하나 설명된 것은 없어 이해하기 위해 각 구성요소들을 직접 구현해본다.  
    Though there are library or sample code for ViT, there isn't any source to study and understand that. So I would make it for myself or others

- 기본적으로 ViT는 Bert와 같이 Transformer의 Encoder만을 가져와 사용하는 방식이며, 이는 데이터를 추상화하고 이에 대한 최종 해석값을 내놓는 과정으로 이해할 수 있다.  
  네 발로 걷는 것, 나는 것, 날개가 달린 것, 털이 있는 것이나 비늘이 있는 것으로 모든 동물을 분류하고, 각 객체들을 판단할 때 어떤 요소에 집중했는지를 파악할 수 있어 일종의 XAI와 같은 역할을 하기도 한다.
- ViT는 크게 세 부분으로 구성돼 있으며, 각각 Input Layer / Encoder / MLP Head 이다.  
  이 중 Encoder는 여러 개의 Encoder Block으로 구성돼 있으며, 각각의 Encoder Block은 Self-Attention과 MLP로 구성돼 있다.
  Self-Attention은 일반적으로 Multi-head Attention 방식으로 구현되며, MLP는 다른 사전학습 레이어(ex. Resnet)로 대체될 수 있으나,  
  복수의 Encoder Block이 들어가는 과정에서 대량의 메모리와 처리속도를 고려하면 간단한 레이어를 사용하는 것이 현재로선 좋아보인다.

# 01.Input Layer

In [5]:
import torch

a = torch.randn(1,1,8)
print(a.shape, a)

torch.Size([1, 1, 8]) tensor([[[-1.4656, -0.0863, -1.3829,  0.7328, -0.2383,  1.3856,  0.3638,
          -0.4458]]])


In [6]:
import torch
import torch.nn as nn

class VitInputLayer(nn.Module) :
    def __init__(self,
                 in_channels:int=3,    # 입력 채널 수
                 emb_dim:int=384,      # embedding vector의 길이
                 num_patch_row:int=2,  # 분할할 Patch 단위(Height를 기준으로 - 보통 정사각형으로 처리)
                 image_size:int=32     # 입력 이미지 한 변의 길이
                ) :
        super(VitInputLayer, self).__init__()
        self.in_channels = in_channels
        self.emb_dim = emb_dim
        self.num_patch_row = num_patch_row
        self.image_size = image_size
        
        # STEP 01. patch 수대로 입력 이미지 분할
        ## 기본 입력값(num_patch_row)대로 분할한다면, 2x2=4개의 이미지로 분할됨
        self.num_patch = self.num_patch_row**2
        
        ## patch에 따른 size를 계산하고, 만약 사이즈가 떨어지지 않으면 error
        self.patch_size = int(self.image_size / self.num_patch_row)
        assert self.image_size % self.num_patch_row == 0, "patch size doesn't match with image size"
        
        ## 입력 이미지를 Patch로 분할하고, Patch 단위로 Embedding하기 위한 레이어 구축
        self.patch_emb_layer = nn.Conv2d(
                                        in_channels=self.in_channels,
                                        out_channels=self.emb_dim,
                                        kernel_size=self.patch_size,
                                        stride=self.patch_size
                                        )
        
        # STEP 02. cls token & position embedding
        ## class token(cls token)
        self.cls_token = nn.Parameter(torch.randn(1,1,emb_dim)) # (1, 1, emb_dim) 차원의 Parameter(변경가능한 값)을 정의
        
        ## pos embedding for sequential info(This is general function in NLP, but optional in CV)
        self.pos_emb = nn.Parameter(torch.randn(1, self.num_patch+1, emb_dim))
        
        
    def forward(self, x:torch.Tensor) -> torch.Tensor : 
        """
        x : (B:batch_size, C:channel_nums, H:height, W:width) 차원의 input image
        
        z_0 : (B:batch_size, N:token_nums, D:dim_of_embedding_vector) 차원의 ViT 입력
        
        """
        
        # STEP 03. Patch Embedding & Flatten
        ## Patch Embedding : (B, C, H, W) -> (B, D, H/P, W/P)
        z_0 = self.patch_emb_layer(x)
        
        ## Flatten : (B, D, H/P, W/P) -> (B, D, Np)
        ### Np : patch_nums = ((H*W)/(P^2))
        z_0 = z_0.flatten(2)
        
        ## Transpose : (B, D, Np) -> (B, Np, D)
        z_0 = z_0.transpose(1,2)
        
        # STEP 04. Patch Embedding + cls token
        ## (B, Np, D) -> (B, N, D)
        ### N = (Np+1)
        z_0 = torch.cat([self.cls_token.repeat(repeats=(x.size(0),1,1)), z_0], dim=1)
        
        ## + pos embedding
        z_0 = z_0 + self.pos_emb
        
        return z_0 # (B, N, D)

In [8]:
# Input Layer Check

import torch

batch_size, channel, height, width = 2,3,32,32
x = torch.randn(batch_size, channel, height, width)
input_layer = VitInputLayer(num_patch_row=2)
z_0 = input_layer(x)

print(z_0.shape) # (B, N, D)
print(z_0)

torch.Size([2, 5, 384])
tensor([[[ 0.8181,  1.8900, -0.0716,  ...,  1.2480,  1.6079,  2.0446],
         [ 0.7724, -1.4483,  0.2489,  ..., -2.6664, -0.4604,  0.9533],
         [-0.9500, -2.0801, -1.5512,  ...,  0.3810, -0.2633,  1.0653],
         [-0.4632,  0.8257, -1.2454,  ...,  2.3331, -0.8414, -0.5579],
         [-1.2910,  1.3327,  0.4927,  ...,  0.2185,  0.2061,  0.5208]],

        [[ 0.8181,  1.8900, -0.0716,  ...,  1.2480,  1.6079,  2.0446],
         [ 1.1584, -1.3264, -1.2515,  ..., -3.3588, -0.2809,  1.4500],
         [-1.5001, -1.0683, -1.7576,  ...,  1.7348,  0.2732,  1.6792],
         [-1.0329, -0.0263, -0.3200,  ...,  0.7118, -0.8149, -0.2893],
         [-1.2318,  0.8022, -0.8656,  ...,  0.5152,  0.8738,  0.1135]]],
       grad_fn=<AddBackward0>)


# 02. Multi-Head Self-Attention

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module) :
    def __init__(self,
                 emb_dim:int=384,  # embedding vector 길이
                 head:int=3,       # head 개수
                 dropout:float=0.1  # dropout rate
                ) :
        
        super(MultiHeadSelfAttention, self).__init__()
        self.emb_dim = emb_dim
        self.head = head
        self.head_dim = emb_dim // head
        self.sqrt_dh = self.head_dim**0.5 # scaling factor로 나눔으로써 feature dimension 구현
        
        # STEP 01. Define Layers
        ## Linear Layer for Query, Key, Value weights
        self.w_q = nn.Linear(emb_dim, emb_dim, bias=False)
        self.w_k = nn.Linear(emb_dim, emb_dim, bias=False)
        self.w_v = nn.Linear(emb_dim, emb_dim, bias=False)
        
        ## Dropout Layer
        self.attn_drop = nn.Dropout(dropout)
        
        # MSHA's output layer
        self.w_o = nn.Sequential(
                                nn.Linear(emb_dim, emb_dim),
                                nn.Dropout(dropout)
        )
        
    def forward(self, z:torch.Tensor) -> torch.tensor :
        """
        z : (B:batch_size, N:token_nums, D:vector_dims) 차원 MHSA 입력
        
        out : (B:batch_size, N:token_nums, D:embedding_vector_dims) 차원 MHSA 출력
        """
        
        batch_size, num_patch, _ = z.size()
        
        # STEP 02. calculate self attention score 
        ## q, k, v embedding
        ## (B, N, D) -> (B, N, D)
        q = self.w_q(z)
        k = self.w_k(z)
        v = self.w_v(z)
        
        ## Attention Score 계산을 위한 사전작업
        ## (B, N, D) -> (B, N, h, D//h)
        q = q.view(batch_size, num_patch, self.head, self.head_dim)
        k = k.view(batch_size, num_patch, self.head, self.head_dim)
        v = v.view(batch_size, num_patch, self.head, self.head_dim)
        
        ## (B, N, h, D//h) -> (B, h, N, D//h)
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)
        
        ## 내적 : matmul
        ## k_T : (B, h, N, D//h) -> (B, h, D/h, N)
        k_T = k.transpose(2,3)
        
        ## QKt : (B,h,N,D//h)@(B,h,D//h,N)=(B,h,N,N)
        dots = (q@k_T) / self.sqrt_dh
        
        ## 열방향 softmax 값 & dropout
        attn = F.softmax(dots, dim=-1)
        attn = self.attn_drop(attn)
        
        # 가중화
        ## (B,h,N,N)@(B,h,N,D//h)=(B,h,N,D//h)
        out = attn @ v
        
        ## (B, h, N, D//h) -> (B, N, h, D//h)
        out = out.transpose(1,2)
        
        ## (B, N, h, D//h) -> (B, N, D)
        out = out.reshape(batch_size, num_patch, self.emb_dim)
        
        ## 출력층 : (B, N, D) -> (B, N, D)
        out = self.w_o(out)
        
        return out

In [13]:
# MHSA Test

mhsa = MultiHeadSelfAttention()
out = mhsa(z_0)

# (B=2,N=5,D=384)
print(out.shape, out, sep='\n')

torch.Size([2, 5, 384])
tensor([[[ 0.0944, -0.0101,  0.0938,  ...,  0.0000,  0.1515,  0.2264],
         [ 0.0476,  0.0749, -0.1221,  ...,  0.1416,  0.0000,  0.0889],
         [ 0.0370, -0.1745,  0.0087,  ...,  0.4039,  0.1903,  0.1085],
         [-0.0458, -0.0173,  0.0980,  ...,  0.0000,  0.2504, -0.0000],
         [ 0.0447, -0.0513, -0.0082,  ...,  0.4133,  0.1874,  0.0523]],

        [[ 0.0236,  0.2666, -0.0194,  ...,  0.2342,  0.1559,  0.0163],
         [-0.0915,  0.0011,  0.2990,  ...,  0.4258,  0.1204,  0.0000],
         [ 0.0296, -0.0000,  0.0014,  ...,  0.5210, -0.0426,  0.1725],
         [ 0.0365,  0.1460,  0.0578,  ...,  0.0000,  0.1876, -0.0705],
         [ 0.0720,  0.1341,  0.1422,  ...,  0.3279, -0.0400,  0.1927]]],
       grad_fn=<MulBackward0>)


# 03. Encoder
- 일반적으로 Encoder는 복수의 Encoder Block으로 구성되며, 각 Encoder Block은 **[LayerNorm->MHSA->LayerNorm->MLP]** 로 구성되어 있습니다.  
   

In [15]:
import torch.nn as nn

class VitEncoderBlock(nn.Module) :
    def __init__(
                self,
                emb_dim:int=384,
                head:int=8,
                hidden_dim:int=384*4,
                dropout:float=0.1
                ) :
        
        super(VitEncoderBlock, self).__init__()
        
        # STEP 01. Define Encoder Block Layers
        ## 1st Layer Norm
        self.ln1 = nn.LayerNorm(emb_dim)
        
        ## MHSA
        self.mhsa = MultiHeadSelfAttention(
                                           emb_dim = emb_dim,
                                           head=head,
                                           dropout=dropout,
                                           )
        
        ## 2nd Layer Norm
        self.ln2 = nn.LayerNorm(emb_dim)
        
        ## MLP
        self.mlp = nn.Sequential(
                                nn.Linear(emb_dim, hidden_dim),
                                nn.GELU(),
                                nn.Dropout(dropout),
                                nn.Linear(hidden_dim, emb_dim),
                                nn.Dropout(dropout),
                                )
        
    def forward(self, z:torch.Tensor) -> torch.Tensor :
        """
        z : (B:batch_size, N:token_nums, D:vector_dims) 차원 Encoder Block 입력
        
        out : (B:batch_size, N:token_nums, D:embedding_vector_dims) 차원 Encoder Block 출력
        """
        
        # STEP 02. Construct Encoder Block
        ## 하나의 Encoder Block은 Layer Norm을 기준으로 크게 둘로 나뉘어져 있으며,
        ## 이 과정에서 Residual connection이 고려된다
        
        ### part 1 : MHSA(layerNorm)+ResidualConnection1
        out = self.mhsa(self.ln1(z)) + z
        
        ### part 2 : MLP(layerNorm)+ResidualConnection2
        out = self.mlp(self.ln2(z)) + out
        
        return out   

In [16]:
# Encoder Block Test

vit_enc = VitEncoderBlock()
z_1 = vit_enc(z_0)

# (B=2, N=5, D=384)
print(z_1.shape, z_1, sep='\n')

torch.Size([2, 5, 384])
tensor([[[ 0.2280,  2.2770,  0.0095,  ...,  1.6610,  1.1083,  3.0569],
         [ 0.2297, -1.3305,  0.4470,  ..., -2.5948, -1.0139,  1.5729],
         [-1.5734, -1.7607, -1.6213,  ...,  0.2934, -0.5840,  1.1330],
         [-1.1030,  0.4799, -1.5101,  ...,  2.5298, -0.8372, -0.9231],
         [-1.6123,  1.7179,  0.4211,  ...,  0.1317, -0.3342,  0.8397]],

        [[ 0.2411,  1.9997, -0.0132,  ...,  1.3102,  1.3436,  3.0914],
         [ 1.0210, -1.3893, -0.8584,  ..., -3.6101, -0.7249,  2.0603],
         [-2.2505, -1.1484, -1.9839,  ...,  1.3386, -0.0072,  2.0938],
         [-1.3323, -0.0995, -0.7837,  ...,  0.3104, -0.4181, -0.6488],
         [-1.5087,  1.0819, -0.3002,  ...,  0.3063,  0.7543,  0.6095]]],
       grad_fn=<AddBackward0>)


# 04. Vision Transformer
- 이번에는 이전 단계에서 구축한 input layer/Encoder에 MLP를 더해 전체 ViT를 구현해보겠습니다

In [22]:
import torch.nn as nn

class ViT(nn.Module) :
    def __init__(self,
                 in_channels:int=3,
                 num_classes:int=10,
                 emb_dim:int=384,
                 num_patch_row:int=2,
                 image_size:int=32,
                 num_blocks:int=7,     # Encoder Block의 수
                 head:int=8,
                 hidden_dim:int=384*4,
                 dropout:float=0.1,
                ) :
        
        super(ViT, self).__init__()
        
        # STEP 01. Input Layer
        self.input_layer = VitInputLayer(in_channels,
                                         emb_dim,
                                         num_patch_row,
                                         image_size)
        
        # STEP 02. Encoder = Encoder Block x num_blocks
        self.encoder = nn.Sequential(*[VitEncoderBlock(emb_dim,
                                                       head,
                                                       hidden_dim,
                                                       dropout
                                                      )
                                       for _ in range(num_blocks)])
        
        
        # STEP 03. MLP Head
        self.mlp_head = nn.Sequential(nn.LayerNorm(emb_dim),
                                      nn.Linear(emb_dim, num_classes)
                                     )
        
    def forward(self, x:torch.Tensor) -> torch.Tensor :
        """
        x : (B:batch_size, C:channel_nums, H:height, W:width) 차원의 ViT 입력 이미지 
        
        out : (B:batch_size, M:class_nums) 차원의 ViT 출력값
        
        """
        
        # STEP 04. Construct ViT
        
        ## Input Layer : (B,C,H,W)->(B,N,D)
        ## N : num_tokens, D : dim_vector
        out = self.input_layer(x)
        
        ## Encoder : (B,N,D)->(B,N,D)
        out = self.encoder(out)
        
        ## class token : (B,N,D)->(B,D)
        cls_token = out[:,0]
        
        ## MLP Head : (B,D)->(B,M)
        pred = self.mlp_head(cls_token)
        
        return pred

In [23]:
import torch

num_classes = 10
batch_size, channel, height, width = 2, 3, 32, 32
x = torch.randn(batch_size, channel, height, width)
vit = ViT(in_channels=channel, num_classes=num_classes)
pred = vit(x)

# (B=2, M=10)
print(pred.shape, pred, sep='\n')

torch.Size([2, 10])
tensor([[-0.4547,  0.4218, -0.1759, -0.9137,  0.2284, -0.4648,  0.0166, -0.8592,
         -0.4822, -0.8670],
        [-0.6414,  0.1693, -0.1344, -0.8408,  0.2250, -0.3415,  0.3283, -1.1152,
         -0.4313, -0.7437]], grad_fn=<AddmmBackward0>)
