In [1]:
import torch
import torch.nn as nn
import math

class MHA(nn.Module):
    def __init__(self, hidden_dim, nums_head) -> None:
        super().__init__()
        self.nums_head = nums_head

        self.head_dim = hidden_dim // nums_head
        self.hidden_dim = hidden_dim

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(0.1)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X, attention_mask = None):
        #(b,s,h)
        batch_size, seq_len,_ = X.size()

        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        # (b,nums_head, s, h)
        q_state = Q.view(batch_size,seq_len,self.nums_head,self.head_dim).transpose(1,2)
        k_state = K.view(batch_size,seq_len,self.nums_head,self.head_dim).transpose(1,2)
        v_state = V.view(batch_size,seq_len,self.nums_head,self.head_dim).transpose(1,2)
        
        # (b,nums_head, s, s)
        attention_weight = (q_state @ k_state.transpose(-1,-2)/math.sqrt(self.head_dim))
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(attention_mask==0, float("-inf"))

        attention_weight  = self.dropout(torch.softmax(attention_weight, dim=-1))

        output = attention_weight @ v_state
        output = output.transpose(1,2).contiguous().view(batch_size,seq_len,-1)
        output = self.o_proj(output)
        return output



attention_mask = (
    torch.tensor(
        [
            [0, 1],
            [0, 0],
            [1, 0],
        ]
    )
    .unsqueeze(1)
    .unsqueeze(2)
    .expand(3, 8, 2, 2)
)

x = torch.rand(3, 2, 128)
net = MHA(128, 8)
net(x, attention_mask).shape

torch.Size([3, 2, 128])

https://blog.csdn.net/qq_40671063/article/details/130285398

https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

class
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, 
add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, 
device=None, dtype=None)

forward(query, key, value, key_padding_mask=None, need_weights=True, 
attn_mask=None, average_attn_weights=True, is_causal=False)

query (Tensor) – 查询嵌入形状(L，Eq)用于非批处理输入，(L，N，Eq)用于batchfirst= False 或(N，L，Eq)用于batchfirst = True，其中 L是目标序列长度，N 是批处理大小，Eq 是嵌入查询的嵌入维度 embed _ dim。将查询与键-值对进行比较以生成输出。

key (Tensor)- 非批量输入的形状(S，Ek)的key嵌入，(S，N，Ek)的键嵌入，batchfirst = False 或(N，S，Ek)的键嵌入，batchfirst = True，其中 S 是源序列长度，N 是批量大小，Ek 是key嵌入维度 kdim。

value (Tensor) – 非批处理输入的形状 (S,Ev) 的value嵌入，(S,N,Ev) 当 batch_first=False 或 (N,S,Ev)当 batch_first=True 时，其中 S 是源序列长度，N是批量大小，Ev 是value嵌入维度 vdim。

average_attn_weights (bool) – 如果为真，则表示返回的 attn_weights 应该在头部之间进行平均。否则，每个头单独提供 attn_weights。请注意，此标志仅在 need_weights=True 时有效。默认值：True（即头部的平均）

In [None]:
import torch
import torch.nn as nn

embed_dim = 16   # 每个 token 的特征维度
num_heads = 4    # 注意力头数

mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

batch_size = 2
seq_len = 5
x = torch.rand(batch_size, seq_len, embed_dim)  # (B, T, C)

# 自注意力：query, key, value 全部用 x
attn_output, attn_weights = mha(x, x, x)
print("输出形状:", attn_output.shape)  # (B, T, C)
print("注意力权重形状:", attn_weights.shape)  # (B, num_heads, T, T) 但是默认是 average_attn_weights=True，所以只返回平均权重： (B, T, S)

输出形状: torch.Size([2, 5, 16])
注意力权重形状: torch.Size([2, 5, 5])


In [3]:
import torch
import torch.nn as nn

embed_dim = 32
num_heads = 4

cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

# encoder output (key/value)
encoder_out = torch.rand(2, 6, embed_dim)     # (B=2, S=6, C=32)

# decoder input (query)
decoder_in = torch.rand(2, 3, embed_dim)      # (B=2, T=3, C=32)

# forward cross attention
output, attn_weights = cross_attn(
    query=decoder_in,
    key=encoder_out,
    value=encoder_out
)

print("Cross attention output:", output.shape)       # (2, 3, 32)
print("Cross attention weights:", attn_weights.shape)  # (2, 3, 6)


Cross attention output: torch.Size([2, 3, 32])
Cross attention weights: torch.Size([2, 3, 6])
