In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat # einops 직관적으로 사용할 수 있는 차원관리 package
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

  warn(f"Failed to load image Python extension: {e}")


### Project input to Patches

In [2]:
# input
x = torch.randn(8, 3, 224, 224) # (batch_size, channel, height, width)
print(f'x: {x.shape}')

# b c (h s1) (w s2) # height, width를 각각 patch_size로 나누라는 말
# b (h w) (s1 s2 c) # batch_size는 그대로 두고 h*w로 펼쳐서 1차원으로 만들고(14*14=196) s1*s2*3을 펼쳐서 (16*16*3)이 된다. 
patch_size = 16
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', 
                    s1=patch_size, s2=patch_size)
print(f"patches: {patches.shape}")

x: torch.Size([8, 3, 224, 224])
patches: torch.Size([8, 196, 768])


In [3]:
patch_size = 16
in_channels = 3
emb_size = 768 # channel * patch_size * patch_size

projection = nn.Sequential(
    nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size), # [-1, 786, 14, 14]
    Rearrange('b e (h) (w) -> b (h w) e') # [-1, 14*14, 786]
)
summary(projection, x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 2.30
Params size (MB): 2.25
Estimated Total Size (MB): 5.12
----------------------------------------------------------------


### Patch Embedding
- patches에 class token & positional embedding을 넣어주기

In [4]:
emb_size = 768
img_size = 224
patch_size = 16

# 이미지를 패치 사이즈로 나누고 flatten
projection_x = projection(x)
print('Projection X shape: ', projection_x.shape)

# cls_token과 pos encoding prameter 정의
cls_token = nn.Parameter(torch.randn(1,1,emb_size)) # cls_token은 embedding 갯수만큼 생성
positions = nn.Parameter(torch.randn((img_size // patch_size)**2 + 1, emb_size)) # R(N+1 * D) # N+1인 이유는 cls_token 때문에
print('Class Shape: ', cls_token.shape, 'Position Shape: ', positions.shape)

# cls_token을 반복하여 배치사이즈의 크기와 맞춰줌
batch_size = 8
cls_tokens = repeat(cls_token, '() n e -> b n e', b=batch_size)
print('Repeated Class Shape: ', cls_tokens.shape)

# cls_token과 projected_x를 concatenate
cat_x = torch.cat([cls_tokens, projection_x], dim=1) # [8, 197, 768]

# position encoding을 더해줌
cat_x += positions # cat_x의 요소에 positions 값을 넣어서 더해줌 # shape에는 변함 없음!
print('output: ', cat_x.shape)

Projection X shape:  torch.Size([8, 196, 768])
Class Shape:  torch.Size([1, 1, 768]) Position Shape:  torch.Size([197, 768])
Repeated Class Shape:  torch.Size([8, 1, 768])
output:  torch.Size([8, 197, 768])


In [5]:
# 어떻게 concatenate되는지 확인해보기
tensor_1 = nn.Parameter(torch.randn(4, 1, 16))
tensor_2 = nn.Parameter(torch.randn(4, 8, 16))
cat_tensor = torch.cat([tensor_1, tensor_2], dim=1)
print('cat shape: ', cat_tensor.shape)

cat shape:  torch.Size([4, 9, 16])


In [8]:
# Class 형태로 만들어주기
class PathEmbedding(nn.Module):
    def __init__(self, in_channels:int=3, patch_size:int=16, emb_size:int=768, img_size:int=224) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e')
        )
        self.cls_token = nn.Parameter(torch.randn(1,1,emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size)**2 + 1, emb_size))
    
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape # batch_size 선언
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.positions
        return x

PE = PathEmbedding()
summary(PE, (3,224,224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 2.30
Params size (MB): 2.25
Estimated Total Size (MB): 5.12
----------------------------------------------------------------


### Multi Head Attention
- 패치들에 대해서 self attention 적용

In [9]:
emb_size = 768
num_heads = 8

# q, k, v 정의하기
keys = nn.Linear(emb_size, emb_size)
queries = nn.Linear(emb_size, emb_size)
values = nn.Linear(emb_size, emb_size)
print(keys, queries, values)

x = PE(x)
print(queries(x).shape) # batch, n, emb_size
queries = rearrange(queries(x), "b n (h d) -> b h n d", h=num_heads) # emb_size를 h*d 형태로 만들면 # h=num_heads(8), d=emb_size/h(96)
keys = rearrange(keys(x), "b n (h d) -> b h n d", h=num_heads)
values = rearrange(values(x), "b n (h d) -> b h n d", h=num_heads)

print('shape: ', queries.shape, keys.shape, values.shape)

Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True)
torch.Size([8, 197, 768])
shape:  torch.Size([8, 8, 197, 96]) torch.Size([8, 8, 197, 96]) torch.Size([8, 8, 197, 96])


In [10]:
# Queries * Keys
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print('energy: ', energy.shape)

# Get Attention Score
scaling = emb_size ** (1/2)
att = F.softmax(energy/scaling, dim=1)
print('att: ', att.shape)

# Attention Score * values
out = torch.einsum('bhal, bhlv -> bhav', att, values)
print('out: ', out.shape)

# Reagrrange to emb_size
out_1 = rearrange(out, "b h n d -> b n (h d)") # 처음 input(x)과 같은 shape으로 반환 
print('out_1: ', out_1.shape)

energy:  torch.Size([8, 8, 197, 197])
att:  torch.Size([8, 8, 197, 197])
out:  torch.Size([8, 8, 197, 96])
out_1:  torch.Size([8, 197, 768])
