In [6]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, n_head):
    super(MultiHeadAttention, self).__init__()
    
    self.n_head = n_head
    self.d_model = d_model
    self.d_k = d_model // n_head
    
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
    self.w_combine = nn.Linear(d_model, d_model)  
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, q, k, v, mask=None):
    batch_size, seq_len, _ = q.shape
    
    # 线性变换
    q = self.w_q(q)
    k = self.w_k(k)
    v = self.w_v(v)
    
    # 分割多头
    q = q.view(batch_size, seq_len, self.n_head, self.d_k).permute(0, 2, 1, 3)
    k = k.view(batch_size, seq_len, self.n_head, self.d_k).permute(0, 2, 1, 3)
    v = v.view(batch_size, seq_len, self.n_head, self.d_k).permute(0, 2, 1, 3)
    
    # 计算注意力得分
    scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)  

    # 应用掩码（如果有）
    if mask is not None:
        scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
    
    # 计算注意力权重并加权求和
    attention = self.softmax(scores)  
    context = attention @ v  
    
    # 合并多头结果
    context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.d_model)
    
    # 最终线性变换
    output = self.w_combine(context)
    return output

d_model = 512
n_head = 8
attention = MultiHeadAttention(d_model, n_head)

x = torch.randn(32, 10, d_model)  
out = attention(x, x, x)

输出形状: torch.Size([32, 10, 512])
