In [3]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

### 1. Project input to patch
- 입력 이미지를 패치로 나눠준다

In [4]:
# Input
x = torch.randn(8, 3, 224, 224)
print('x: ', x.shape)

patch_size = 16 # 16x16 patches
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
print('patches: ', patches.shape)

x:  torch.Size([8, 3, 224, 224])
patches:  torch.Size([8, 196, 768])


위와 같이 reaarange를 통해 쉽게 reshape 할 수 있지만, 아래와 같이 Conv를 이용해 패치를 구성하면 성능상 이점이 있다고 한다.

In [5]:
patch_size = 16
in_channels = 3
embed_dim = 768

projection = nn.Sequential(
    nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size),
    Rearrange('b e (h) (w) -> b (h w) e')
)

summary(projection, x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 2.30
Params size (MB): 2.25
Estimated Total Size (MB): 5.12
----------------------------------------------------------------


### 2. Patch embedding
- patches에 class token과 positional embedding을 추가한다.
- class toke은 어떤 클래스인지 학습하기 위한 파라미터이다.
- Positional은 패치가 어떤 위치에 있는지 알기 위한 학습 가능한 파라미터이다.

In [6]:
embed_dim = 768
img_size = 224
patch_size = 16

projected_x = projection(x)
print(f'Projected X shape : {projected_x.shape}')

cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, embed_dim))
print(f"Cls Shape : {cls_token.shape}, Position Shape : {positions.shape}")

batch_size = 8
cls_tokens = repeat(cls_token, '() n e -> b n e', b=batch_size)
print(f"Repeated Cls shape : {cls_tokens.shape}")

cat_x = torch.cat([cls_tokens, projected_x], dim= 1)

cat_x += positions
print(f"Cat X shape : {cat_x.shape}")

Projected X shape : torch.Size([8, 196, 768])
Cls Shape : torch.Size([1, 1, 768]), Position Shape : torch.Size([197, 768])
Repeated Cls shape : torch.Size([8, 1, 768])
Cat X shape : torch.Size([8, 197, 768])


In [7]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, embed_dim: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e')
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, embed_dim))
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        
        x = torch.cat([cls_tokens, x], dim=1)
        
        x += self.positions
        
        return x
    
PE = PatchEmbedding()
summary(PE, (3, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 2.30
Params size (MB): 2.25
Estimated Total Size (MB): 5.12
----------------------------------------------------------------


### 3. Multi-head Self-Attention(MHA)
- 패치들에 대해 셀프 어텐션 메커니즘을 적용함

In [8]:
embed_dim = 768
num_heads = 8

keys    = nn.Linear(embed_dim, embed_dim)
queries = nn.Linear(embed_dim, embed_dim)
values  = nn.Linear(embed_dim, embed_dim)
print(f"{keys}, {queries}, {values}")

x = PE(x)
print(f"{queries(x).shape}")

queries = rearrange(queries(x), "b n (h d) -> b h n d", h=num_heads) # -> batch, heads, n, embed_dim
keys    = rearrange(keys(x), "b n (h d) -> b h n d", h=num_heads)
values  = rearrange(values(x), "b n (h d) -> b h n d", h=num_heads)

print(f"shape : {queries.shape}, {keys.shape}, {values.shape}")

Linear(in_features=768, out_features=768, bias=True), Linear(in_features=768, out_features=768, bias=True), Linear(in_features=768, out_features=768, bias=True)
torch.Size([8, 197, 768])
shape : torch.Size([8, 8, 197, 96]), torch.Size([8, 8, 197, 96]), torch.Size([8, 8, 197, 96])


In [9]:
# Queries * Keys
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print(f"energy : {energy.shape}")

# Get Attention Score
scaling = embed_dim ** (1/2)
att = F.softmax(energy / scaling, dim=-1)
print(f"att : {att.shape}")

# Attention Score * values
out = torch.einsum('bhal, bhlv -> bhav', att, values)
print(f"out : {out.shape}")

# Rearrange to embed_dim
out = rearrange(out, 'b h n d -> b n (h d)')
print(f"out2 : {out.shape}")

energy : torch.Size([8, 8, 197, 197])
att : torch.Size([8, 8, 197, 197])
out : torch.Size([8, 8, 197, 96])
out2 : torch.Size([8, 197, 768])


In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
        
        scaling = self.embed_dim ** (1/2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        
        out = torch.einsum('bhal, bhlv -> bhav', att, values)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.projection(out)
        
        return out

x = torch.randn(8, 3, 224, 224)
PE = PatchEmbedding()
x  = PE(x)
print(f"x shape : {x.shape}")
MHA = MultiHeadAttention()
summary(MHA, x.shape[1:], device='cpu')

x shape : torch.Size([8, 197, 768])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1            [-1, 197, 2304]       1,771,776
           Dropout-2          [-1, 8, 197, 197]               0
            Linear-3             [-1, 197, 768]         590,592
Total params: 2,362,368
Trainable params: 2,362,368
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.58
Forward/backward pass size (MB): 6.99
Params size (MB): 9.01
Estimated Total Size (MB): 16.57
----------------------------------------------------------------


### 4. Transformer Encoder Block
- MLP(feed forward) 블록을 만들어주고 MHA와 하나로 묶어준다.

In [13]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x
    
class FeedForwardBlock(nn.Sequential):
    def __init__(self, embed_dim: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(embed_dim, expansion * embed_dim),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * embed_dim, embed_dim),
        )
        
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 embed_dim: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 **  kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(embed_dim),
                MultiHeadAttention(embed_dim, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(embed_dim),
                FeedForwardBlock(embed_dim, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            ))
        )
        
x = torch.randn(8, 3, 224, 224)
x = PE(x)
x = MHA(x)
TE = TransformerEncoderBlock()
summary(TE, x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         LayerNorm-1             [-1, 197, 768]           1,536
            Linear-2            [-1, 197, 2304]       1,771,776
           Dropout-3          [-1, 8, 197, 197]               0
            Linear-4             [-1, 197, 768]         590,592
MultiHeadAttention-5             [-1, 197, 768]               0
           Dropout-6             [-1, 197, 768]               0
       ResidualAdd-7             [-1, 197, 768]               0
         LayerNorm-8             [-1, 197, 768]           1,536
            Linear-9            [-1, 197, 3072]       2,362,368
             GELU-10            [-1, 197, 3072]               0
          Dropout-11            [-1, 197, 3072]               0
           Linear-12             [-1, 197, 768]       2,360,064
          Dropout-13             [-1, 197, 768]               0
      ResidualAdd-14             [-1, 1

### 5. 모두 묶어서 ViT 빌드

In [17]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
        
class ClassificationHead(nn.Sequential):
    def __init__(self, embed_dim: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction="mean"),
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, n_classes)
        )
        
class ViT(nn.Sequential):
    def __init__(self, 
                 in_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 img_size: int = 224,
                 depth: int = 12,
                 n_classes: int = 1000,
                 **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, embed_dim, img_size),
            TransformerEncoder(depth, embed_dim=embed_dim, **kwargs),
            ClassificationHead(embed_dim, n_classes)
        )
        
summary(ViT(), (3, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
    PatchEmbedding-3             [-1, 197, 768]               0
         LayerNorm-4             [-1, 197, 768]           1,536
            Linear-5            [-1, 197, 2304]       1,771,776
           Dropout-6          [-1, 8, 197, 197]               0
            Linear-7             [-1, 197, 768]         590,592
MultiHeadAttention-8             [-1, 197, 768]               0
           Dropout-9             [-1, 197, 768]               0
      ResidualAdd-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
           Linear-12            [-1, 197, 3072]       2,362,368
             GELU-13            [-1, 197, 3072]               0
          Dropout-14            [-1, 19