In [1]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (224 // patch_size) ** 2  # Assuming input image size is 224x224
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

    def forward(self, x):
        # Input x is an image tensor of shape (batch_size, channels, height, width)
        # Extract patches from the input image tensor using conv2d
        x = self.proj(x)
        # Reshape the extracted patches into a 2D tensor
        # (batch_size, num_patches, embedding_dim)
        x = x.flatten(2).transpose(1, 2)
        # Add positional embeddings to the patch embeddings
        x = x + self.pos_embed
        return x

In [2]:
from PIL import Image
from torchvision import transforms

# Load the image and resize it to 224x224
image = Image.open('data/dog.jpeg').convert('RGB')
resized_image = image.resize((224, 224))

# Convert the resized image to a tensor
tensor_image = transforms.ToTensor()(resized_image)

# Add a batch dimension to the tensor
batched_tensor_image = tensor_image.unsqueeze(0)

# Initialize the PatchEmbedding module with patch_size=16, embed_dim=768
patch_embedding = PatchEmbedding(patch_size=16, embed_dim=768)

# Compute patch embeddings for the input image tensor
patch_embeddings = patch_embedding.forward(batched_tensor_image)

# Print the shape of the output tensor
print(patch_embeddings.shape)

torch.Size([1, 196, 768])
