참고 : 


https://github.com/gymoon10/Paper-Review/blob/main/NLP/Attention%20is%20All%20you%20Need%20-%20%EC%84%A4%EB%AA%85%26%EB%85%BC%EB%AC%B8%EC%9D%BD%EA%B8%B0.ipynb


https://github.com/gymoon10/Paper-Review/blob/main/NLP/BERT.md

In [1]:
!pip install einops

Collecting einops
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops
Successfully installed einops-0.3.2


In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

In [3]:
# Input image (B, C, H, W)
x = torch.randn(1, 3, 224, 224)
x.shape

torch.Size([1, 3, 224, 224])

![image](https://user-images.githubusercontent.com/44194558/145760178-b6fffe96-8ed8-4525-aacf-d3bb43a9b231.png)

## Patch Embeddings

NLP 관점에서 입력 단어에 대한 token embedding

입력 이미지를 patch로 나누고 flatten (1차원 벡터로 projection)

In [None]:
# Input image -> Flattened Patch Sequence (BxCxHxW -> BxNx(P*P*C), P : patch size, N : # of patches H*W/P*P = sequence length)
# Linear Embedding (linearly projects the flattened patches into a lower dimensional space)
patch_size = 16  # 16 pixels / 총 패치의 개수는 14x14
print('x :', x.shape)
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
print('patches :', patches.shape)  # 8x3x(14*16)x(14*16) -> 8x(14*14)x(16*16*3) / NLP 관점에서 196개의 token

x : torch.Size([1, 3, 224, 224])
patches : torch.Size([1, 196, 768])


In [None]:
# Input image -> Flattened Patch
# kernel, stride size를 patch size로 갖는 conv layer 적용
patch_size = 16
in_channels = 3
emb_size = 768

projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),  # NLP 관점에서 입력 단어에 대한 token embedding
            Rearrange('b e (h) (w) -> b (h w) e'),
        )

# 196 : 입력 이미지를 구성하는 196개의 패치
# 768 : 각 패치는 768의 embedding 길이를 갖는 1차원 벡터
projection(x).shape  

torch.Size([1, 196, 768])

In [None]:
# Patch + Position Embedding + [CLS]
emb_size = 768
img_size = 224
patch_size = 16

# 입력 이미지를 patch로 변환 후 flatten
projected_x = projection(x)
print('Projected X shape :', projected_x.shape)

# [CLS] 추가 & positional encoding 초기화
# 학습되면서 이미지의 전반적인 representation, 공간 정보를 반영할 수 있도록 갱신됨
cls_token = nn.Parameter(torch.randn(1,1, emb_size))
positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))
print('Cls Shape :', cls_token.shape, ', Pos Shape :', positions.shape)  # 197=196+1([CLS] 추가로 늘어난 크기)

# [CLS]를 배치사이즈의 크기와 맞춰줌 (repeat)
batch_size = 1
cls_tokens = repeat(cls_token, '() n e -> b n e', b=batch_size)
print('Repeated Cls shape :', cls_tokens.shape)

# [CLS], projected_x를 concat
cat_x = torch.cat([cls_tokens, projected_x], dim=1)

# position encoding을 더해줌
cat_x += positions
print('output : ', cat_x.shape)  # NLP 관점에서 197개의 embedded token

Projected X shape : torch.Size([1, 196, 768])
Cls Shape : torch.Size([1, 1, 768]) , Pos Shape : torch.Size([197, 768])
Repeated Cls shape : torch.Size([1, 1, 768])
output :  torch.Size([1, 197, 768])


참고 : BERT


![image](https://user-images.githubusercontent.com/44194558/145760913-ad8f5a0c-09f1-4196-928c-b102c503de6d.png)

* Input : 입력 이미지에 대한 patch

* Token Embeddings : projected_x

* Position Embeddings : positions


Token Embeddings에 Position Embeddings를 더함

In [None]:
# 랜덤한 값이지만 학습되면서 갱신됨 (learned position embedding is added to the patch representations)
# Model learns to encode distance within the image in the similarity of position embeddings (가까운 패치일 수록 유사도가 높은 embedding을 가짐)
# 2D image topology를 학습
positions  

Parameter containing:
tensor([[ 1.6821, -1.7668,  0.5707,  ..., -0.3544,  1.4827,  2.4232],
        [-1.4466, -0.9558, -0.0930,  ...,  0.4000, -0.0330,  0.1503],
        [-1.8524, -1.4030, -1.2173,  ...,  1.1249, -0.2161,  0.7137],
        ...,
        [-0.9177,  1.9372, -1.6234,  ..., -0.1032, -0.8573,  1.3289],
        [ 0.2576,  0.5573, -1.9590,  ...,  0.2052,  0.7386,  0.2253],
        [ 2.0091, -0.3537,  1.5352,  ...,  0.3217,  0.8891, -1.2121]],
       requires_grad=True)

In [None]:
# 위 과정을 하나의 class로
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e')
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

    def forward(self, x):  # x : Input image
        b, _, _, _ = x.shape
        x = self.projection(x) # Input image to flattened patch
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        x = torch.cat([cls_tokens, x], dim=1)  # [CLS] 추가
        x += self.positions  # position embedding 추가

        return x

In [None]:
PatchEmbedding()(x).shape

torch.Size([1, 197, 768])

In [None]:
patches_embedded = PatchEmbedding()(x)  # NLP의 embedding된 word token 역할 (e.g BERT의 입력)

## MHA (Multi Head Attention)


Q, K, V는 동일한 텐서로 입력

3개의 linear projection을 통해 각각 임베딩되어 여러 개의 head로 나눠진 후 각각 scaled dot product attention 연산 수행 

임베딩된 입력 텐서를 입력으로 받아 다시 임베딩 사이즈로 linear projection을 수행하는 layer를 3개 생성. (학습 과정에서 입력 텐서를 Q, K, V로 만드는 layer가 학습됨)

참고 : Attention is all you need



<br/>


![image](https://user-images.githubusercontent.com/44194558/145761233-dca75a16-727f-43a3-9018-c379f9e1f95f.png)

<br/>

![image](https://user-images.githubusercontent.com/44194558/145761782-793c9358-667d-4f0b-9dba-79bec3b3b579.png)

In [None]:
patches_embedded.shape

torch.Size([1, 197, 768])


![image](https://user-images.githubusercontent.com/44194558/145762488-0490003c-ffee-460f-9efa-d12c8388beff.png)

In [None]:
emb_size = 768
num_heads = 8

# embedding된 입력 텐서를 받아 linear projection 수행 
keys = nn.Linear(emb_size, emb_size)
queries = nn.Linear(emb_size, emb_size)
values = nn.Linear(emb_size, emb_size)
print(keys, queries, values)

Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True)


In [None]:
# QKV (Linear Projection)
queries = rearrange(queries(patches_embedded), "b n (h d) -> b h n d", h=num_heads)
keys = rearrange(keys(patches_embedded), "b n (h d) -> b h n d", h=num_heads)
values  = rearrange(values(patches_embedded), "b n (h d) -> b h n d", h=num_heads)

print('shape :', queries.shape, keys.shape, values.shape)  # (Batch, Heads, Seq_len, Embedding_size), 768=96x8 / 8 : # of heads

shape : torch.Size([1, 8, 197, 96]) torch.Size([1, 8, 197, 96]) torch.Size([1, 8, 197, 96])


In [None]:
# Scaled dot product attention
# Shape : (Batch, Heads, Query_len, Key_len)로 동일
# Queries * Keys
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # 자동으로 transpose
print('energy :', energy.shape)

# Get Attention Score
scaling = emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
print('att :', att.shape)

# Attention Score * values
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
print('out :', out.shape)

# Rearrage to emb_size (Concat)
out = rearrange(out, "b h n d -> b n (h d)")
print('out2 : ', out.shape)  # MHA Output

energy : torch.Size([1, 8, 197, 197])
att : torch.Size([1, 8, 197, 197])
out : torch.Size([1, 8, 197, 96])
out2 :  torch.Size([1, 197, 768])


In [None]:
# 위 과정을 하나의 class로
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # Q,K,V를 하나의 매트릭스에 저장
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # Q,K,V 분할
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # (Batch, Heads, Query_len, Key_len)

        # Attention 연산 수행 시 무시할 정보를 설정
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
  
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)

        return out

In [None]:
MultiHeadAttention()(patches_embedded).shape

torch.Size([1, 197, 768])

## Residuals

In [None]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        
        return x

## MLP

Linear -> GELU -> Dropout -> Linear


![image](https://user-images.githubusercontent.com/44194558/145762755-a0a938f9-a406-4222-b89e-1c1e925fbd0b.png)

In [None]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),  # 확장
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),  # 원래대로 축소         
        )

## Transformer Encoder Block

In [None]:
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

In [None]:
patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape

torch.Size([1, 197, 768])

## Transformer

Transformer Encoder Block을 depth만큼 쌓음

In [None]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

## MLP Head

Classification Layer

In [None]:
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),  # emb_size의 1차원 벡터로 projection
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

## VIT

In [None]:
class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )
        

In [None]:
summary(ViT(), (3, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
    PatchEmbedding-3             [-1, 197, 768]               0
         LayerNorm-4             [-1, 197, 768]           1,536
            Linear-5            [-1, 197, 2304]       1,771,776
           Dropout-6          [-1, 8, 197, 197]               0
            Linear-7             [-1, 197, 768]         590,592
MultiHeadAttention-8             [-1, 197, 768]               0
           Dropout-9             [-1, 197, 768]               0
      ResidualAdd-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
           Linear-12            [-1, 197, 3072]       2,362,368
             GELU-13            [-1, 197, 3072]               0
          Dropout-14            [-1, 19