In [2]:
import torch
import torch.nn as nn

# 初始化多头注意力层
embed_dim = 512  # 嵌入维度
num_heads = 8  # 多头数量
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

# 随机初始化输入特征
batch_size = 32
seq_len = 2  # 序列长度
h, w = 5, 9  # 空间特征维度 (14x14)

# 序列特征: [seq_len, batch_size, embed_dim]
sequence_features = torch.randn(seq_len, batch_size, embed_dim)

# 空间特征展平: [batch_size, h * w, embed_dim] -> [8, 196, 512]
spatial_features = torch.randn(batch_size, h * w, embed_dim)

# 转置空间特征为 [196, batch_size, embed_dim]
spatial_features = spatial_features.permute(1, 0, 2)  # [196, 8, 512]

# 模式 1: 序列作为 Query，空间作为 Key 和 Value
query = sequence_features  # [10, 8, 512]
key = value = spatial_features  # [196, 8, 512]

# 计算交叉注意力输出
output1, _ = multihead_attn(query, key, value)
print(f"序列作为 Query 的输出 shape: {output1.shape}")  # [10, 8, 512]

# 模式 2: 空间作为 Query，序列作为 Key 和 Value
query = spatial_features  # [196, 8, 512]
key = value = sequence_features  # [10, 8, 512]

# 计算交叉注意力输出
output2, _ = multihead_attn(query, key, value)
print(f"空间作为 Query 的输出 shape: {output2.shape}")  # [196, 8, 512]

序列作为 Query 的输出 shape: torch.Size([2, 32, 512])
空间作为 Query 的输出 shape: torch.Size([45, 32, 512])
