In [9]:
import torch

import torch.nn as nn

# Parameters
embed_dim = 16
num_heads = 4
batch_size = 2
seq_len = 10

# Create a MultiheadAttention layer
mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)

# Dummy inputs (shape: sequence length, batch size, embedding dimension)
query = torch.randn(seq_len, batch_size, embed_dim)
key = torch.randn(seq_len, batch_size, embed_dim)
value = torch.randn(seq_len, batch_size, embed_dim)

# Forward pass with need_weights set to False -- attention weights won't be computed
attn_output_without, weights_without = mha(query, key, value, need_weights=False)
attn_output_with, weights_with = mha(query, key, value, need_weights=True)

print("Attention output without weights:", attn_output_without[0])
print("Attention output with weights:", attn_output_with[0])
print("Attention weights without weights:", weights_without)
print("Attention weights with weights:", weights_with)

Attention output without weights: tensor([[ 0.0271,  0.0995, -0.1134, -0.2398,  0.0464, -0.2873,  0.0098, -0.2535,
          0.0373, -0.0007,  0.1214, -0.0417, -0.0646, -0.1667, -0.0210, -0.0912],
        [ 0.1833,  0.0517, -0.4616, -0.1873,  0.2027, -0.2449, -0.0279, -0.3226,
          0.2955, -0.1358, -0.1031, -0.1092, -0.2429,  0.0221, -0.1686, -0.2252]],
       grad_fn=<SelectBackward0>)
Attention output with weights: tensor([[ 0.0271,  0.0995, -0.1134, -0.2398,  0.0464, -0.2873,  0.0098, -0.2535,
          0.0373, -0.0007,  0.1214, -0.0417, -0.0646, -0.1667, -0.0210, -0.0912],
        [ 0.1833,  0.0517, -0.4616, -0.1873,  0.2027, -0.2449, -0.0279, -0.3226,
          0.2955, -0.1358, -0.1031, -0.1092, -0.2429,  0.0221, -0.1686, -0.2252]],
       grad_fn=<SelectBackward0>)
Attention weights without weights: None
Attention weights with weights: tensor([[[0.1046, 0.1101, 0.0788, 0.1206, 0.0839, 0.1156, 0.0909, 0.1094,
          0.1080, 0.0780],
         [0.1026, 0.1240, 0.0885, 0.0901

In [4]:
# Call with need_weights=False. The second output will be None.
output_without_weights, weights_without = mha(query, key, value, need_weights=False)
print("With need_weights=False:")
print("Output shape:", output_without_weights.shape)
print("Weights:", weights_without)

# Call with need_weights=True. The second output will contain the attention weights.
output_with_weights, weights_with = mha(query, key, value, need_weights=True)
print("\nWith need_weights=True:")
print("Output shape:", output_with_weights.shape)
print("Weights shape:", weights_with.shape)

With need_weights=False:
Output shape: torch.Size([10, 2, 16])
Weights: None

With need_weights=True:
Output shape: torch.Size([10, 2, 16])
Weights shape: torch.Size([2, 10, 10])
