# Self-Attention

Paper: `Transformer` Attention is All you need (NIPS 2017)

Code:
- [官方TensorFlow实现](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py)
- [Pytorch实现](https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/SubLayers.py)

Reference:
- `Enzo_Mi` [Multi-Head Attention | 算法 + 代码](https://www.bilibili.com/video/BV1qo4y1F7Ep)
- `黑白` [Transformer代码及解析(Pytorch)](https://zhuanlan.zhihu.com/p/345993564)
- `于建民` [The Illustrated Transformer【译】](https://blog.csdn.net/yujianmin1990/article/details/85221271)
- `Jay Alammar` [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)

In [None]:
%%time
import torch
import torch.nn.functional as F
query = torch.rand((10, 32, 512))
key = query
value = query
attn = F.scaled_dot_product_attention(query, key, value)

print("query \n", query.shape)
print("attn \n", attn.shape)

### MAC 上也可以用GPU加速了

In [None]:
%%time
import torch
import torch.nn.functional as F

device = torch.device("mps")

query = torch.rand((10, 32, 512), device = device)
key = query
value = query
attn = F.scaled_dot_product_attention(query, key, value)

print("query \n", query.shape)
print("attn \n", attn.device)

In [None]:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
    def __init__(self, dim, dk, dv):
        super(SelfAttention, self).__init__()
        self.scale = dk ** -0.5
        self.q = nn.Linear(dim, dk)
        self.k = nn.Linear(dim, dk)
        self.v = nn.Linear(dim, dv)
    
    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        
        attn = q @ k.transpose(-2,-1) * self.scale
        attn = attn.softmax(dim=-1)
        
        x = attn @ v
        return x

att = SelfAttention(dim=2,dk=2,dv=3)
x = torch.rand((1,4,2))
output = att(x)
print(x, '\n', output)