An attention function can be described as
```
mapping a query and a set of key-value pairs to an output
```
**output** computed as a weighted sum of the values  
**weight** compatibility between query and key

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

Scaled Dot-Product Attention from Attention is all you need

$
attention = softmax(\frac{QK^T}{\sqrt{d_k}})V
$

1. matmul  U = Q @ K^T
2. scale   U = U / sqrt(d_k)
3. mask    U = U.masked_fill(mask, -inf)
4. softmax A = softmax(U)
5. matmul  O = A @ V



In [2]:
class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-Product Attention 
    mapping a query and a set of key-value pairs to an output

    input:
        query  [B, n_q, d_q]
        key    [B, n_k, d_k]
        value  [B, n_v, d_v]
    output:
        attn   [B, n_q, n_k]
        output [B, n_q, d_v]
    """

    def __init__(self, scale):
        super().__init__()

        self.scale = scale
        # compute a weight vector for every query
        # therefore apply softmax in key dimention
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
        u = torch.bmm(q, k.transpose(1, 2))
        # scale是为了防止数值过大从而导致softmax时梯度很小
        u /= self.scale

        if mask is not None:
            u = u.masked_fill(mask, -np.inf)
        
        attn = self.softmax(u)
        output = torch.bmm(attn, v)
        
        return attn, output

In [3]:
# 测试代码
batch = 1
n_q, n_k, n_v = 2, 4, 4
d_q, d_k, d_v = 128, 128, 64

q = torch.randn(batch, n_q, d_q)
k = torch.randn(batch, n_k, d_k)
v = torch.randn(batch, n_v, d_v)
mask = torch.zeros(batch, n_q, n_k).bool()

attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
attn, output = attention(q, k, v, mask=mask)

In [4]:
print(attn.shape)
print(output.shape)

torch.Size([1, 2, 4])
torch.Size([1, 2, 64])


In [5]:
print(attn)
print(output)

tensor([[[0.3740, 0.1008, 0.0738, 0.4514],
         [0.6360, 0.2082, 0.0599, 0.0959]]])
tensor([[[-0.2677, -0.1240, -1.0508, -0.2012,  0.4684, -0.9366,  0.0795,
          -0.0091, -0.1899, -0.3730, -0.9943, -0.8041,  0.1018, -0.2556,
          -1.4128,  0.7581, -0.6784, -0.1562, -0.1802,  0.2760, -0.0950,
           1.0549,  0.3989, -0.2095, -0.6312, -0.9333, -0.4323,  0.2695,
          -0.3454, -0.3770, -0.6076,  0.2579, -0.0761,  0.2060, -0.7542,
           0.9421, -0.4471, -0.4111, -0.4425, -0.9084, -0.1988, -0.5302,
           0.0885, -0.7709, -1.0963, -0.9873,  0.8236,  0.5818,  0.8382,
           1.0121,  0.2256,  0.3858,  0.0234, -0.2685,  0.1367,  0.3838,
           0.3765,  0.5326,  0.5310, -0.7640, -0.7884,  0.2384,  0.1531,
          -0.2706],
         [-0.0314, -0.0598, -1.2965,  0.2059, -0.1421, -0.4629,  0.0685,
          -0.4133, -0.5710, -0.0967, -1.6474, -0.6449,  0.3460,  0.1037,
          -1.0440,  0.5452, -0.2436, -0.6906,  0.2568,  0.1788, -0.6554,
           1.729

**变形金刚结构图如下所示，我们这里主要考虑decoder的结构。**  
![](../Source/transformer.png)

**输入的格式**
一句话len个单词，每个单词编码成为一个n维向量称为hidden_dim
```python
    Q: [B, len_q, dim_q]
    K: [B, len_k, dim_k]
    V: [B, len_v, dim_v]
```

注意力机制就是计算query和key的相关性矩阵，这就要求向量维度一致
```python
    assert dim_q == dim_k
    attn_mask: [len_q, len_k]
```
softmax之后得到注意力权重，与value求出加权和即可，所以这里key的个数必须和value一致
```python
    assert len_k == len_v
    output: [B, len_q, dim_v]
```
如果我们根据作者的抽象理解，最终的输出就是query数量个的value向量
一般情况下我们认为这些维度都是相等的

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim=768, nums_heads=8):
        super().__init__()
        
        self.W_Q = nn.Linear(dim, d_q)
        self.W_K = nn.Linear(dim, d_k)
        self.W_V = nn.Linear(dim, d_v)
        self.fc = nn.Linear(d_v, dim)
        self.ln = nn.LayerNorm(dim)

        self.nums_heads = nums_heads
    
    def forward(self, q, k, v):
        residual, batch = q, q.shape(0)

        q = self.W_Q(q).reshape(batch, n_q, self.nums_heads, -1).transpose(1, 2)
        k = self.W_K(k).reshape(batch, n_k, self.nums_heads, -1).transpose(1, 2)
        V = self.W_K(v).reshape(batch, n_v, self.nums_heads, -1).transpose(1, 2)

        attn = torch.bmm(q, k.transpose(2, 3))  # [bacth, heads, n_q, n_k]
        attn = F.softmax(attn / np.sqrt(d_k), dim=-1)
        output = torch.bmm(attn, v).transpose(1, 2).reshape(batch, n_v, -1)
        output = self.fc(output)

        return self.ln(output + residual)

In [None]:
class mlp(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()

        self.fc = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim, dim * 4)
        )
    
    def forward(self, x):
        residual = x
        output = self.fc(x)

        return self.ln(output + residual)

batchnorm是在一个batch的各个样本之间做归一化，参数大小是 2×C  
Layernorm则是在特征之间做归一化，所以参数大小和batchsize有关？

可以看看这个，LayerNorm就是加在最后D个维度上，被称为normalized_shape  
https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html

为什么对于图像是`nn.LayerNorm([C, H, W])`而对于NLP是`nn.LayerNorm(embedding_dim)`  
考虑单独的实例，一张图像的组成就是[C, H, W]， 而一个单词只是[N]维的向量，这里做LN是在单词维度上，而不是句子，所以只需要最后一个维度。

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-12) -> None:
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.parameter(torch.zeros(dim))
        self.eps = eps

    def forward(self, x):
        # 是把batch内每个实例，单独对其特征求均值和方差，这里的特征只有最后一个维度
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)

        out = (x - mean) / torch.sqrt(var + self.eps)  # 尺度不变
        out = self.gamma * out + self.beta  # 特征变换 所以gamma和beta是dim尺寸

        return out

In [4]:
# NLP Example
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
print(embedding.mean(-1, keepdim=True).shape)
layer_norm = nn.LayerNorm(embedding_dim)
# Activate module
# layer_norm(embedding)

torch.Size([20, 5, 1])


tensor([[[ 4.5243e-01,  5.1689e-01,  8.7616e-01,  1.0935e+00,  1.0197e+00,
          -6.1018e-02,  3.2073e-01, -1.6677e+00, -1.7136e+00, -8.3704e-01],
         [ 3.4844e-01,  1.9638e-01, -8.7768e-01,  9.0137e-01, -1.2342e+00,
          -1.3312e+00, -8.8244e-01,  1.5841e+00,  1.2937e+00,  1.4217e-03],
         [ 1.6180e+00, -3.2185e-01, -1.2941e+00,  1.2915e+00,  5.7974e-01,
          -1.1048e+00, -1.0464e+00, -4.7261e-01,  9.9943e-01, -2.4889e-01],
         [-1.1997e+00,  9.4407e-01,  1.3663e+00,  5.6440e-01, -4.6944e-01,
          -6.2003e-01,  1.6438e+00, -9.0377e-01, -1.1536e+00, -1.7216e-01],
         [ 5.0406e-04,  1.3974e-01,  9.7277e-01, -1.9937e+00, -1.0500e+00,
           1.1876e+00,  9.5969e-01,  3.2783e-01, -1.0999e+00,  5.5541e-01]],

        [[-6.3729e-01, -6.3336e-01, -1.3643e+00,  1.0272e-01,  1.2369e+00,
           1.7993e+00, -1.3663e+00,  8.2434e-01, -3.7042e-02,  7.5135e-02],
         [-2.3591e-01,  1.4860e-01, -2.0166e+00, -5.7666e-01, -4.4758e-01,
           1.9476