# 第三章 编码注意力机制

In [1]:
import torch
import torch.nn as nn

## 手写多头注意力机制

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model,context_length,dropout,num_heads,qkv_bias=False):
        super().__init__()
        self.d_model = d_model
        self.context_length = context_length
        self.dropout = dropout
        self.num_heads = num_heads
        self.qkv_bias = qkv_bias
        assert d_model % num_heads == 0
        self.head_dim = d_model // num_heads
        self.q = nn.Linear(d_model,d_model,bias=qkv_bias)
        self.k = nn.Linear(d_model,d_model,bias=qkv_bias)
        self.v = nn.Linear(d_model,d_model,bias=qkv_bias)
        self.out = nn.Linear(d_model,d_model,bias=qkv_bias)
        self.attn_drop = nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
    
    def forward(self,x):
        b,t,d_model = x.shape
        q = self.q(x).view(b,t,self.num_heads,self.head_dim).transpose(1,2)  # b, num_heads, t, head_dim
        k = self.k(x).view(b,t,self.num_heads,self.head_dim).transpose(1,2)  # b, num_heads, t, head_dim
        v = self.v(x).view(b,t,self.num_heads,self.head_dim).transpose(1,2)  # b, num_heads, t, head_dim
        atten_scores = q @ k.transpose(-2,-1) / (self.head_dim ** 0.5)  # b, num_heads, t, t
        atten_scores = atten_scores.masked_fill_(self.mask.bool()[:t,:t],-torch.inf)
        atten_weights = torch.softmax(atten_scores,dim=-1)  # b, num_heads, t, t
        atten_weights = self.attn_drop(atten_weights)
        out = atten_weights @ v  # b, num_heads, t, head_dim
        out = out.transpose(1,2).contiguous().view(b,t,d_model)
        
        return out

In [3]:
# 测试
torch.manual_seed(0)
batch = torch.ones((2,4,8))  # (b,t,d_model)
mha = MultiHeadAttention(d_model=8,context_length=4,dropout=0.1,num_heads=4)
output = mha(batch)
print(output)

tensor([[[-0.4942,  1.2725, -1.4697, -0.3354, -0.9314,  0.1962,  0.1261,
           0.6202],
         [-0.4942,  1.2725, -1.4697, -0.3354, -0.9314,  0.1962,  0.1261,
           0.6202],
         [-0.4942,  1.2725, -1.4697, -0.3354, -0.9314,  0.1962,  0.0841,
           0.4135],
         [-0.4942,  1.2725, -1.1023, -0.2516, -0.6985,  0.1471,  0.1261,
           0.6202]],

        [[-0.4942,  1.2725,  0.0000,  0.0000, -0.9314,  0.1962,  0.1261,
           0.6202],
         [-0.2471,  0.6362, -0.7348, -0.1677, -0.9314,  0.1962,  0.1261,
           0.6202],
         [-0.4942,  1.2725, -1.4697, -0.3354, -0.9314,  0.1962,  0.1261,
           0.6202],
         [-0.4942,  1.2725, -1.4697, -0.3354, -0.9314,  0.1962,  0.1261,
           0.6202]]], grad_fn=<ViewBackward0>)
