In [1]:
import torch
import torch.nn as nn

### PatchEmbedding

In [2]:
# Define toy input: Batch of 1 image, 3 channels (RGB), 8x8 size
n_samples = 1
in_channels = 3  # RGB image
img_size = 8
patch_size = 4
embed_dim = 6  # Output embedding dimension


In [3]:
# Create a dummy image tensor (batch_size=1, channels=3, height=8, width=8)
x = torch.randn(n_samples, in_channels, img_size, img_size)
x

tensor([[[[-2.3085,  2.9301, -0.2634,  0.0959,  0.9889,  0.9827,  0.4778,
            0.2156],
          [-0.2989, -0.7796, -0.4471,  1.0615, -0.1459,  0.3371, -0.6267,
            0.6362],
          [ 0.8384,  1.5411,  1.8485,  0.0526, -0.2272, -0.5812,  0.1585,
            1.2235],
          [ 2.6966,  0.0080, -1.4554, -0.3657,  0.9881, -0.6463,  0.5210,
           -0.5178],
          [ 1.0496,  1.0235,  0.4265, -0.3547,  0.6718,  1.2416,  0.8128,
            1.3974],
          [-0.1225,  1.5216,  1.5826, -0.3536,  1.4035,  1.3632, -0.2961,
           -0.2238],
          [ 0.5618, -1.2004, -0.0455, -0.7475, -1.2903,  0.8388, -0.7475,
           -0.8898],
          [ 1.1827, -0.6823,  0.3822,  1.0224,  1.6121, -1.5723,  2.4434,
           -0.9830]],

         [[ 1.8518,  1.2762,  0.7064, -0.2057,  0.9397, -0.6804, -0.0820,
           -1.0543],
          [ 0.8019,  0.1416, -0.8532,  1.0030,  0.5649, -0.2463,  0.7291,
           -1.1164],
          [-0.5022,  0.8088,  1.2477,  0.4497, -

In [4]:
# Define PatchEmbedding layer
patch_embedding = torch.nn.Conv2d(
    in_channels,
    embed_dim,
    kernel_size=patch_size,
    stride=patch_size
)
patch_embedding

Conv2d(3, 6, kernel_size=(4, 4), stride=(4, 4))

In [6]:
x_proj = patch_embedding(x)
x_proj.shape

torch.Size([1, 6, 2, 2])

In [7]:
x_proj

tensor([[[[-0.6274,  0.3757],
          [ 0.4737,  0.6531]],

         [[ 0.0518,  0.1551],
          [ 0.1834, -0.2474]],

         [[-0.1822,  0.5464],
          [ 0.7082,  0.6695]],

         [[ 0.1504, -0.7843],
          [ 0.3537, -0.8814]],

         [[-0.0249,  0.5595],
          [-1.3911, -1.0484]],

         [[ 0.1888,  0.2588],
          [ 0.5194,  0.6042]]]], grad_fn=<ConvolutionBackward0>)

In [None]:
x_flatten = x_proj.flatten(2) # merge the the dimension of 2 and 3 into a single dimension
x_flatten.shape

torch.Size([1, 6, 4])

In [17]:
x_flatten.transpose(1, 2)

tensor([[[-0.6274,  0.0518, -0.1822,  0.1504, -0.0249,  0.1888],
         [ 0.3757,  0.1551,  0.5464, -0.7843,  0.5595,  0.2588],
         [ 0.4737,  0.1834,  0.7082,  0.3537, -1.3911,  0.5194],
         [ 0.6531, -0.2474,  0.6695, -0.8814, -1.0484,  0.6042]]],
       grad_fn=<TransposeBackward0>)