In [1]:
import torch
from torch import nn

In [8]:
class PatchEmbeddings(nn.Module):
    def __init__(
        self, img_size: int = 96, patch_size: int = 16, hidden_dim: int = 512
    ) -> None:
        super().__init__()
        # Store the input image size, the patch size and hidden dimension
        self.img_size = img_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

        # Calculate the total number of patches
        self.num_patches = (self.img_size // self.patch_size) ** 2

        # Create a convolution to extract patch embeddings
        # in_channels=3 asummes a 3-channel image (RGB)
        # outp_channels=hidden_dim sets the number of output channels to match the hidden dimension
        # kernel_size=patch_size and stride=patch_size ensuring each patch is embedded separately
        self.conv = nn.Conv2d(
            in_channels=3,
            out_channels=self.hidden_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

    def forward(self, X) -> torch.Tensor:
        # Extract patch embeddings from the input image
        # Output shape: (batch_size, hidden_dim, (self.img_size // self.patch_size), (self.img_size // self.patch_size))
        X = self.conv(X)

        # Flatten the spatial dimensions (height and width) of the patch embeddings
        # This step flattens the patch dimensions to a single dimension
        # Output shape: (batch_size, hidden_dim, self.num_patches)
        X = X.flatten(2)

        # Transpose the dimensions to obtain the shape (batch_size, num_patches, hidden_dim)
        # This step brings the num_patches dimension to the second position
        # Output shape: (batch_size, self.num_patches, hidden_dim)
        X = X.transpose(1, 2)

        return X

In [9]:
patchifier = PatchEmbeddings()
X = torch.randn(1, 3, 96, 96)
patches = patchifier(X)