In [2]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

### 第一重 直接写SelfAttention

In [20]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_dim: int = 756):
        super().__init__()
        self.hidden_dim = hidden_dim

        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        attention_value = torch.matmul(
            Q ,K.transpose(-2, -1))

        attention_weight = torch.softmax(
            attention_value / math.sqrt(self.hidden_dim), dim=-1
        )
        
        output = torch.matmul(
            attention_weight, V)
        
        return output

In [22]:
x = torch.rand(3, 2, 756)

self_att = SelfAttention(756)
print(self_att(x))

tensor([[[-0.2301, -0.5728,  0.3617,  ..., -0.2509, -0.0339,  0.0432],
         [-0.2300, -0.5731,  0.3601,  ..., -0.2506, -0.0339,  0.0424]],

        [[-0.0930, -0.4355,  0.1877,  ...,  0.1289, -0.0683, -0.0112],
         [-0.0965, -0.4334,  0.1866,  ...,  0.1349, -0.0671, -0.0063]],

        [[-0.3861, -0.2914,  0.2270,  ...,  0.0846, -0.1521,  0.0388],
         [-0.3864, -0.2915,  0.2269,  ...,  0.0847, -0.1526,  0.0390]]],
       grad_fn=<UnsafeViewBackward0>)


### 第二重 效率优化

In [29]:
class SelfAttention_2(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.proj = nn.Linear(hidden_dim, hidden_dim*3)

    def forward(self, x):
        QKV = self.proj(x)
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)
        atten_value = torch.softmax(Q @ K.transpose(-2, -1) / math.sqrt(self.hidden_dim), dim = -1)
        output = atten_value @ V
        return output

x = torch.randn(3, 2, 4)
self_att = SelfAttention_2(4)
self_att(x)

tensor([[[-1.0527,  0.8444,  0.2570,  0.4923],
         [-1.0414,  0.3257, -0.1195, -0.1469]],

        [[ 0.0901, -0.2256, -1.1794, -1.4731],
         [ 0.2486, -0.0071, -0.9408, -1.0595]],

        [[-0.0482,  0.1720, -0.2060, -0.2855],
         [-0.1129,  0.2088, -0.1971, -0.2510]]], grad_fn=<UnsafeViewBackward0>)

### 第三重 加入一些细节