<a href="https://colab.research.google.com/github/daixinguang/code_snippets/blob/master/multi_head_self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Head Attention

Paper: `Transformer` Attention is All you need (NIPS 2017)

Code:
- [官方TensorFlow实现](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py)
- [Pytorch实现](https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/SubLayers.py)

Transformer architecture的两个核心sub-layers
- Multi-Head Attention layer
- Feed Forward Network layer

Reference:
- `Enzo_Mi` [Multi-Head Attention | 算法 + 代码](https://www.bilibili.com/video/BV1qo4y1F7Ep)
- `黑白` [Transformer代码及解析(Pytorch)](https://zhuanlan.zhihu.com/p/345993564)
- [详解Transformer （Attention Is All You Need） - 知乎](https://zhuanlan.zhihu.com/p/48508221)
- `于建民` [The Illustrated Transformer【译】](https://blog.csdn.net/yujianmin1990/article/details/85221271)
- `Jay Alammar` [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)

In [3]:
import torch
import torch.nn as nn

transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, tgt)
print(out)



tensor([[[-1.0862,  1.6395,  0.6158,  ..., -0.1350, -0.0343,  0.8277],
         [-0.4080,  0.8118,  0.3556,  ..., -1.6812, -0.0687,  0.2144],
         [-0.1890,  1.4179, -0.1665,  ..., -1.4447, -0.5229, -0.5293],
         ...,
         [ 0.1457,  0.7281, -0.5304,  ..., -1.7234,  0.8478,  0.8505],
         [-0.3402,  2.2362,  0.4259,  ..., -1.0775,  0.1720,  0.7748],
         [-0.1124,  1.3646, -0.2595,  ..., -1.3014, -0.5485,  1.4648]],

        [[-0.4230,  1.9957,  0.1823,  ..., -0.0650, -0.2859, -0.0335],
         [ 0.0601,  0.7105,  0.7234,  ..., -0.7285,  0.6098,  1.2656],
         [-0.6270,  1.6926, -0.3442,  ..., -0.8210, -0.7875,  0.6791],
         ...,
         [-0.2587,  0.7681, -0.8603,  ..., -1.6751, -0.7452,  0.5740],
         [-0.7096,  1.0769,  0.6512,  ...,  0.3322, -0.4507,  1.4602],
         [-0.4582,  1.6992, -0.1435,  ..., -1.2184,  0.1195,  0.8606]],

        [[-0.9396,  1.8701,  0.2147,  ..., -1.3849,  0.0541,  0.2041],
         [ 0.1012,  0.4822,  0.1239,  ..., -1

In [1]:
import torch
import torch.nn as nn
from math import sqrt
class MultiHeadSelfAttention(nn.Module):

    def __init__(self, dim_in, d_model, num_heads=3):
        super(MultiHeadSelfAttention, self).__init__()

        self.dim_in = dim_in
        self.d_model = d_model
        self.num_heads = num_heads

        assert d_model % num_heads == 0 # d_model must be multiple of num_heads

        self.linear_q = nn.Linear(dim_in, d_model)
        self.linear_k = nn.Linear(dim_in, d_model)
        self.linear_v = nn.Linear(dim_in, d_model)

        self.scale = 1 / sqrt(d_model // d_model)

        self.fc = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch, n, dim_in = x.shape # x: shape(batch, n, dim_in)
        assert dim_in == self.dim_in

        nh = self.num_heads
        dk = self.d_model // nh

        q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        v = self.linear_v(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)

        dist = torch.matmul(q,k.transpose(2,3)) * self.scale # (batch,nh,n,n)
        dist = torch.softmax(dist, dim=-1)

        att = torch.matmul(dist, x)
        att = att.transpose(1,2).reshape(batch, n, self.d_model)

        output = self.fc(att)

        return output

x = torch.rand((1,4,2))
multi_head_att = MultiHeadSelfAttention(x.shape[2], 6, 3)
output = multi_head_att(x)

print(x, '\n', output)

tensor([[[0.4756, 0.6228],
         [0.1560, 0.2808],
         [0.4538, 0.6251],
         [0.6369, 0.5038]]]) 
 tensor([[[-0.1063,  0.1082, -0.4711,  0.1384, -0.4054, -0.7985],
         [-0.1065,  0.1071, -0.4713,  0.1380, -0.4051, -0.7985],
         [-0.1064,  0.1082, -0.4711,  0.1384, -0.4055, -0.7985],
         [-0.1052,  0.1079, -0.4713,  0.1380, -0.4041, -0.7985]]],
       grad_fn=<ViewBackward0>)


In [None]:
# 每个sub-layer的输入dimension=512, dim_in=512
# 原论文中 d_model=512 dk=dv=64, h=8





$$
\mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

In [None]:

dist = torch.matmul(q,k.transpose(2,3)) * self.scale # (batch,nh,n,n)
dist = torch.softmax(dist, dim=-1)
att = torch.matmul(dist, x) # Attention(QKV)

att = att.transpose(1,2).reshape(batch, n, self.d_model)

In [None]:
x = torch.rand((1,4,2))

res1 = x @ x.transpose(1,2)
res2 = torch.matmul(x, x.transpose(1,2))

print(res1, '\n', res2)

tensor([[[0.4200, 0.3127, 0.5360, 0.5334],
         [0.3127, 0.3135, 0.4851, 0.5296],
         [0.5360, 0.4851, 0.7758, 0.8220],
         [0.5334, 0.5296, 0.8220, 0.8949]]]) 
 tensor([[[0.4200, 0.3127, 0.5360, 0.5334],
         [0.3127, 0.3135, 0.4851, 0.5296],
         [0.5360, 0.4851, 0.7758, 0.8220],
         [0.5334, 0.5296, 0.8220, 0.8949]]])
