# Attention is all you need
[https://arxiv.org/abs/1706.03762](https://arxiv.org/abs/1706.03762)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

![image.png](img\attention.png)  
![scaled-dot-product-attention.png](img\scaled-dot-product-attention.png)

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__():
        super().__init__()

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask=None) -> torch.Tensor:
        # q: (b, n, dk)
        # k: (b, m, dk)
        # v: (b, m, dk)
        return F.softmax(torch.matmul(q, k.T) / math.sqrt(q.size(-1)), dim=-1).matmul(v)

为什么要除以根号dk？
与Softmax函数两边的扁平区有关，两边的梯度接近于0，所以softmax进去的值不能特大或特小，然而
Q.matmul(K.T)的方差可能比较大，容易造成梯度消失。
假设Q和K的均值为0，方差为1，且Q和K独立，则Q.matmul(K.T)的方差 = sum to dk(var(Q * K)) = sum to dk(var(Q) * var(K)) = dk
要让方差变为1:
var(Q * K / sqrt(dk)) = 1 / dk * var(Q * K) = 1

In [69]:
dk = 512
Q = torch.rand([1024, dk], requires_grad=True)
K = torch.rand([1024, dk], requires_grad=True)
sm = torch.softmax(Q.matmul(K.T), dim=-1)
sm[0, 0].backward()
print("Normal dot product softmax grad:")
print(f"grad of Q: {Q.grad.max()}")
print(f"grad of K: {K.grad.max()}")

sm = torch.softmax(Q.matmul(K.T) / math.sqrt(dk), dim=-1)
sm[0, 0].backward()
print("Scaled dot product softmax grad:")
print(f"grad of Q: {Q.grad.max()}")
print(f"grad of K: {K.grad.max()}")

Normal dot product softmax grad:
grad of Q: 1.606757393801672e-08
grad of K: 2.6459346713636478e-08
Scaled dot product softmax grad:
grad of Q: 1.804493513191119e-05
grad of K: 3.5442018997855484e-05
