In [17]:
import torch
import torch.nn as nn

# Define the dimensions
embed_size = 768  # Size of the embedding vector
num_heads = 8     # Number of attention heads

# Initialize the MultiheadAttention layer
multihead_attn = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, batch_first=True)

# Create dummy input data
# Note: input tensors need to be of shape (L, N, E) where L is the sequence length,
# N is the batch size, and E is the embedding dimension.
inputs = torch.randn(32, 768) # (batch_size, seq_length, embed_size)

# Apply the multihead attention to the input data
# attn_output shape: (seq_length, batch_size, embed_size)
# attn_output_weights shape: (batch_size, seq_length, seq_length) - attention weights
attn_output, attn_output_weights = multihead_attn(inputs, inputs, inputs)

print("Attention output shape:", attn_output.shape)
print("Attention weights shape:", attn_output_weights.shape)

Attention output shape: torch.Size([32, 768])
Attention weights shape: torch.Size([32, 32])


In [4]:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(10, 32, 512)
out = transformer_encoder(src)



In [5]:
out.shape

torch.Size([10, 32, 512])

In [7]:
for mn, m in nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, batch_first=True).named_modules():
    print(mn, m)

 MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
out_proj NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)


In [14]:
model = nn.Sequential(nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, batch_first=True))

In [15]:
for mn, m in model.named_modules():
    print(mn, m)

 Sequential(
  (0): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
)
0 MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
0.out_proj NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
