In [2]:
import torch
import torch.nn as nn

### PatchEmbedding

In [3]:
# Define toy input: Batch of 1 image, 3 channels (RGB), 8x8 size
n_samples = 1
in_channels = 3  # RGB image
img_size = 8
patch_size = 4
embed_dim = 6  # Output embedding dimension


In [4]:
# Create a dummy image tensor (batch_size=1, channels=3, height=8, width=8)
x = torch.randn(n_samples, in_channels, img_size, img_size)
x

tensor([[[[ 0.5579,  0.4759,  1.8356,  0.4074,  0.7715, -0.2655,  1.2270,
           -1.1819],
          [ 2.9807,  0.1846, -0.5290, -1.1848,  1.2755,  1.3569, -1.8161,
            1.0120],
          [ 0.3811,  1.1015, -1.4335,  2.1419, -1.0955,  0.7276, -0.2383,
            0.3690],
          [-0.7204,  0.2786, -1.4674,  0.4720,  0.8239, -0.8816, -0.7099,
            0.8566],
          [ 0.8738, -1.2502, -0.1535, -1.9817,  1.5161,  1.1610,  0.5154,
            1.8058],
          [ 0.3842, -0.3454,  0.6362,  0.4555, -1.3266,  1.3281,  0.7418,
           -0.3243],
          [-1.2001, -0.8990,  1.4726,  1.3036,  0.9414,  0.1963, -0.6406,
            0.2436],
          [ 0.4218,  0.4888, -1.2154, -0.8616,  0.6857,  0.2105, -1.9911,
           -1.8040]],

         [[-0.4941, -0.2721, -0.6848,  1.4417,  1.8179,  1.3482,  0.6135,
           -0.5813],
          [-0.4211, -0.5812,  0.1442,  0.4282, -0.1659,  0.2754, -0.3848,
            0.6700],
          [ 1.4737,  0.0664,  0.8370, -1.0639, -

In [5]:
# Define PatchEmbedding layer
patch_embedding = torch.nn.Conv2d(
    in_channels,
    embed_dim,
    kernel_size=patch_size,
    stride=patch_size
)
patch_embedding

Conv2d(3, 6, kernel_size=(4, 4), stride=(4, 4))

In [6]:
x_proj = patch_embedding(x)
x_proj.shape

torch.Size([1, 6, 2, 2])

In [7]:
x_proj

tensor([[[[-0.9781, -1.1017],
          [ 0.3744,  0.2382]],

         [[ 0.4445, -0.2569],
          [ 0.1958, -0.1500]],

         [[ 0.7448,  0.5789],
          [ 0.3149,  0.8070]],

         [[-0.1469,  0.2645],
          [ 0.3128, -0.3775]],

         [[ 0.1720, -0.5700],
          [ 0.2468,  0.1501]],

         [[ 0.1384, -0.4763],
          [ 0.6453, -0.1547]]]], grad_fn=<ConvolutionBackward0>)

In [8]:
x_flatten = x_proj.flatten(2) # merge the the dimension of 2 and 3 into a single dimension
x_flatten.shape

torch.Size([1, 6, 4])

In [9]:
x_flatten.transpose(1, 2)

tensor([[[-0.9781,  0.4445,  0.7448, -0.1469,  0.1720,  0.1384],
         [-1.1017, -0.2569,  0.5789,  0.2645, -0.5700, -0.4763],
         [ 0.3744,  0.1958,  0.3149,  0.3128,  0.2468,  0.6453],
         [ 0.2382, -0.1500,  0.8070, -0.3775,  0.1501, -0.1547]]],
       grad_fn=<TransposeBackward0>)

## LayerNorm

In PyTorch: `nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True)`

In [10]:
input = torch.tensor([[0, 4.], [-1, 7], [3, 5]])
input

tensor([[ 0.,  4.],
        [-1.,  7.],
        [ 3.,  5.]])

In [11]:
n_samples, n_features = input.shape

In [12]:
layernorm = nn.LayerNorm(n_features, elementwise_affine=False) #  elementwise_affine=False: no learnable parameters

In [13]:
# computes the total number of trainable parameters in the layernorm model (or layer).
sum(p.numel() for p in layernorm.parameters() if p.requires_grad)

0

In [14]:
layernorm.weight, layernorm.bias

(None, None)

In [15]:
input.mean(-1) # calculate the mean of the last dimension

tensor([2., 3., 4.])

In [16]:
"""
If unbiased=False, the standard deviation is computed using N (population standard deviation).
If unbiased=True, the standard deviation is computed using N-1 (sample standard deviation, also called Bessel’s correction).

When calculating the standard deviation of a sample, dividing by N-1 corrects the bias in estimating the population standard deviation.
This is useful in statistics when working with small sample sizes.

When to Use Each?
Use unbiased=True (default) when working with samples and need an unbiased estimator of population std.
Use unbiased=False when working with the full dataset (population statistics)
"""
input.std(-1, unbiased=False)

tensor([2., 4., 1.])

In [17]:
# applies Layer Normalization to the input tensor and then computes the mean along the last dimension.
layernorm(input).mean(-1), layernorm(input).std(-1, unbiased=False)

(tensor([0., 0., 0.]), tensor([1.0000, 1.0000, 1.0000]))

In [18]:
layernorm2 = nn.LayerNorm(n_features, elementwise_affine=True) 

In [19]:
# computes the total number of trainable parameters in the layernorm2 model (or layer).
sum(p.numel() for p in layernorm2.parameters() if p.requires_grad)

4

In [20]:
layernorm2.weight, layernorm2.bias

(Parameter containing:
 tensor([1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0.], requires_grad=True))

In [21]:
# Both have grad_fn, meaning they are part of the computational graph in PyTorch and support autograd.
layernorm2(input).mean(-1), layernorm2(input).std(-1, unbiased=False)

(tensor([0., 0., 0.], grad_fn=<MeanBackward1>),
 tensor([1.0000, 1.0000, 1.0000], grad_fn=<StdBackward0>))

# CLS token

In [None]:
# Parameters
batch_size = 4      # number of images in a batch
n_patches = 10      # number of patches per image
embed_dim = 8       # embedding dimension for each patch/token


In [23]:
# Simulate patch embeddings for a batch of images.
# Shape: (batch_size, n_patches, embed_dim)
patch_embeddings = torch.randn(batch_size, n_patches, embed_dim)
print("Patch embeddings shape:", patch_embeddings.shape)  # (4,

Patch embeddings shape: torch.Size([4, 10, 8])


In [24]:
patch_embeddings

tensor([[[-5.9520e-01,  9.1205e-01, -4.5434e-02, -5.9831e-01,  1.6236e-01,
           1.0749e+00, -1.4761e+00,  2.2378e+00],
         [ 1.4578e+00, -5.9560e-01, -9.2054e-01,  1.5206e+00, -6.8571e-01,
           9.9370e-01, -3.3203e-01,  2.0326e-01],
         [-1.7800e-01, -1.1020e+00, -1.0643e+00, -4.3989e-02, -5.7469e-01,
          -8.7325e-01, -2.7625e-01, -1.9118e+00],
         [-9.3121e-01,  6.4652e-01, -7.8528e-01,  6.9302e-01, -1.9613e-01,
           2.5961e+00,  3.1525e-01,  1.2773e+00],
         [-1.3783e+00, -5.0352e-01, -8.8609e-01, -3.3275e-02, -1.1520e+00,
           5.3384e-01,  2.9944e+00,  1.9708e+00],
         [-5.7775e-01, -9.0095e-03,  2.1483e+00,  8.6338e-01,  2.1738e+00,
           5.8588e-01, -4.5173e-01,  4.7345e-01],
         [-9.1210e-01, -1.6091e+00, -1.0199e-01,  1.5419e+00,  8.1252e-01,
           1.9598e-01, -2.1218e+00, -1.8783e-01],
         [ 6.9087e-01,  1.4726e-01,  1.0598e-01,  1.1172e-01,  1.0222e+00,
           4.6652e-02, -8.6382e-01,  1.0793e+00],


In [25]:
# Initialize the learnable classification token (cls_token)
# It is defined as (1, 1, embed_dim) meaning:
#   1: placeholder for a single token
#   1: one token (the classification token itself)
#   embed_dim: the token's embedding dimension
cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
print("cls_token shape:", cls_token.shape)  # (1, 1, 8)


cls_token shape: torch.Size([1, 1, 8])


In [None]:
cls_token # one token (the classification token itself)

Parameter containing:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0.]]], requires_grad=True)

In [27]:
# In order to prepend the cls_token to every image in the batch,
# we expand it along the batch dimension.
# This does not create new data; it simply views the same parameter for each item.
expanded_cls_token = cls_token.expand(batch_size, -1, -1)
print("Expanded cls_token shape:", expanded_cls_token.shape)  # (4, 1, 8)

Expanded cls_token shape: torch.Size([4, 1, 8])


In [None]:
# Concatenate the cls_token with the patch embeddings along the token dimension (dim=1)
# The resulting tensor shape will be (batch_size, n_patches + 1, embed_dim)

# since the cls_token is prepended to the entire batch of image patches, i.e., n_patches + 1 tokens, the dimension of cls_token is embed_dim.
tokens = torch.cat([expanded_cls_token, patch_embeddings], dim=1)
print("Tokens shape after concatenation:", tokens.shape)  # (4, 11, 8)

Tokens shape after concatenation: torch.Size([4, 11, 8])


In [33]:
# Then we add a learnable positional embedding to the tokens
# Initialize the positional embeddings
# Note: The positional embeddings are shared across the batch
n_positions = n_patches + 1  # number of tokens
positional_embeddings = nn.Parameter(torch.randn(n_positions, embed_dim))
print("Positional embeddings shape:", positional_embeddings.shape)  # (11, 8)

Positional embeddings shape: torch.Size([11, 8])
