In [7]:
import torch
import torch.nn as nn
import math
import numpy as np

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_k: int):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_k

        self.W_Q = nn.Linear(d_model, n_heads * d_k)
        self.W_K = nn.Linear(d_model, n_heads * d_k)
        self.W_V = nn.Linear(d_model, n_heads * d_k)
        self.output = nn.Linear(n_heads * d_k, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, attn_mask):
        batch_size = query.size(0)
        src_len = query.size(1)
        d_k = self.d_k
        n_heads = self.n_heads

        query = self.W_Q(query).reshape(batch_size, src_len, n_heads, d_k).transpose(1, 2)
        key = self.W_K(key).reshape(batch_size, src_len, n_heads, d_k).transpose(1, 2)
        value = self.W_V(value).reshape(batch_size, src_len, n_heads, d_k).transpose(1, 2)

        if attn_mask is not None:
            assert attn_mask.size() == (batch_size, src_len, src_len)
            attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
            attn_mask = attn_mask.bool()

        scores = torch.matmul(query, key.transpose(-1, -2)) / np.sqrt(d_k)

        if attn_mask is not None:
            scores.masked_fill_(attn_mask, -1e4)

        attns = self.softmax(scores)
        output = torch.matmul(attns, value)
        output = output.transpose(1, 2).contiguous().reshape(batch_size, -1, d_k * n_heads)
        output = self.output(output)

        return output

In [11]:
MHA = MultiHeadAttention(  
    d_model=512,   # 模型总维度  
    n_heads=8,     # 8个注意力头  
    d_k=64         # 每个头的维度  
)  
# 使用示例  
batch_size = 32  
seq_length = 100  
query = torch.randn(batch_size, seq_length, 512)  
key = torch.randn(batch_size, seq_length, 512)  
value = torch.randn(batch_size, seq_length, 512)  

# 创建注意力遮罩(可选)  
attn_mask = torch.zeros(batch_size, seq_length, seq_length).bool()  

# 前向传播  
output = MHA(query, key, value, attn_mask)  
print(query)
print(output)

tensor([[[-0.5374,  0.8965,  1.1350,  ...,  1.6581,  0.1686,  1.3624],
         [ 0.5993, -0.4241,  0.6466,  ...,  1.6504,  0.6699,  1.4069],
         [-1.6111,  0.5703,  0.5384,  ...,  3.2205, -0.7980, -0.4078],
         ...,
         [-0.0642, -0.8640, -1.7854,  ...,  0.0496, -0.1097, -0.0803],
         [ 0.3494,  0.5224, -1.7362,  ..., -0.7658,  1.4892,  0.4058],
         [-0.9695,  1.0183, -3.5471,  ..., -0.4584, -0.8055,  0.2513]],

        [[-0.9149,  0.9571,  0.4021,  ..., -0.5145, -0.2525, -0.5677],
         [-0.9343,  1.0164,  1.0843,  ...,  0.9977, -0.5970,  0.5019],
         [-0.7378, -2.5707, -1.7871,  ..., -0.5649,  0.5830,  0.3226],
         ...,
         [-0.2303, -1.5828, -0.9122,  ..., -0.2734, -1.0095,  0.1967],
         [ 0.7680, -0.1118,  0.4220,  ..., -0.2372,  0.0644,  1.7829],
         [ 0.6531, -1.6374, -0.9833,  ..., -0.1718, -1.0883, -0.2823]],

        [[-1.1870, -0.1303,  0.7256,  ..., -0.3316,  0.9571,  0.8577],
         [-1.0492, -1.2005, -0.0238,  ..., -1