In [6]:
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_

import sys

sys.path.append("../../")

import torch
import math
from torch import nn
from d2l import d2l_en as d2l

In [8]:
def masked_softmax(X,valid_len):
    if valid_len is None:
        return nn.functional.softmax(X,dim=1)
    else:
        shape = X.shape
        if valid_len.dim() == 1:
            valid_len = torch.repeat_interleave(valid_len,repeats=shape[1],dim=0)
        else:
            valid_len = valid_len.reshape(-1)
        X = d2l.sequence_mask(X.reshape(-1,shape[-1]),valid_len,value=-1e6)
        return nn.functional.softmax(X.reshape(shape),dim=-1)

In [13]:
masked_softmax(torch.rand(2,2,4),torch.tensor([2,3]))

tensor([[[0.3649, 0.6351, 0.0000, 0.0000],
         [0.5457, 0.4543, 0.0000, 0.0000]],

        [[0.3728, 0.4230, 0.2042, 0.0000],
         [0.3328, 0.2840, 0.3833, 0.0000]]])

In [14]:
torch.bmm(torch.ones(2,1,3),torch.ones(2,3,2)) # X(b,n,m) * Y(b,m,k) -> Result(b,n,k) 批量点乘 n个m维向量，每个都与k个m维向量做点积，这正合q*k，一次性把每个Q做完

tensor([[[3., 3.]],

        [[3., 3.]]])

### Dot-Product Attention

In [17]:
class DotProductAttention(nn.Module):
    def __init__(self,dropout,**kwargs):
        super(DotProductAttention,self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
    #q -> batch_size,query_size,query_d/d/embedding_size
    #k -> 
    def forward(self,query,key,value,valid_len=None):
        d = query.shape[-1]#d就是指embeddings的维度吧？
        
        #scaled dot-production attention
        '''transpose转成上面的格式，批量求QK'''
        scores = torch.bmm(query,key.transpose(1,2))/math.sqrt(d)
        attention_weights = self.dropout(masked_softmax(scores,valid_len))
        return torch.bmm(attention_weights,value)

In [20]:
atten  = DotProductAttention(dropout=0.5)
atten.eval()#关掉droput
keys = torch.ones(2,10,2)
#这里dim_v和dim_k好像可以不一样，一个是4，一个是2
values = torch.arange(40,dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)
atten(torch.ones(2,1,2),keys,values,torch.tensor([2,6]))

tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]])

In [22]:
values.size()

torch.Size([2, 10, 4])

#### MLP Attention

$$ \alpha(k,q) = v^T \tanh(W_k k + W_q q) $$

In [29]:
class MLPAttention(nn.Module):
    #key_size / query_size / units : k,q,v的embeddings size
    def __init__(self,key_size,query_size,units,dropout,**kwargs):
        super(MLPAttention,self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size,units,bias=False)
        self.W_q = nn.Linear(query_size,units,bias=False)
        self.v = nn.Linear(units,1,bias=False)#units既是value的维度，也是隐藏层的维度
        self.dropout = nn.Dropout(dropout)
    
    def forward(self,query,key,value,valid_len):
        query,key = self.W_q(query),self.W_k(key)
        # Expand query to (`batch_size`, #queries, 1, units), and key to
        # (`batch_size`, 1, #kv_pairs, units). Then plus them with broadcast
        features = query.unsqueeze(2)+key.unsqueeze(1)
        features = torch.tanh(features)
        scores = self.v(features).squeeze(-1)
    
        attention_weights = self.dropout(masked_softmax(scores,valid_len))
        
        return torch.bmm(attention_weights,value)        

In [30]:
atten = MLPAttention(2,2,8,0.1)

In [31]:
atten.eval()

MLPAttention(
  (W_k): Linear(in_features=2, out_features=8, bias=False)
  (W_q): Linear(in_features=2, out_features=8, bias=False)
  (v): Linear(in_features=8, out_features=1, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [32]:
atten(torch.ones(2,1,2),keys,values,torch.tensor([2,6]))

tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)

#### 总结

* An attention layer explicitly selects related information.
注意力机制选择信息的关联强弱

* An attention layer’s memory consists of key-value pairs, so its output is close to the values whose keys are similar to the queries.


* Two commonly used attention models are dot product attention and MLP attention.