In [1]:
import math
import torch
import torch.nn as nn

In [None]:
'''第一个版本的self-attention'''

####################

class SelfAttentionV1(nn.Module):
    """self-attention的第一重境界"""
    def __init__(self, hidden_dim: int =728):
        super().__init__()
        self.hidden_dim = hidden_dim
        # * Q, K, V的权重矩阵
        # * 假设x的shape是(128, 20), nn.Linear的dim是(20, 30), 最终计算结果是(128, 30)
        self.query_Q_weight = nn.Linear(hidden_dim, hidden_dim)
        self.key_K_weight = nn.Linear(hidden_dim, hidden_dim)
        self.value_V_weight = nn.Linear(hidden_dim, hidden_dim)
        
        
    def forward(self, x):
        # * X shape is : (batch_size, seq_len, hidden_dim)
        Q_matrix = self.query_Q_weight(x)
        K_matrix = self.key_K_weight(x)
        V_matrix = self.value_V_weight(x)
        
        # * Q, K, V shape is : (batch, seq_len, hidden_dim)
        # * Attention weights shape is : (batch, seq_len, seq_len)
        
        attention_matrix = torch.matmul(Q_matrix, K_matrix.transpose(-1, -2))
        
        attention_weights = torch.softmax(attention_matrix / math.sqrt(self.hidden_dim), dim=-1)
        print(attention_weights)
        # * output shape is : (batch, seq_len, hidden_dim)
        
        output = torch.matmul(attention_weights, V_matrix)
        
                
        return output



X = torch.rand(3, 2, 4)

# print("X: {}".format(X))

attention_V1 = SelfAttentionV1(4)

attention_V1(X)

tensor([[[0.4960, 0.5040],
         [0.4848, 0.5152]],

        [[0.4944, 0.5056],
         [0.4923, 0.5077]],

        [[0.5221, 0.4779],
         [0.5274, 0.4726]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.1299, -0.6092, -0.4495, -0.2527],
         [-0.1283, -0.6098, -0.4518, -0.2592]],

        [[-0.2092, -0.8641, -0.7713, -0.2240],
         [-0.2096, -0.8637, -0.7708, -0.2232]],

        [[-0.1089, -0.6132, -0.4714, -0.2210],
         [-0.1097, -0.6106, -0.4687, -0.2207]]], grad_fn=<UnsafeViewBackward0>)

In [None]:
"""对原本的self-attention计算进行效率优化"""


class SelfAttentionV2(nn.Module):
    """Some Information about SelfAttentionV2"""
    def __init__(self, hidden_dim):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        # * 将Q, K, V3个矩阵合并为一个大矩阵进行计算
        self.cal_proj_weight = nn.Linear(hidden_dim, hidden_dim * 3)

    def forward(self, x):
        # * X shape is : (batch, seq, dim)
        # ! 这种方式只适用于模型比较小的时候
        QKV_matrix = self.cal_proj_weight(x)
        Q_matrix, K_matrix, V_matrix = torch.split(QKV_matrix, self.hidden_dim, dim=-1)
        
        
        # * Q, K, V shape is : (batch, seq_len, hidden_dim)
        # * Attention weights shape is : (batch, seq_len, seq_len)
        
        attention_matrix = torch.matmul(Q_matrix, K_matrix.transpose(-1, -2))
        
        attention_weights = torch.softmax(attention_matrix / math.sqrt(self.hidden_dim), dim=-1)
        print(attention_weights)
        # * output shape is : (batch, seq_len, hidden_dim)
        
        output = torch.matmul(attention_weights, V_matrix)
        
        
        
        
        return output


X = torch.rand(3, 2, 4)

attention_V2 = SelfAttentionV2(4)

attention_V2(X)

tensor([[[0.5215, 0.4785],
         [0.5253, 0.4747]],

        [[0.5161, 0.4839],
         [0.5132, 0.4868]],

        [[0.5149, 0.4851],
         [0.5078, 0.4922]]], grad_fn=<SoftmaxBackward0>)


tensor([[[ 0.4113, -0.3819, -0.2678, -0.0727],
         [ 0.4110, -0.3830, -0.2681, -0.0723]],

        [[ 0.2601, -0.1874, -0.0258, -0.2135],
         [ 0.2594, -0.1877, -0.0255, -0.2140]],

        [[ 0.7261, -0.3526, -0.4820, -0.0516],
         [ 0.7245, -0.3512, -0.4798, -0.0526]]], grad_fn=<UnsafeViewBackward0>)

In [None]:
'''加入一些关于self-attention的细节'''

# * 1. Dropout的位置
# * 2. Attention mask, 因为在实际应用过程中，sequence的长度可能是不一样的
# * 3. output 矩阵映射（可选）

class SelfAttentionV3(nn.Module):
    """Some Information about SelfAttentionV3"""
    def __init__(self, hiddendim, dropout_rate=0.1):
        super().__init__()
        self.hiddendim = hiddendim
        self.cal_matrix_weight = nn.Linear(hiddendim, hiddendim * 3)
        self.dropout_layer = nn.Dropout(dropout_rate) # * 定义dropout
        
        # * 可选
        self.output_mapping = nn.Linear(hiddendim, hiddendim)
    def forward(self, x, attention_mask=None):
        QKV_matrix = self.cal_matrix_weight(x)
        Q_matrix, K_matrix, V_matrix = torch.split(QKV_matrix, self.hiddendim, dim=-1)
        
        # * attention_matrix和attention_mask shape is (batch, seq, seq)
        attention_matrix = Q_matrix @ K_matrix.transpose(-1, -2) / math.sqrt(self.hiddendim)
        
        # * mask要加在计算softmax之前
        if attention_mask is not None:
            attention_matrix = attention_matrix.masked_fill(attention_mask == 0, float("-1e20"))
             
        # * Attention的Dropout是在Attention层面进行Dropout，因此将Dropout层加在计算完attention_weights之后
        attention_weight = torch.softmax(attention_matrix, dim=-1)
        print(attention_weight)
        attention_weight = self.dropout_layer(attention_weight)
        
        
        
        output = attention_weight @ V_matrix
        
        
        output = self.output_mapping(output)
            
    
        return output
    


X = torch.rand(3, 4, 2)
mask = torch.tensor([[1,1,1,0],[1,1,0,0], [1,0,0,0]])

print(mask.shape)

mask = mask.unsqueeze(dim=1).repeat(1, 4, 1)
print(mask.shape)
attention_V3 = SelfAttentionV3(2)

attention_V3(X, mask)



torch.Size([3, 4])
torch.Size([3, 4, 4])
tensor([[[0.3364, 0.3375, 0.3261, 0.0000],
         [0.3198, 0.3440, 0.3362, 0.0000],
         [0.3247, 0.3414, 0.3339, 0.0000],
         [0.3294, 0.3390, 0.3315, 0.0000]],

        [[0.5025, 0.4975, 0.0000, 0.0000],
         [0.4829, 0.5171, 0.0000, 0.0000],
         [0.4916, 0.5084, 0.0000, 0.0000],
         [0.4966, 0.5034, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.3273, -0.6831],
         [-0.4053, -0.5083],
         [-0.3276, -0.6836],
         [-0.3275, -0.6833]],

        [[-0.2727, -0.7993],
         [-0.2698, -0.8035],
         [-0.2711, -0.8017],
         [-0.2718, -0.8006]],

        [[-0.3080, -0.6561],
         [-0.3080, -0.6561],
         [-0.3080, -0.6561],
         [-0.3080, -0.6561]]], grad_fn=<ViewBackward0>)

In [21]:
'''实际面试时的写法'''

class SelfAttentionV4(nn.Module):
    """Some Information about SelfAttentionV4"""
    def __init__(self, hiddendim, dropout_rate=0.1):
        super().__init__()
        
        self.hiddendim = hiddendim
        
        self.query = nn.Linear(hiddendim, hiddendim)
        self.key = nn.Linear(hiddendim, hiddendim)
        self.value = nn.Linear(hiddendim, hiddendim)
        
        self.dropout_layer = nn.Dropout(dropout_rate)
    def forward(self, x, attention_mask=None):
        
        Q_matrix = self.query(x)
        K_matrix = self.key(x)
        V_matrix = self.value(x)
        
        attention_matrix = Q_matrix @ K_matrix.transpose(-1, -2) / math.sqrt(self.hiddendim)
        
        if attention_mask is not None:
            attention_matrix = attention_matrix.masked_fill(attention_mask == 0, float("-inf"))
        
        # * (batch, seq, seq)
        attention_weights = torch.softmax(attention_matrix, dim=-1)
        print(attention_weights)
        attention_weights = self.dropout_layer(attention_weights)
        
        # * (batch, seq, hiddendim)
        output = attention_weights @ V_matrix
        
        
        return output

X = torch.rand(3, 4, 2)

mask = torch.tensor([[1,1,1,0],[1,1,0,0],[1,0,0,0]])

mask = mask.unsqueeze(dim=1).repeat(1, 4, 1)

attention_V4 = SelfAttentionV4(2)

attention_V4(X, mask)

tensor([[[0.3425, 0.3439, 0.3136, 0.0000],
         [0.3406, 0.3431, 0.3163, 0.0000],
         [0.3471, 0.3366, 0.3164, 0.0000],
         [0.3427, 0.3436, 0.3137, 0.0000]],

        [[0.4733, 0.5267, 0.0000, 0.0000],
         [0.4705, 0.5295, 0.0000, 0.0000],
         [0.4783, 0.5217, 0.0000, 0.0000],
         [0.4637, 0.5363, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)


tensor([[[ 0.5100, -0.8924],
         [ 0.5101, -0.8929],
         [ 0.5123, -0.8970],
         [ 0.5101, -0.8926]],

        [[ 0.3029, -0.4966],
         [ 0.3026, -0.4956],
         [ 0.3036, -0.4985],
         [ 0.3017, -0.4929]],

        [[ 0.4814, -0.9871],
         [ 0.0000,  0.0000],
         [ 0.4814, -0.9871],
         [ 0.4814, -0.9871]]], grad_fn=<UnsafeViewBackward0>)