In [1]:
import torch
from torch import nn
from torchinfo import summary
from patch_embedding import PatchEmbedding

In [56]:
class InputEmbedding(nn.Module):
    def __init__(self, channel, img_size, patch_size, d_dim):
        '''
        Input Shape : (channel, img_size, img_size)
        Output Shape : (embedding_size + 1, d_dim)
        embedding_size : patch로 나누었을 때, embedding의 수
        '''
        
        super().__init__()
        self.embedding_size = int((img_size*img_size)/(patch_size*patch_size))

        self.classEmbedding = nn.Parameter(torch.randn(1, d_dim, requires_grad=True))
        self.patchEmbedding = PatchEmbedding(channel, img_size, patch_size, d_dim)
        self.positionalEmbedding = nn.Parameter(torch.randn(self.embedding_size + 1, d_dim), requires_grad=True)

    def forward(self, x):
        print(x.shape)
        x = self.patchEmbedding(x)
        print(x.shape)
        print(self.classEmbedding.shape)

        embedding_li = []
        for i in range(x.shape[0]): # Batch의 수에 따라 
            # 각 Batch에 대해서 Class Embedding과 Positional Embedding 더하기
            # 이 때, Embedding은 모두 같다.
            x_ = torch.concat((self.classEmbedding, x[i]), dim=-2)
            print(x_.shape)
            x_ = x_ + self.positionalEmbedding
            print(x_.shape)
            embedding_li.append(x_.unsqueeze(dim=0))
            
        embeddings = torch.concat(embedding_li, dim=0)
        print(embeddings.shape)
        return embeddings

In [57]:
embedd = InputEmbedding(channel=3, img_size=224, patch_size=16, d_dim=30)

In [59]:
summary(embedd, (4, 3, 224, 224))

torch.Size([4, 3, 224, 224])
torch.Size([4, 196, 30])
torch.Size([1, 30])
torch.Size([197, 30])
torch.Size([197, 30])
torch.Size([197, 30])
torch.Size([197, 30])
torch.Size([197, 30])
torch.Size([197, 30])
torch.Size([197, 30])
torch.Size([197, 30])
torch.Size([4, 197, 30])


Layer (type:depth-idx)                   Output Shape              Param #
InputEmbedding                           [4, 197, 30]              5,940
├─PatchEmbedding: 1-1                    [4, 196, 30]              --
│    └─Flatten: 2-1                      [4, 768]                  --
│    └─Flatten: 2-2                      [4, 768]                  --
│    └─Flatten: 2-3                      [4, 768]                  --
│    └─Flatten: 2-4                      [4, 768]                  --
│    └─Flatten: 2-5                      [4, 768]                  --
│    └─Flatten: 2-6                      [4, 768]                  --
│    └─Flatten: 2-7                      [4, 768]                  --
│    └─Flatten: 2-8                      [4, 768]                  --
│    └─Flatten: 2-9                      [4, 768]                  --
│    └─Flatten: 2-10                     [4, 768]                  --
│    └─Flatten: 2-11                     [4, 768]                  --
│    └─Flatt

In [53]:
x = torch.randn(4, 3, 224, 224)
embedd(x)

torch.Size([4, 196, 30])
torch.Size([4, 1, 30])
torch.Size([4, 197, 30])


tensor([[[-1.9994e+00, -2.0422e+00, -9.3259e-01,  ..., -2.2107e+00,
           4.6468e-02, -2.3437e+00],
         [ 4.6467e-01, -1.6237e+00,  7.1954e-01,  ...,  4.6099e-01,
          -6.1860e-01, -2.7626e+00],
         [-1.4207e+00,  9.8410e-01, -1.3526e+00,  ..., -7.6201e-01,
          -1.6425e-01,  8.7078e-01],
         ...,
         [ 1.2713e+00, -2.4886e-03, -1.2811e+00,  ...,  6.6824e-01,
          -8.4106e-01, -5.2546e-01],
         [ 9.5886e-01,  6.4856e-01,  1.7090e+00,  ..., -1.3175e+00,
           1.1656e+00, -6.1793e-01],
         [ 1.1312e+00, -1.1779e+00,  1.5561e+00,  ..., -5.5975e-01,
          -4.2126e-01,  1.5906e+00]],

        [[-9.4360e-01, -1.3111e+00, -1.5326e+00,  ..., -8.2530e-01,
           3.5700e-01,  2.7564e-01],
         [ 7.1929e-01,  1.4175e-01, -2.4479e-01,  ...,  1.0638e+00,
           1.0104e+00,  6.1744e-01],
         [ 2.1222e+00, -4.4997e-01, -7.0056e-01,  ...,  1.3328e+00,
           2.7297e-01, -1.6638e+00],
         ...,
         [-1.2373e+00, -1