### 1. 🚀了解公式
![self-attention公式](./pics/self-attention.png)

### 2. 代码实现
### · SelfAttentionV1 公式实现

In [7]:
import math
import torch
import torch.nn as nn

class Self_AttentionV1(nn.Module):
    def __init__(self, hidden_dim: int = 728) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim

        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, x):
        # x shape: [batch_size, seq_len, hidden_dim]
        Q = self.query_proj(x)
        K = self.key_proj(x)
        V = self.value_proj(x)

        # Q, K, V shape: [batch_size, seq_len, hidden_dim]
        # compute attention scores 
        # attebtion_value shape: [batch_size, seq_len, seq_len]
        attention_value = torch.matmul(
            # K reshape: [batch_size, seq_len, hidden_dim] -> [batch_size, hidden_dim, seq_len]
            Q, K.transpose(-1, -2)
        )

        # compute attention weights
        # attention_weights shape: [batch_size, seq_len, seq_len]
        attention_weight = torch.softmax(
            attention_value / math.sqrt(self.hidden_dim),
            dim = -1
        )

        # compute attention output
        # attention_output shape: [batch_size, seq_len, hidden_dim]
        output = torch.matmul(
            attention_weight, V
        )

        return output

x = torch.randn(3 ,2, 4)

self_att_net = Self_AttentionV1(4)

self_att_net(x)

tensor([[[ 0.5010, -1.4322, -0.2414,  0.1632],
         [ 0.5117, -1.5739, -0.2755,  0.1640]],

        [[-0.4387,  0.4075, -0.5785,  0.1173],
         [-0.4382,  0.4043, -0.5777,  0.1205]],

        [[ 0.5927, -0.7009, -0.0076, -0.5356],
         [ 0.6485, -0.7573,  0.0375, -0.7019]]], grad_fn=<UnsafeViewBackward0>)

### · Self-AttentionV2 效率优化

In [8]:
class Self_AttentionV2(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        self.proj = nn.Linear(dim, dim * 3)

    def forward(self, x):
        # x shape : (batch_size, seq_len, dim)
        # QKV shape : (batch_size, seq_len, dim * 3)
        QKV = self.proj(x)
        Q, K, V = torch.split(QKV, self.dim, dim = -1)
        att_weight = torch.softmax(
            torch.matmul(
                Q, K.transpose(-2, -1)
            ) / math.sqrt(self.dim), dim = -1
        )

        # @ == torch.matmul
        output = att_weight @ V
        return output
        
x = torch.randn(3 ,2, 4)
self_att_netV2 = Self_AttentionV2(4)
self_att_netV2(x)

tensor([[[-0.2136,  0.2236,  0.6926,  0.2348],
         [-0.1595,  0.1828,  0.6793,  0.1907]],

        [[-0.3588,  0.5055, -0.2845,  0.2861],
         [-0.3767,  0.5226, -0.2824,  0.3109]],

        [[-0.8789,  0.5472, -0.4421, -0.5628],
         [-0.9618,  0.2683,  0.0440, -0.0110]]], grad_fn=<UnsafeViewBackward0>)

### · Self-AttentionV3 细节加入


In [29]:
# 1.dropout position
# 2.attention mask
# 3.output 矩阵映射

class Self_AttentionV3(nn.Module):
    def __init__(self, dim, Dropout_rate = 0.1):
        super().__init__()
        self.dim = dim

        self.proj = nn.Linear(dim, dim * 3)
        self.attention_dropout = nn.Dropout(Dropout_rate)

        self.output_proj = nn.Linear(dim, dim)

    def forward(self, x, attention_mask = None):
        QKV = self.proj(x)
        # x shape: [B, seq, D]
        Q, K, V = torch.split(QKV, self.dim, dim = -1)

        # attention_weight shape: [B, seq, seq]
        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)

        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, 
                float("-1e20")
            )
        
        attention_weight = torch.softmax(
            attention_weight,
            dim = -1
        )

        # dropout the attention weight
        attention_weight = self.attention_dropout(attention_weight)
        attention_result = attention_weight @ V

        # output 
        output = self.output_proj(attention_result)
        return output

x = torch.randn(3 ,4, 2)
# mask shape: [B, seq]
mask = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]
    ]
)
print(mask.shape)
# mask shape: [B, seq, seq]
mask = mask.unsqueeze(dim = 1).repeat(1, 4, 1)
print("repeat mask shape: ", mask.size())

self_att_netV3 = Self_AttentionV3(2)
self_att_netV3(x, mask)

torch.Size([3, 4])
repeat mask shape:  torch.Size([3, 4, 4])


tensor([[[-0.3208, -0.6199],
         [-0.6620, -0.5617],
         [-0.5866, -0.5632],
         [-0.6273, -0.5385]],

        [[-1.2397, -1.0820],
         [-1.1849, -1.0373],
         [-1.1710, -1.0259],
         [-0.3342, -0.6982]],

        [[-1.0509, -0.7904],
         [-1.0509, -0.7904],
         [-1.0509, -0.7904],
         [-1.0509, -0.7904]]], grad_fn=<ViewBackward0>)

### · Self-Attention interview

In [None]:
class Self_AttentionV4(nn.Module):
    def __init__(self, dim, dropout_rate = 0.1):
        super().__init__()
        self.dim = dim

        self.Q = nn.Linear(dim, dim)
        self.K = nn.Linear(dim, dim)
        self.V = nn.Linear(dim, dim)

        self.attention_dropout = nn.Dropout(dropout_rate)

    def forward(self, x, atten_mask = None):
        # [b,s,d]
        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)

        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if atten_mask is not None:
            attention_weight = attention_weight.masked_fill(
                atten_mask == 0,
                float("-inf")
            )
        
        attention_weight = torch.softmax(
            attention_weight,
            dim = -1
        )
        
        attention_weight = self.attention_dropout(attention_weight)

        output = attention_weight @ V

        return output

x = torch.rand(3, 4, 2)

mask = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]
    ]
)

repeat_mask = mask.unsqueeze(dim = 1).repeat(1, 4, 1)

net = Self_AttentionV4(2)
net(x, repeat_mask)

tensor([[[0.3428, 0.3319, 0.3253, 0.0000],
         [0.3513, 0.3294, 0.3193, 0.0000],
         [0.3465, 0.3309, 0.3226, 0.0000],
         [0.3500, 0.3297, 0.3203, 0.0000]],

        [[0.5091, 0.4909, 0.0000, 0.0000],
         [0.5079, 0.4921, 0.0000, 0.0000],
         [0.5110, 0.4890, 0.0000, 0.0000],
         [0.5091, 0.4909, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)


tensor([[[ 0.0035,  0.2287],
         [ 0.0671,  0.3224],
         [ 0.1035,  0.4267],
         [ 0.1388,  0.3070]],

        [[-0.0160,  0.2526],
         [ 0.0379,  0.4628],
         [ 0.0374,  0.4630],
         [ 0.0377,  0.4629]],

        [[ 0.0437,  0.4510],
         [ 0.0437,  0.4510],
         [ 0.0437,  0.4510],
         [ 0.0437,  0.4510]]], grad_fn=<UnsafeViewBackward0>)