<a href="https://colab.research.google.com/github/daixinguang/code_snippets/blob/master/multi_head_self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Head Attention

Paper: `Transformer` Attention is All you need (NIPS 2017)

Code:
- [官方TensorFlow实现](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py)
- [Pytorch实现](https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/SubLayers.py)

Transformer architecture的两个核心sub-layers
- Multi-Head Attention layer
- Feed Forward Network layer

Reference:
- `Enzo_Mi` [Multi-Head Attention | 算法 + 代码](https://www.bilibili.com/video/BV1qo4y1F7Ep)
- `黑白` [Transformer代码及解析(Pytorch)](https://zhuanlan.zhihu.com/p/345993564)
- `大师兄`[详解Transformer （Attention Is All You Need） - 知乎](https://zhuanlan.zhihu.com/p/48508221)
- `于建民` [The Illustrated Transformer【译】](https://blog.csdn.net/yujianmin1990/article/details/85221271)
- `Jay Alammar` [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)

In [4]:
import torch
import torch.nn as nn

transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, tgt)
print(out)



tensor([[[ 0.3641,  0.7691, -0.1651,  ..., -0.2663,  1.9459, -2.0577],
         [ 0.2793,  0.8363, -0.4215,  ..., -0.1447,  0.9408, -1.1444],
         [ 1.5120,  0.2229, -0.6750,  ..., -0.1214,  0.8203, -2.2667],
         ...,
         [ 0.9347,  0.2070, -0.2067,  ...,  0.2668,  1.4383, -1.9678],
         [ 0.6448,  1.1027,  0.6465,  ..., -0.9839,  1.4634, -1.6281],
         [ 0.0942,  0.7632,  0.4352,  ..., -0.5545,  1.3510, -2.2930]],

        [[ 0.7475,  0.0247,  0.0600,  ...,  0.3426,  1.3491, -1.7195],
         [-1.0700,  0.8039,  0.3605,  ...,  0.1669,  0.9159, -1.7930],
         [ 0.8181,  0.7471, -0.0348,  ...,  0.1656,  0.5735, -1.7630],
         ...,
         [ 0.9119,  0.8595, -0.5107,  ...,  1.3115,  0.5397, -1.7754],
         [-0.5404,  0.7633,  0.0762,  ..., -0.3760,  1.0363, -1.9690],
         [-0.2461,  0.5505,  0.5155,  ..., -0.6507,  1.4638, -2.1236]],

        [[-0.0719,  0.9742,  0.2095,  ..., -0.1136,  0.7836, -1.5936],
         [-0.6308, -0.0321,  0.0337,  ..., -0

In [1]:
import torch
import torch.nn as nn
from math import sqrt
class MultiHeadSelfAttention(nn.Module):

    def __init__(self, dim_in, d_model, num_heads=3):
        super(MultiHeadSelfAttention, self).__init__()

        self.dim_in = dim_in
        self.d_model = d_model
        self.num_heads = num_heads

        assert d_model % num_heads == 0 # d_model must be multiple of num_heads

        self.linear_q = nn.Linear(dim_in, d_model)
        self.linear_k = nn.Linear(dim_in, d_model)
        self.linear_v = nn.Linear(dim_in, d_model)

        self.scale = 1 / sqrt(d_model // d_model)

        self.fc = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch, n, dim_in = x.shape # x: shape(batch, n, dim_in)
        assert dim_in == self.dim_in

        nh = self.num_heads
        dk = self.d_model // nh

        q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        v = self.linear_v(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)

        dist = torch.matmul(q,k.transpose(2,3)) * self.scale # (batch,nh,n,n)
        dist = torch.softmax(dist, dim=-1)

        att = torch.matmul(dist, x)
        att = att.transpose(1,2).reshape(batch, n, self.d_model)

        output = self.fc(att)

        return output

x = torch.rand((1,4,2))
multi_head_att = MultiHeadSelfAttention(x.shape[2], 6, 3)
output = multi_head_att(x)

print(x, '\n', output)

tensor([[[0.4756, 0.6228],
         [0.1560, 0.2808],
         [0.4538, 0.6251],
         [0.6369, 0.5038]]]) 
 tensor([[[-0.1063,  0.1082, -0.4711,  0.1384, -0.4054, -0.7985],
         [-0.1065,  0.1071, -0.4713,  0.1380, -0.4051, -0.7985],
         [-0.1064,  0.1082, -0.4711,  0.1384, -0.4055, -0.7985],
         [-0.1052,  0.1079, -0.4713,  0.1380, -0.4041, -0.7985]]],
       grad_fn=<ViewBackward0>)


In [None]:
# 每个sub-layer的输入dimension=512, dim_in=512
# 原论文中 d_model=512 dk=dv=64, h=8





$$
\mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

In [None]:

dist = torch.matmul(q,k.transpose(2,3)) * self.scale # (batch,nh,n,n)
dist = torch.softmax(dist, dim=-1)
att = torch.matmul(dist, x) # Attention(QKV)

att = att.transpose(1,2).reshape(batch, n, self.d_model)

In [None]:
x = torch.rand((1,4,2))

res1 = x @ x.transpose(1,2)
res2 = torch.matmul(x, x.transpose(1,2))

print(res1, '\n', res2)

tensor([[[0.4200, 0.3127, 0.5360, 0.5334],
         [0.3127, 0.3135, 0.4851, 0.5296],
         [0.5360, 0.4851, 0.7758, 0.8220],
         [0.5334, 0.5296, 0.8220, 0.8949]]]) 
 tensor([[[0.4200, 0.3127, 0.5360, 0.5334],
         [0.3127, 0.3135, 0.4851, 0.5296],
         [0.5360, 0.4851, 0.7758, 0.8220],
         [0.5334, 0.5296, 0.8220, 0.8949]]])
