In [10]:
from torch.nn import Linear, MultiheadAttention

# Suppose we have 1,3,128,28 input tensor (batch size, channels, height, width)
# Suppose we divide it by 16x16 patch, the output is 1,64,768 tensor (batch size, num_patches, channels*height*width)
# Suppose we embed it to 1,64,512 tensor (batch size, num_patches, embed_dim)

# Create a simple linear layer with input and output dimensions
linear_single = Linear(512, 512)

# Create a MultiHeadAttention layer
# Note: MultiHeadAttention applies linear projections for query, key, and value, plus an output projection.
mha_layer = MultiheadAttention(embed_dim=512, num_heads=2)

# Count parameters in each module
params_linear = sum(p.numel() for p in linear_single.parameters())
params_mha = sum(p.numel() for p in mha_layer.parameters())

print("Parameters in Linear:", params_linear)
print("Parameters in MultiHeadAttention:", params_mha)

Parameters in Linear: 262656
Parameters in MultiHeadAttention: 1050624
