In [53]:
import torch
from torch.nn import Unfold, Linear

channels = 3
height = 4
width = 6
input = torch.arange(end=1 * channels * height * width, dtype=torch.float).view(1, channels, height, width) # batch_size, channels, height, width
print(input, input.shape)
unfold = Unfold(kernel_size=(2, 2), stride=(2, 2))
patches: torch.Tensor = unfold(input) # batch_size, channels*kernel_size[0]*kernel_size[1], num_patches
print(patches, patches.shape)
patches = patches.permute(0, 2, 1) # batch_size, num_patches, channels*kernel_size[0]*kernel_size[1]
print(patches, patches.shape)
embed = Linear(12, 8)(patches)

tensor([[[[ 0.,  1.,  2.,  3.,  4.,  5.],
          [ 6.,  7.,  8.,  9., 10., 11.],
          [12., 13., 14., 15., 16., 17.],
          [18., 19., 20., 21., 22., 23.]],

         [[24., 25., 26., 27., 28., 29.],
          [30., 31., 32., 33., 34., 35.],
          [36., 37., 38., 39., 40., 41.],
          [42., 43., 44., 45., 46., 47.]],

         [[48., 49., 50., 51., 52., 53.],
          [54., 55., 56., 57., 58., 59.],
          [60., 61., 62., 63., 64., 65.],
          [66., 67., 68., 69., 70., 71.]]]]) torch.Size([1, 3, 4, 6])
tensor([[[ 0.,  2.,  4., 12., 14., 16.],
         [ 1.,  3.,  5., 13., 15., 17.],
         [ 6.,  8., 10., 18., 20., 22.],
         [ 7.,  9., 11., 19., 21., 23.],
         [24., 26., 28., 36., 38., 40.],
         [25., 27., 29., 37., 39., 41.],
         [30., 32., 34., 42., 44., 46.],
         [31., 33., 35., 43., 45., 47.],
         [48., 50., 52., 60., 62., 64.],
         [49., 51., 53., 61., 63., 65.],
         [54., 56., 58., 66., 68., 70.],
         [55.

In [61]:
import torch
from torch.nn import Unfold, Linear, MultiheadAttention

# Suppose we have 1,3,4,6 input tensor (batch size, channels, height, width)
# Suppose we divide it by 2x2 patch, the output is 1,6,12 tensor (batch size, num_patches, channels*height*width)
# Suppose we embed it to 1,6,8 tensor (batch size, num_patches, embed_dim)

channels = 3
height = 4
width = 6
input = torch.arange(end=1 * channels * height * width, dtype=torch.float).view(1, channels, height, width) # batch_size, channels, height, width
# print(input, input.shape)
unfold_layer = Unfold(kernel_size=(2, 2), stride=(2, 2))
patches: torch.Tensor = unfold_layer(input) # batch_size, channels*kernel_size[0]*kernel_size[1], num_patches
# print(patches, patches.shape)
patches = patches.permute(0, 2, 1) # batch_size, num_patches, channels*kernel_size[0]*kernel_size[1]
# print(patches, patches.shape)
embed_layer = Linear(12, 8)
embed: torch.Tensor = embed_layer(patches)
flattened_embed = embed.flatten(start_dim=1, end_dim=2) # batch_size, num_patches * embed_dim

# Create a simple linear layer with input and output dimensions
linear_layer = Linear(6 * 8, 6 * 8)
linear = linear_layer(flattened_embed)

# Create a MultiHeadAttention layer
# Note: MultiHeadAttention applies linear projections for query, key, and value, plus an output projection.
multi_head_attention_layer = MultiheadAttention(embed_dim=8, num_heads=2)
multi_head_attention = multi_head_attention_layer(embed, embed, embed, need_weights=False)[0]

# Count parameters in each module
params_linear = sum(p.numel() for p in linear_layer.parameters())
params_mha = sum(p.numel() for p in multi_head_attention_layer.parameters())

print("Parameters in Linear:", params_linear)
print("Parameters in MultiHeadAttention:", params_mha)

print("Output of Linear:", linear.shape)
print("Output of MultiHeadAttention:", multi_head_attention.shape)

Parameters in Linear: 2352
Parameters in MultiHeadAttention: 288
Output of Linear: torch.Size([1, 48])
Output of MultiHeadAttention: torch.Size([1, 6, 8])
