# Attention is all you need
[https://arxiv.org/abs/1706.03762](https://arxiv.org/abs/1706.03762)

In [120]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

![image.png](img\attention.png)  
![scaled-dot-product-attention.png](img\scaled-dot-product-attention.png)

In [121]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask=None) -> torch.Tensor:
        # q: (b, n, dk)
        # k: (b, m, dk)
        # v: (b, m, dk)

        # (b, n, m)
        x = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(q.size(-1))
        if mask is not None:
            x = x * mask
        
        # (b, n, dk)
        return F.softmax(x, dim=-1).matmul(v)

为什么要除以根号dk？
与Softmax函数两边的扁平区有关，两边的梯度接近于0，所以softmax进去的值不能特大或特小，然而
Q.matmul(K.T)的方差可能比较大，容易造成梯度消失。
假设Q和K的均值为0，方差为1，且Q和K独立，则Q.matmul(K.T)的方差 = sum to dk(var(Q * K)) = sum to dk(var(Q) * var(K)) = dk
要让方差变为1:
var(Q * K / sqrt(dk)) = 1 / dk * var(Q * K) = 1

In [122]:
dk = 512
Q = torch.rand([1024, dk], requires_grad=True)
K = torch.rand([1024, dk], requires_grad=True)
sm = torch.softmax(Q.matmul(K.T), dim=-1)
sm[0, 0].backward()
print("Normal dot product softmax grad:")
print(f"grad of Q: {Q.grad.max()}")
print(f"grad of K: {K.grad.max()}")

sm = torch.softmax(Q.matmul(K.T) / math.sqrt(dk), dim=-1)
sm[0, 0].backward()
print("Scaled dot product softmax grad:")
print(f"grad of Q: {Q.grad.max()}")
print(f"grad of K: {K.grad.max()}")

Normal dot product softmax grad:
grad of Q: 4.1509100157099965e-08
grad of K: 6.693417731185036e-08
Scaled dot product softmax grad:
grad of Q: 1.873047040135134e-05
grad of K: 3.744859714061022e-05


![mha.png](img\mha.png)

In [123]:
class MHA(nn.Module):
    def __init__(self, d_model, num_head):
        super().__init__()
        self.d_model = d_model
        self.num_head = num_head
        self.d_k = self.d_model // self.num_head
        self.W_Q = [nn.Linear(d_model, self.d_k)] * self.num_head
        self.W_K = [nn.Linear(d_model, self.d_k)] * self.num_head
        self.W_V = [nn.Linear(d_model, self.d_k)] * self.num_head
        self.attn = [ScaledDotProductAttention()] * self.num_head
        self.linear_out = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        # (b, n, d_k)
        q_proj = [w_q(Q) for w_q in self.W_Q]
        # (b, m, d_k)
        k_proj = [w_k(K) for w_k in self.W_K]
        v_proj = [w_v(V) for w_v in self.W_V]

        attn_out = [attn(q, k, v, mask) for attn, q, k, v in zip(self.attn, q_proj, k_proj, v_proj)]

        # (b, n, d_model)
        return self.linear_out(torch.concat(attn_out, dim=-1))

In [124]:
d_model = 512
num_head = 8

batch_size = 1
n = 16
m = 32
mha = MHA(d_model=d_model, num_head=num_head)
Q = torch.rand([batch_size, n, d_model])
K = torch.rand([batch_size, m, d_model])
V = torch.rand([batch_size, m, d_model])
mha_out = mha(Q, K, V)
print(mha_out)
print(mha_out.size())
assert(mha_out.size() == torch.Size([batch_size, n, d_model]))

tensor([[[-0.3248,  0.0545,  0.0064,  ..., -0.2376, -0.0765,  0.0862],
         [-0.3234,  0.0535,  0.0067,  ..., -0.2379, -0.0763,  0.0865],
         [-0.3241,  0.0537,  0.0067,  ..., -0.2381, -0.0761,  0.0868],
         ...,
         [-0.3243,  0.0534,  0.0067,  ..., -0.2376, -0.0760,  0.0861],
         [-0.3243,  0.0543,  0.0061,  ..., -0.2373, -0.0755,  0.0865],
         [-0.3248,  0.0546,  0.0060,  ..., -0.2373, -0.0759,  0.0852]]],
       grad_fn=<ViewBackward0>)
torch.Size([1, 16, 512])


![encode_block](img\encode_block.png)

In [125]:
class FFN(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        
        self.d_ffn = 4 * d_model
        self.ffn_hidden = nn.Linear(d_model, self.d_ffn)
        self.relu = nn.ReLU()
        self.ffn_out = nn.Linear(self.d_ffn, d_model)

    def forward(self, x):
        return self.ffn_out(self.relu(self.ffn_hidden(x)))

In [126]:
class ResMHA(nn.Module):
    def __init__(self, d_model, num_head):
        super().__init__()

        self.mha = MHA(d_model, num_head)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, Q, K, V, mask=None, dropout=None):
        mha_out = self.mha(Q=Q, K=K, V=V, mask=mask)
        if dropout:
            mha_out = dropout(mha_out)
        x = self.norm(Q + mha_out)
        return x

In [127]:
class ResFFN(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.ffn = FFN(d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, dropout=None):
        ffn_out = self.ffn(x)
        if dropout:
            ffn_out = dropout(ffn_out)
        x = self.norm(x + ffn_out)
        return x

In [128]:
class EncodeBlock(nn.Module):
    def __init__(self, d_model=512, num_head=8):
        super().__init__()

        self.mha = ResMHA(d_model, num_head)
        self.ffn = ResFFN(d_model)

    def forward(self, x, dropout=None):
        x = self.mha(Q=x, K=x, V=x, dropout=dropout)
        x = self.ffn(x, dropout)
        return x

In [129]:
d_model = 512
num_head = 8

batch_size = 1
n = 16
m = 32
Q = torch.rand([batch_size, n, d_model])

encode_block = EncodeBlock(d_model, num_head)
encode_out = encode_block(Q)
print(encode_out)
print(encode_out.size())
assert(encode_out.size() == torch.Size([batch_size, n, d_model]))

tensor([[[-0.5173,  1.3850,  0.3295,  ...,  1.6806,  0.7333, -1.0889],
         [-0.6384,  1.8470, -1.2366,  ..., -0.7849,  0.2638, -1.1329],
         [ 0.0967, -0.2581, -0.5965,  ..., -0.3920,  1.5811, -1.4578],
         ...,
         [-0.0209,  1.1334, -0.2954,  ...,  1.5240,  0.9521, -2.5016],
         [-0.8310,  0.9661,  0.4257,  ...,  0.1001, -0.1922, -0.9898],
         [ 0.3043, -0.6394,  0.2250,  ...,  1.1292,  0.7230, -0.0559]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 16, 512])


![decode_block](img\decode_block.png)

In [130]:
def sequence_mask(token_len):
    mask = torch.ones([token_len, token_len])
    tril_mask = torch.tril(mask) == 0
    return mask.masked_fill(tril_mask, value=-1e-9).unsqueeze(0)

mask = sequence_mask(4)
print(mask)

tensor([[[ 1.0000e+00, -1.0000e-09, -1.0000e-09, -1.0000e-09],
         [ 1.0000e+00,  1.0000e+00, -1.0000e-09, -1.0000e-09],
         [ 1.0000e+00,  1.0000e+00,  1.0000e+00, -1.0000e-09],
         [ 1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00]]])


In [131]:
class DecodeBlock(nn.Module):
    def __init__(self, d_model=512, num_head=8):
        super().__init__()

        self.mha = ResMHA(d_model, num_head)
        self.ffn = ResFFN(d_model)

    def forward(self, x, encoder_out_kv, mask=None, dropout=None):
        if mask is None:
            mask = sequence_mask(x.size(-2))

        q = self.mha(Q=x, K=x, V=x, mask=mask, dropout=dropout)
        x = self.mha(Q=q, K=encoder_out_kv, V=encoder_out_kv, dropout=dropout)
        x = self.ffn(x)
        return x

In [132]:
d_model = 512
num_head = 8

batch_size = 1
n = 16
m = 32
Q = torch.rand([batch_size, n, d_model])

encode_block = EncodeBlock(d_model, num_head)
encode_out = encode_block(Q)

decode_block = DecodeBlock(d_model, num_head)
decode_out = decode_block(Q, encode_out)
print(decode_out)
print(decode_out.size())
assert(decode_out.size() == torch.Size([batch_size, n, d_model]))

tensor([[[ 1.3745, -0.1947,  0.1542,  ...,  0.6360,  0.6947, -1.8688],
         [ 1.6771, -0.3148, -0.9900,  ..., -0.6673,  1.1267, -1.0726],
         [ 1.6385, -1.1807,  0.5586,  ...,  1.5278,  0.3707, -1.2273],
         ...,
         [ 0.1696, -0.5213, -0.2874,  ...,  1.8333,  0.9122, -1.6084],
         [ 1.0831, -0.2956, -0.9377,  ...,  0.9470,  0.1393, -0.9638],
         [ 1.8224, -1.7668, -0.7125,  ..., -0.6468,  0.7662, -0.8900]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 16, 512])


![overall](img\overall_structure.png)

## Positional Encoding

![pe](img\pe.png)

In [133]:
def position_encoding(token_num, d_model):
    pe = torch.zeros([token_num, d_model])
    pos = torch.arange(0, token_num).repeat(d_model, 1).T
    power = 2.0 / d_model * torch.arange(0, d_model)
    pe[:, ::2] = torch.sin(pos[:, ::2] / torch.pow(10000, power[::2]))
    pe[:, 1::2] = torch.sin(pos[:, 1::2] / torch.pow(10000, power[1::2]))
    return pe

print(position_encoding(64, 512))

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 8.4147e-01,  8.2186e-01,  8.0196e-01,  ...,  1.1140e-08,
          1.0746e-08,  1.0366e-08],
        [ 9.0930e-01,  9.3641e-01,  9.5814e-01,  ...,  2.2279e-08,
          2.1492e-08,  2.0733e-08],
        ...,
        [-9.6612e-01,  7.4857e-01,  2.1454e-01,  ...,  6.7952e-07,
          6.5551e-07,  6.3235e-07],
        [-7.3918e-01, -1.1848e-01,  9.1145e-01,  ...,  6.9066e-07,
          6.6626e-07,  6.4271e-07],
        [ 1.6736e-01, -8.8357e-01,  8.7441e-01,  ...,  7.0180e-07,
          6.7700e-07,  6.5308e-07]])


## Encoder

In [134]:
class Encoder(nn.Module):
    def __init__(self, num_layer=6, d_model=512, num_head=8):
        super().__init__()

        self.num_layer = num_layer
        self.d_model = d_model
        self.num_head = num_head

        self.encode_blocks = [EncodeBlock(d_model, num_head)] * num_layer

    def forward(self, x, embedding: nn.Embedding, pe: torch.Tensor, dropout=None):
        x = embedding(x) + pe
        if dropout:
            x = dropout(x)
        outputs = []
        for block in self.encode_blocks:
            x = block(x, dropout=dropout)
            outputs.append(x)
        return outputs

## Decoder

In [135]:
class Decoder(nn.Module):
    def __init__(self, num_layer=6, d_model=512, num_head=8):
        super().__init__()

        self.num_layer = num_layer
        self.d_model = d_model
        self.num_head = num_head

        self.decode_blocks = [DecodeBlock(d_model, num_head)] * num_layer

    def forward(self, x, embedding: nn.Embedding, pe: torch.Tensor, encoder_out_kv, mask=None, dropout=None):
        x = embedding(x) * math.sqrt(self.d_model) + pe
        if dropout:
            x = dropout(x)
        for block, kv in zip(self.decode_blocks, encoder_out_kv):
            x = block(x, encoder_out_kv=kv, mask=mask, dropout=dropout)
        return x

## Transformer

In [141]:
class Transformer(nn.Module):
    def __init__(self, n_vocab, token_len, num_layer=6, d_model=512, num_head=8):
        super().__init__()

        self.n_vocab = n_vocab
        self.num_layer = num_layer
        self.d_model = d_model
        self.num_head = num_head

        self.embedding = nn.Embedding(n_vocab, d_model)
        self.pe = position_encoding(token_len, d_model)
        self.dropout = nn.Dropout(p=0.1)
        self.encoder = Encoder(num_layer, d_model, num_head)
        self.decoder = Decoder(num_layer, d_model, num_head)
        self.vocab_linear = nn.Linear(d_model, n_vocab)
        self.vocab_act = nn.ReLU()

    def forward(self, enc_x, dec_x, mask=None):
        enc_kv = self.encoder(enc_x, self.embedding, self.pe, dropout=self.dropout)
        x = self.decoder(dec_x, self.embedding, self.pe, enc_kv, mask=mask, dropout=self.dropout)
        x = self.vocab_act(self.vocab_linear(x))
        return F.softmax(x, dim=-1)

In [145]:
n_vocab = 1024
token_len = 32
model = Transformer(n_vocab, token_len)

enc_x = torch.randint(low=0, high=n_vocab, size=[1, token_len])
dec_x = torch.randint(low=0, high=n_vocab, size=[1, token_len])
mask = sequence_mask(dec_x.size(-1))
output = model(enc_x, dec_x, mask=mask)
print(output)
print(output.size())

tensor([[[0.0007, 0.0010, 0.0007,  ..., 0.0011, 0.0007, 0.0007],
         [0.0007, 0.0014, 0.0007,  ..., 0.0007, 0.0007, 0.0012],
         [0.0007, 0.0016, 0.0007,  ..., 0.0010, 0.0007, 0.0010],
         ...,
         [0.0007, 0.0014, 0.0007,  ..., 0.0010, 0.0007, 0.0007],
         [0.0007, 0.0015, 0.0007,  ..., 0.0008, 0.0007, 0.0020],
         [0.0007, 0.0009, 0.0010,  ..., 0.0009, 0.0007, 0.0014]]],
       grad_fn=<SoftmaxBackward0>)
torch.Size([1, 32, 1024])
