![img](../img/attention1.png)

![img](../img/attention2.png)

![img](../img/attention3.png)

创建 query、key、value 向量序列 Q K V ，并且使用点积作为相似度函数来计算注意力分数

In [4]:
import torch
from math import sqrt

Q = K = V = torch.rand(2, 4, 768)
dim_k = K.size(-1)
scores = torch.bmm(Q, K.transpose(1,2)) / sqrt(dim_k)
print(scores.size())

torch.Size([2, 4, 4])


这里 Q K 的序列长度都为 5，因此生成了一个 5 x 5 的注意力分数矩阵，接下来就是应用 Softmax 标准化注意力权重：

In [5]:
import torch.nn.functional as F

weights = F.softmax(scores, dim=-1)
print(weights.sum(dim=-1))

tensor([[1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000]])


最后将注意力权重与 value 序列相乘：

In [None]:
attn_outputs = torch.bmm(weights, V)
print(attn_outputs.shape)

torch.Size([1, 5, 768])


至此就实现了一个简化版的 Scaled Dot-product Attention。可以将上面这些操作封装为函数以方便后续调用：

In [6]:
import torch
import torch.nn.functional as F
from math import sqrt

def scaled_dot_product_attention(query, key, value, query_mask=None, key_mask=None, mask=None):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    if query_mask is not None and key_mask is not None:
        mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1))
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -float("inf"))
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)

![img](../img/attention4.png)

In [7]:
from torch import nn

class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)

    def forward(self, query, key, value, query_mask=None, key_mask=None, mask=None):
        attn_outputs = scaled_dot_product_attention(
            self.q(query), self.k(key), self.v(value), query_mask, key_mask, mask)
        return attn_outputs

![img](../img/attention5.png)

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, head_dim):
        super().__init__()
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
        )
        self.output_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, query_mask=None, key_mask=None, mask=None):
        x = torch.cat([
            h(query, key, value, query_mask, key_mask, mask) for h in self.heads
        ], dim=-1)
        x = self.output_linear(x)
        return x

In [9]:
x = torch.rand(2, 4, 768)
mha = MultiHeadAttention(embed_dim=768, num_heads=12, head_dim=64)

In [10]:
y = mha(x, x, x)
y.shape

torch.Size([2, 4, 768])