## Patch Embedding

<img src="../assets/patch-embedding.png" width="700px" height="400px">

In [1]:
import torch
import torch.nn as nn

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=patch_size, # take each patch_size pixel
            ),
            nn.Flatten(2), # flatten from the 2nd dimension to the end
        )
        
        # special classification token
        self.special_classification_token = nn.Parameter(
            torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True
        )
        
        # here we're randomly initializing the position embeddings, else we could use the sin-cos positional encodings
        self.position_embeddings = nn.Parameter(
            torch.randn(size=(1, num_patches + 1, embed_dim)), requires_grad=True
        )
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        cls_token = self.special_classification_token.expand(x.shape[0], -1, -1) # (B, input_channel, E)

        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.position_embeddings + x
        x = self.dropout(x)
        return x

In [3]:
BATCH_SIZE = 512
PATCH_SIZE = 4
IMG_SIZE = 28
IN_CHANNELS = 1
DROPOUT = 0.001
EMBED_DIM = (PATCH_SIZE**2) * IN_CHANNELS  # 16
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2  # 49

device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(
    device
)
x = torch.randn(512, 1, 28, 28).to(device)
print(model(x).shape)

model.special_classification_token.shape

torch.Size([512, 50, 16])


torch.Size([1, 1, 16])