# Attention is all you need
[https://arxiv.org/abs/1706.03762](https://arxiv.org/abs/1706.03762)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

![image.png](img\attention.png)  
![scaled-dot-product-attention.png](img\scaled-dot-product-attention.png)

In [45]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask=None) -> torch.Tensor:
        # q: (b, n, dk)
        # k: (b, m, dk)
        # v: (b, m, dk)

        # (b, n, m)
        x = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(q.size(-1))
        if mask is not None:
            x = x * mask
        
        # (b, n, dk)
        return F.softmax(x, dim=-1).matmul(v)

为什么要除以根号dk？
与Softmax函数两边的扁平区有关，两边的梯度接近于0，所以softmax进去的值不能特大或特小，然而
Q.matmul(K.T)的方差可能比较大，容易造成梯度消失。
假设Q和K的均值为0，方差为1，且Q和K独立，则Q.matmul(K.T)的方差 = sum to dk(var(Q * K)) = sum to dk(var(Q) * var(K)) = dk
要让方差变为1:
var(Q * K / sqrt(dk)) = 1 / dk * var(Q * K) = 1

In [29]:
dk = 512
Q = torch.rand([1024, dk], requires_grad=True)
K = torch.rand([1024, dk], requires_grad=True)
sm = torch.softmax(Q.matmul(K.T), dim=-1)
sm[0, 0].backward()
print("Normal dot product softmax grad:")
print(f"grad of Q: {Q.grad.max()}")
print(f"grad of K: {K.grad.max()}")

sm = torch.softmax(Q.matmul(K.T) / math.sqrt(dk), dim=-1)
sm[0, 0].backward()
print("Scaled dot product softmax grad:")
print(f"grad of Q: {Q.grad.max()}")
print(f"grad of K: {K.grad.max()}")

Normal dot product softmax grad:
grad of Q: 1.1947786049404385e-07
grad of K: 1.7698897636364563e-07
Scaled dot product softmax grad:
grad of Q: 1.9767354388022795e-05
grad of K: 3.987304444308393e-05


![mha.png](img\mha.png)

In [46]:
class MHA(nn.Module):
    def __init__(self, d_model, num_head):
        super().__init__()
        self.d_model = d_model
        self.num_head = num_head
        self.d_k = self.d_model // self.num_head
        self.W_Q = [nn.Linear(d_model, self.d_k)] * self.num_head
        self.W_K = [nn.Linear(d_model, self.d_k)] * self.num_head
        self.W_V = [nn.Linear(d_model, self.d_k)] * self.num_head
        self.attn = [ScaledDotProductAttention()] * self.num_head
        self.linear_out = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        # (b, n, d_k)
        q_proj = [w_q(Q) for w_q in self.W_Q]
        # (b, m, d_k)
        k_proj = [w_k(K) for w_k in self.W_K]
        v_proj = [w_v(V) for w_v in self.W_V]

        attn_out = [attn(q, k, v, mask) for attn, q, k, v in zip(self.attn, q_proj, k_proj, v_proj)]

        # (b, n, d_model)
        return self.linear_out(torch.concat(attn_out, dim=-1))

In [88]:
d_model = 512
num_head = 8

batch_size = 1
n = 16
m = 32
mha = MHA(d_model=d_model, num_head=num_head)
Q = torch.rand([batch_size, n, d_model])
K = torch.rand([batch_size, m, d_model])
V = torch.rand([batch_size, m, d_model])
mha_out = mha(Q, K, V)
print(mha_out)
print(mha_out.size())
assert(mha_out.size() == torch.Size([batch_size, n, d_model]))

tensor([[[ 0.0192, -0.1477,  0.1066,  ...,  0.0278,  0.2166, -0.0585],
         [ 0.0199, -0.1477,  0.1066,  ...,  0.0286,  0.2163, -0.0579],
         [ 0.0200, -0.1481,  0.1070,  ...,  0.0279,  0.2166, -0.0585],
         ...,
         [ 0.0195, -0.1477,  0.1072,  ...,  0.0288,  0.2161, -0.0585],
         [ 0.0200, -0.1470,  0.1068,  ...,  0.0275,  0.2167, -0.0576],
         [ 0.0196, -0.1476,  0.1067,  ...,  0.0279,  0.2168, -0.0584]]],
       grad_fn=<ViewBackward0>)
torch.Size([1, 16, 512])


![encode_block](img\encode_block.png)

In [47]:
class FFN(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        
        self.d_ffn = 4 * d_model
        self.ffn_hidden = nn.Linear(d_model, self.d_ffn)
        self.relu = nn.ReLU()
        self.ffn_out = nn.Linear(self.d_ffn, d_model)

    def forward(self, x):
        return self.ffn_out(self.relu(self.ffn_hidden(x)))

In [65]:
class ResMHA(nn.Module):
    def __init__(self, d_model, num_head):
        super().__init__()

        self.mha = MHA(d_model, num_head)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, Q, K, V, mask=None, dropout=None):
        mha_out = self.mha(Q=Q, K=K, V=V, mask=mask)
        if dropout:
            mha_out = dropout(mha_out)
        x = self.norm(Q + mha_out)
        return x

In [66]:
class ResFFN(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.ffn = FFN(d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, dropout=None):
        ffn_out = self.ffn(x)
        if dropout:
            ffn_out = dropout(ffn_out)
        x = self.norm(x + ffn_out)
        return x

In [67]:
class EncodeBlock(nn.Module):
    def __init__(self, d_model=512, num_head=8):
        super().__init__()

        self.mha = ResMHA(d_model, num_head)
        self.ffn = ResFFN(d_model)

    def forward(self, x, dropout=None):
        x = self.mha(x,x,x, dropout)
        x = self.ffn(x, dropout)
        return x

In [68]:
d_model = 512
num_head = 8

batch_size = 1
n = 16
m = 32
Q = torch.rand([batch_size, n, d_model])

encode_block = EncodeBlock(d_model, num_head)
encode_out = encode_block(Q)
print(encode_out)
print(encode_out.size())
assert(encode_out.size() == torch.Size([batch_size, n, d_model]))

tensor([[[-1.2226, -0.0778,  0.8386,  ...,  0.1688, -0.8030, -0.6777],
         [ 0.4460,  0.8320, -1.4905,  ...,  0.8372, -1.1637,  0.0205],
         [-1.7901, -0.5204,  0.7034,  ...,  1.6489,  0.5827, -0.6641],
         ...,
         [-1.6955, -1.1857, -0.0091,  ...,  0.3037, -0.0612, -0.0898],
         [-0.3293, -0.2477, -0.9037,  ...,  1.0064, -0.6258,  0.0540],
         [-1.0016,  0.9891,  0.0200,  ..., -0.3595,  1.4110, -0.5751]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 16, 512])


![decode_block](img\decode_block.png)

In [69]:
def sequence_mask(d):
    mask = torch.ones([d, d])
    tril_mask = torch.tril(mask) == 0
    return mask.masked_fill(tril_mask, value=-1e-9).unsqueeze(0)

mask = sequence_mask(4)
print(mask)

tensor([[[ 1.0000e+00, -1.0000e-09, -1.0000e-09, -1.0000e-09],
         [ 1.0000e+00,  1.0000e+00, -1.0000e-09, -1.0000e-09],
         [ 1.0000e+00,  1.0000e+00,  1.0000e+00, -1.0000e-09],
         [ 1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00]]])


In [74]:
class DecodeBlock(nn.Module):
    def __init__(self, d_model=512, num_head=8):
        super().__init__()

        self.mha = ResMHA(d_model, num_head)
        self.ffn = ResFFN(d_model)

    def forward(self, x, encoder_out_kv, mask=None, dropout=None):
        if mask is None:
            mask = sequence_mask(x.size(-2))

        q = self.mha(Q=x, K=x, V=x, mask=mask, dropout=dropout)
        x = self.mha(Q=q, K=encoder_out_kv, V=encoder_out_kv, dropout=dropout)
        x = self.ffn(x)
        return x

In [77]:
d_model = 512
num_head = 8

batch_size = 1
n = 16
m = 32
Q = torch.rand([batch_size, n, d_model])

encode_block = EncodeBlock(d_model, num_head)
encode_out = encode_block(Q)

decode_block = DecodeBlock(d_model, num_head)
decode_out = decode_block(Q, encode_out)
print(decode_out)
print(decode_out.size())
assert(decode_out.size() == torch.Size([batch_size, n, d_model]))

tensor([[[-0.3511, -0.0940, -1.1568,  ..., -0.0988, -2.7415,  1.1955],
         [ 0.4982,  0.5903,  1.1022,  ..., -1.3403, -0.2859,  2.1205],
         [-0.9958,  0.4812,  0.0228,  ...,  0.0433, -1.1576,  0.9903],
         ...,
         [-1.8417,  1.3384, -0.0246,  ...,  0.3239, -0.2813,  1.0861],
         [-0.3814,  0.3114,  0.8897,  ...,  0.6380, -2.3141,  0.4224],
         [ 0.2877,  0.5784,  0.8718,  ..., -0.1036, -1.9636,  0.8790]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 16, 512])
