### 注意力评分函数


假设有查询$\mathbf{q} \in \mathbb{R}^q$和m个Key-Value Pair
$(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m,\mathbf{v}_m)$,其中$\mathbb{k_i} \in \mathbb{R^k},\mathbf{v_i} \in \mathbb{R^v}$注意力汇聚函数$f$就被表示成值的加权和:
$$f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,$$
其中,$\mathbf{q}和\mathbf{k_i}$的注意力权重通过注意力评分函数$a$将两个向量映射成标量,再经过softmax运算得到:
$$
\alpha(\mathbf{q},\mathbf{k}_i)=\mathrm{softmax}(a(\mathbf{q},\mathbf{k}_i)) = \frac {\exp(a(\mathbf{q},\mathbf{k}_j))}{\sum_{j=1}^m \exp(a(\mathbf{q},\mathbf{k}_j))} \in \mathbb{R}
$$


不同的注意力评分函数会导致不同的注意力汇聚操作.

In [2]:
import math
import torch
from torch import nn
import sys,os
sys.path.append(os.path.abspath("../"))
import lmy
from lmy import show_heatmaps
import d2l


  Referenced from: <C5023A15-CF0B-3BCA-A7A1-8343A4696561> /Users/zane/miniforge3/envs/torch/lib/python3.9/site-packages/torchvision/image.so
  Reason: tried: '/Users/ec2-user/actions-runner/_work/vision/vision/conda-env-2408739877/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/ec2-user/actions-runner/_work/vision/vision/conda-env-2408739877/lib/libpng16.16.dylib' (no such file), '/Users/ec2-user/actions-runner/_work/vision/vision/conda-env-2408739877/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/ec2-user/actions-runner/_work/vision/vision/conda-env-2408739877/lib/libpng16.16.dylib' (no such file), '/Users/ec2-user/actions-runner/_work/vision/vision/conda-env-2408739877/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/ec2-user/actions-runner/_work/vision/vision/conda-env-2408739877/lib/libpng16.16.dylib' (no such file), '/Users/ec2-user/actions-runner/_work/vision/vision/conda-env-

#### 掩码softmax操作
目的:删除部分无意义的,长度超限的元素,进行softmax操作.

In [3]:
def masked_softmax(X,valid_lens):
    if valid_lens is None:
        return nn.functional.softmax(X,dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens,shape[1])
        else:
            valid_lens = valid_lens.reshape(-1) 
        X = d2l.sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)
        X= X.reshape(shape)
        print(X)
        return nn.functional.softmax(X,dim=-1)

In [4]:
masked_softmax(torch.arange(1, 5,dtype=torch.float32), torch.tensor(3))

tensor([ 1.0000e+00,  2.0000e+00,  3.0000e+00, -1.0000e+06])


tensor([0.0900, 0.2447, 0.6652, 0.0000])

In [5]:
# 效果与上述相同
X = torch.arange(1,4,dtype=torch.float32)
X,X.softmax(-1)

(tensor([1., 2., 3.]), tensor([0.0900, 0.2447, 0.6652]))

### 加性注意力
一般来说,当Query和Key是不同长度的矢量时,我们可以使用加性注意力作为评分函数.给定一个查询$q \in \mathbb{R}^q$和键$k \in \mathbb{R}^k$,additive attention 的评分函数为
$$
a(q,k) = \mathbf{w}_v^T\mathrm{tanh}(\mathbf{w}_q q+\mathbf{w}_k k ) \in \mathbb{R}
$$
,其中可学习的参数是wq和wk,wv.将查询和键连接到一起后输入到一个MLP中,感知机包含一个隐含层,单元数量为超参数h,使用tanh作为激活函数且禁止偏置项.

In [6]:
class AdditiveAttention(nn.Module):
    """加性注意力"""

    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super().__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)  # 不添加偏置
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries = self.W_q(queries)
        keys = self.W_k(keys)
        # 在维度扩展之后， Q的形状(batch_size, 查询的个数，1，num_hiddens)
        # K的形状(batch_size, 1 ，键值对个数，hun_hiddens)

        # 使用广播的方式进行求和
        features = queries.unsqueeze(2)+keys.unsqueeze(1)
        features = torch.tanh(features)
        # self.w_v 只有一个输出,因此从形状中移除最后那个维度
        # scores形状(batch_size,查询的个数,键值对的个数)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


In [12]:
queries = torch.normal(0, 1, size=(2, 1, 20))
keys = torch.ones((2, 10, 2))
values = torch.arange(40,dtype= torch.float32).reshape(1,10,4).repeat(2,1,1)
# queries,keys,values.shape
valid_lens = torch.tensor([2,6])
attention = AdditiveAttention(key_size=2,query_size=20,num_hiddens=8,dropout=.1)
attention.eval() # evaluation
attention(queries,keys,values,valid_lens)

tensor([[[ 7.4718e-02,  7.4718e-02, -1.0000e+06, -1.0000e+06, -1.0000e+06,
          -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06]],

        [[-1.3514e-02, -1.3514e-02, -1.3514e-02, -1.3514e-02, -1.3514e-02,
          -1.3514e-02, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06]]],
       grad_fn=<ReshapeAliasBackward0>)


tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)