# Multi-Head Attention

Paper: `Transformer` Attention is All you need (NIPS 2017)

Code:
- [官方TensorFlow实现](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py)
- [Pytorch实现](https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/SubLayers.py)

Transformer architecture的两个核心sub-layers 
- Multi-Head Attention layer
- Feed Forward Network layer

Reference:
- `Enzo_Mi` [Multi-Head Attention | 算法 + 代码](https://www.bilibili.com/video/BV1qo4y1F7Ep)
- `黑白` [Transformer代码及解析(Pytorch)](https://zhuanlan.zhihu.com/p/345993564)
- [详解Transformer （Attention Is All You Need） - 知乎](https://zhuanlan.zhihu.com/p/48508221)
- `于建民` [The Illustrated Transformer【译】](https://blog.csdn.net/yujianmin1990/article/details/85221271)
- `Jay Alammar` [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)

In [8]:
import torch
import torch.nn as nn
from math import sqrt
class MultiHeadSelfAttention(nn.Module):
    
    def __init__(self, dim_in, d_model, num_heads=3):
        super(MultiHeadSelfAttention, self).__init__()
    
        self.dim_in = dim_in
        self.d_model = d_model
        self.num_heads = num_heads
        
        assert d_model % num_heads == 0 # d_model must be multiple of num_heads
        
        self.linear_q = nn.Linear(dim_in, d_model)
        self.linear_k = nn.Linear(dim_in, d_model)
        self.linear_v = nn.Linear(dim_in, d_model)
        
        self.scale = 1 / sqrt(d_model // d_model)
        
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch, n, dim_in = x.shape # x: shape(batch, n, dim_in)
        assert dim_in == self.dim_in
        
        nh = self.num_heads
        dk = self.d_model // nh
        
        q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        v = self.linear_v(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        
        dist = torch.matmul(q,k.transpose(2,3)) * self.scale # (batch,nh,n,n)
        dist = torch.softmax(dist, dim=-1)
        
        att = torch.matmul(dist, x)
        att = att.transpose(1,2).reshape(batch, n, self.d_model)
        
        output = self.fc(att)
        
        return output

x = torch.rand((1,4,2))
multi_head_att = MultiHeadSelfAttention(x.shape[2], 6, 3)
output = multi_head_att(x)

print(x, '\n', output)

tensor([[[0.7723, 0.2628],
         [0.8771, 0.9445],
         [0.3468, 0.1092],
         [0.6148, 0.7719]]]) 
 tensor([[[ 0.2075,  0.0570,  0.1289, -0.2385, -0.2644,  0.7774],
         [ 0.2092,  0.0622,  0.1302, -0.2377, -0.2602,  0.7797],
         [ 0.2043,  0.0643,  0.1267, -0.2208, -0.2476,  0.7803],
         [ 0.2072,  0.0658,  0.1288, -0.2274, -0.2509,  0.7811]]],
       grad_fn=<AddBackward0>)


In [None]:
# 每个sub-layer的输入dimension=512, dim_in=512
# 原论文中 d_model=512 dk=dv=64, h=8





$$
\mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

In [None]:

dist = torch.matmul(q,k.transpose(2,3)) * self.scale # (batch,nh,n,n)
dist = torch.softmax(dist, dim=-1)
att = torch.matmul(dist, x) # Attention(QKV)

att = att.transpose(1,2).reshape(batch, n, self.d_model)

In [7]:
x = torch.rand((1,4,2))

res1 = x @ x.transpose(1,2)
res2 = torch.matmul(x, x.transpose(1,2))

print(res1, '\n', res2)

tensor([[[0.4200, 0.3127, 0.5360, 0.5334],
         [0.3127, 0.3135, 0.4851, 0.5296],
         [0.5360, 0.4851, 0.7758, 0.8220],
         [0.5334, 0.5296, 0.8220, 0.8949]]]) 
 tensor([[[0.4200, 0.3127, 0.5360, 0.5334],
         [0.3127, 0.3135, 0.4851, 0.5296],
         [0.5360, 0.4851, 0.7758, 0.8220],
         [0.5334, 0.5296, 0.8220, 0.8949]]])
