In [None]:
'''第一个版本的self-attention'''

####################
import math
import torch
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    """self-attention的第一重境界"""
    def __init__(self, hidden_dim: int =728):
        super().__init__()
        self.hidden_dim = hidden_dim
        # * Q, K, V的权重矩阵
        # * 假设x的shape是(128, 20), nn.Linear的dim是(20, 30), 最终计算结果是(128, 30)
        self.query_Q_weight = nn.Linear(hidden_dim, hidden_dim)
        self.key_K_weight = nn.Linear(hidden_dim, hidden_dim)
        self.value_V_weight = nn.Linear(hidden_dim, hidden_dim)
        
        
    def forward(self, x):
        # * X shape is : (batch_size, seq_len, hidden_dim)
        Q_matrix = self.query_Q_weight(x)
        K_matrix = self.key_K_weight(x)
        V_matrix = self.value_V_weight(x)
        
        # * Q, K, V shape is : (batch, seq_len, hidden_dim)
        # * Attention weights shape is : (batch, seq_len, seq_len)
        
        attention_matrix = torch.matmul(Q_matrix, K_matrix.transpose(-1, -2))
        
        attention_weights = torch.softmax(attention_matrix / math.sqrt(self.hidden_dim), dim=-1)
        print(attention_weights)
        # * output shape is : (batch, seq_len, hidden_dim)
        
        output = torch.matmul(attention_weights, V_matrix)
        
                
        return output



X = torch.rand(3, 2, 4)

# print("X: {}".format(X))

attention_V1 = SelfAttentionV1(4)

attention_V1(X)

tensor([[[0.5218, 0.4782],
         [0.5193, 0.4807]],

        [[0.5030, 0.4970],
         [0.5044, 0.4956]],

        [[0.4886, 0.5114],
         [0.4859, 0.5141]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.6099, -0.1530, -0.8041,  0.2956],
         [-0.6100, -0.1535, -0.8045,  0.2958]],

        [[-0.3163, -0.0801, -0.2249,  0.3892],
         [-0.3164, -0.0800, -0.2249,  0.3891]],

        [[-0.3718, -0.0180, -0.4918,  0.2616],
         [-0.3723, -0.0169, -0.4922,  0.2607]]], grad_fn=<UnsafeViewBackward0>)

In [None]:
"""对原本的self-attention计算进行效率优化"""


class SelfAttentionV2(nn.Module):
    """Some Information about SelfAttentionV2"""
    def __init__(self, hidden_dim):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        # * 将Q, K, V3个矩阵合并为一个大矩阵进行计算
        self.cal_proj_weight = nn.Linear(hidden_dim, hidden_dim * 3)

    def forward(self, x):
        # * X shape is : (batch, seq, dim)
        # ! 这种方式只适用于模型比较小的时候
        QKV_matrix = self.cal_proj_weight(x)
        Q_matrix, K_matrix, V_matrix = torch.split(QKV_matrix, self.hidden_dim, dim=-1)
        
        
        # * Q, K, V shape is : (batch, seq_len, hidden_dim)
        # * Attention weights shape is : (batch, seq_len, seq_len)
        
        attention_matrix = torch.matmul(Q_matrix, K_matrix.transpose(-1, -2))
        
        attention_weights = torch.softmax(attention_matrix / math.sqrt(self.hidden_dim), dim=-1)
        print(attention_weights)
        # * output shape is : (batch, seq_len, hidden_dim)
        
        output = torch.matmul(attention_weights, V_matrix)
        
        
        
        
        return output


X = torch.rand(3, 2, 4)

attention_V2 = SelfAttentionV2(4)

attention_V2(X)

tensor([[[0.5215, 0.4785],
         [0.5253, 0.4747]],

        [[0.5161, 0.4839],
         [0.5132, 0.4868]],

        [[0.5149, 0.4851],
         [0.5078, 0.4922]]], grad_fn=<SoftmaxBackward0>)


tensor([[[ 0.4113, -0.3819, -0.2678, -0.0727],
         [ 0.4110, -0.3830, -0.2681, -0.0723]],

        [[ 0.2601, -0.1874, -0.0258, -0.2135],
         [ 0.2594, -0.1877, -0.0255, -0.2140]],

        [[ 0.7261, -0.3526, -0.4820, -0.0516],
         [ 0.7245, -0.3512, -0.4798, -0.0526]]], grad_fn=<UnsafeViewBackward0>)

In [None]:
'''加入一些关于self-attention的细节'''

# * 1. Dropout的位置
# * 2. Attention mask, 因为在实际应用过程中，sequence的长度可能是不一样的
# * 3. output 矩阵映射（可选）
