<a href="https://colab.research.google.com/github/daixinguang/code_snippets/blob/master/multi_head_self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Head Attention

Paper: `Transformer` Attention is All you need (NIPS 2017)

Code:
- [官方TensorFlow实现](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py)
- [Pytorch实现](https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/SubLayers.py)

Transformer architecture的两个核心sub-layers
- Multi-Head Attention layer
- Feed Forward Network layer
- BBBBBBBBBB test

Reference:
- `Enzo_Mi` [Multi-Head Attention | 算法 + 代码](https://www.bilibili.com/video/BV1qo4y1F7Ep)
- `黑白` [Transformer代码及解析(Pytorch)](https://zhuanlan.zhihu.com/p/345993564)
- `大师兄` [详解Transformer （Attention Is All You Need） - 知乎](https://zhuanlan.zhihu.com/p/48508221)
- `于建民` [The Illustrated Transformer【译】](https://blog.csdn.net/yujianmin1990/article/details/85221271)
- `Jay Alammar` [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)

In [None]:
# Torch 实现的Transformer

import torch

transformer_model = torch.nn.Transformer(d_model=512, 
                                        nhead=8, 
                                        num_encoder_layers=6, 
                                        num_decoder_layers=6, 
                                        dim_feedforward=2048,
                                        dropout=0.1)

src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, tgt)
print(out.shape)

In [None]:
import torch
import torch.nn as nn
from math import sqrt
class MultiHeadSelfAttention(nn.Module):

    def __init__(self, dim_in, d_model, num_heads=3):
        super(MultiHeadSelfAttention, self).__init__()

        self.dim_in = dim_in
        self.d_model = d_model
        self.num_heads = num_heads

        assert d_model % num_heads == 0 # d_model must be multiple of num_heads

        self.linear_q = nn.Linear(dim_in, d_model)
        self.linear_k = nn.Linear(dim_in, d_model)
        self.linear_v = nn.Linear(dim_in, d_model)

        self.scale = 1 / sqrt(d_model // d_model)

        self.fc = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch, n, dim_in = x.shape # x: shape(batch, n, dim_in)
        assert dim_in == self.dim_in

        nh = self.num_heads
        dk = self.d_model // nh

        q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        v = self.linear_v(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)

        dist = torch.matmul(q,k.transpose(2,3)) * self.scale # (batch,nh,n,n)
        dist = torch.softmax(dist, dim=-1)

        att = torch.matmul(dist, x)
        att = att.transpose(1,2).reshape(batch, n, self.d_model)

        output = self.fc(att)

        return output

x = torch.rand((1,4,2))
multi_head_att = MultiHeadSelfAttention(x.shape[2], 6, 3)
output = multi_head_att(x)

print(x, '\n', output)

In [None]:
# 每个sub-layer的输入dimension=512, dim_in=512
# 原论文中 d_model=512 dk=dv=64, h=8




ScaledDotProductAttention 对应的公式和代码

$$
\mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

In [None]:

dist = torch.matmul(q,k.transpose(2,3)) * self.scale # (batch,nh,n,n)
dist = torch.softmax(dist, dim=-1)
att = torch.matmul(dist, x) # Attention(QKV)

att = att.transpose(1,2).reshape(batch, n, self.d_model)

Position-wise Feed Forward 对应的公式和代码

$$
\mathrm{FFN}(x)=\max(0,xW_1+b_1)W_2+b_2
$$



In [None]:
import torch.nn.functional as F

class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid) # position-wise
        self.w_2 = nn.Linear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        # add & norm
        x += residual

        x = self.layer_norm(x)

        return x

In [None]:
x = torch.rand((1,4,2))

res1 = x @ x.transpose(1,2)
res2 = torch.matmul(x, x.transpose(1,2))

print(res1, '\n', res2)