In [1]:
import torch
from torch import nn
from torchsummary import summary
import math

In [15]:
class PatchEmbedding(nn.Module):
    def __init__(self, channel, img_size, patch_size, d_dim):
        super().__init__()
        # img_size: 정사각 이미지인 것으로 정한다
        # patch_size: 원하는 patch size
        # d_dim: Linear Projection을 했을 때, 원하는 차원의 크기

        self.img_size = img_size
        self.channel = channel
            
        self.patch_size = patch_size
        
        self.n_patches = int((img_size * img_size) / (patch_size * patch_size))
        self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
        
        self.linear_projs = nn.ModuleList([nn.Linear(patch_size*patch_size*channel, d_dim) for _ in range(self.n_patches)])

    def crop_img(self, x):
        patches = []
        row = int(math.sqrt(self.n_patches))
        col = int(math.sqrt(self.n_patches))
        p_len = self.patch_size
        
        
        for i in range(row):
            for j in range(col):
                temp = x[:, :, i*p_len:(i+1)*p_len, j*p_len:(j+1)*p_len]
                patches.append(temp)
                # one patch에 대한 시작점 (y, x)와 끝점 (y, x) 조사
                # print(f'{i*p_len} {(i+1)*p_len} {j*p_len} {(j+1)*p_len}')
                
        return patches
        
    def forward(self, x):
        # Patch 생성
        patches = self.crop_img(x)

        # Flatten
        patches_flatten = [self.flatten(patches[i]).unsqueeze(dim=-2) for i in range(self.n_patches)]
        print(patches[0].shape)
        print(patches_flatten[0].shape)

        # Linear Projection
        embedding_li = [self.linear_projs[i](patches_flatten[i]) for i in range(self.n_patches)]
        print(embedding_li[0].shape)

        # 모든 embedding concatenate
        embeddings = torch.concat(embedding_li, dim=-2)
        print(embeddings.shape)

        return embeddings

In [11]:
patch_embedding = PatchEmbedding(channel=3, img_size=224, patch_size=16, d_dim=30)

In [12]:
BATCH_SIZE = 16
summary(patch_embedding, (3, 224, 224), batch_size=-1)

torch.Size([2, 3, 16, 16])
torch.Size([2, 1, 768])
torch.Size([2, 1, 30])
torch.Size([2, 196, 30])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1                  [-1, 768]               0
           Flatten-2                  [-1, 768]               0
           Flatten-3                  [-1, 768]               0
           Flatten-4                  [-1, 768]               0
           Flatten-5                  [-1, 768]               0
           Flatten-6                  [-1, 768]               0
           Flatten-7                  [-1, 768]               0
           Flatten-8                  [-1, 768]               0
           Flatten-9                  [-1, 768]               0
          Flatten-10                  [-1, 768]               0
          Flatten-11                  [-1, 768]               0
          Flatten-12                  [-1, 768]               0
    

In [14]:
x = torch.randn(4, 3, 224, 224, requires_grad=True)
pred = patch_embedding(x)
pred.shape

torch.Size([4, 3, 16, 16])
torch.Size([4, 1, 768])
torch.Size([4, 1, 30])
torch.Size([4, 196, 30])


torch.Size([4, 196, 30])