## 3.4 实现带可训练权重的自注意力机制

In [1]:
import torch

In [2]:
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],   # Your
        [0.55, 0.87, 0.66],   # journey
        [0.57, 0.85, 0.64],   # starts
        [0.22, 0.58, 0.33],   # with
        [0.77, 0.25, 0.10],   # one
        [0.05, 0.80, 0.55],   # step
    ]
)

In [6]:
x_2 = inputs[1]
d_in = inputs.shape[1]  # 输入嵌入维度 = 3
d_out = 2  # 输出嵌入维度 = 2

In [19]:
torch.manual_seed(123) # 确保输出结果的稳定性
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [20]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


#### 计算所有的键向量和值向量

In [23]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


#### 计算第 2 个输入的注意力分数

In [27]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(f"keys_2: {keys_2}, query_2: {query_2}, attn_score_22: {attn_score_22}")

attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

keys_2: tensor([0.4433, 1.1419]), query_2: tensor([0.4306, 1.4551]), attn_score_22: 1.852384328842163
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


#### 缩放点积注意力，避免维度过大造成点积过大而梯度消失

In [28]:
d_k= keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


#### 计算上下文向量

In [29]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


### 完整 v1

In [35]:
import torch
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out)) # 这里 requires_grad 默认为 True，表示会在反向传播中更新梯度
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        # 1. 分别计算 query、key、value 三个维度的权重矩阵嵌入向量
        keys = x @ self.W_key
        queires = x @ self.W_query
        values = x @ self.W_value
        # 2. 注意力分数使用 querys 和 keys 来计算，querys 表示要查询东西，keys 表示我是谁
        # querys 说我要什么，keys 说我是什么，二者结合起来，就能得到这个 query 跟所有 keys 的注意力分配
        attn_scores = queires @ keys.T
        # 3. 归一化注意力分数、同时使用 **0.5(即开方) 来缩放点积，避免梯度消失
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        # 4. 最后计算上下文向量，表示要查询到什么东西，所以用 values
        context_vec = attn_weights @ values
        return context_vec

In [36]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


### 使用 nn.Linear 层优化 v1 的实现
> 当偏置单元被禁用时，nn.Linear 层可以有效执行矩阵算法。
> 它的优势是提供了优化的权重初始化方案，从而有助于模型训练的稳定性和有效性

In [40]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        # 1. 计算 3 个维度的矩阵嵌入向量
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # 2. 计算注意力分数
        attn_scores = queries @ keys.T

        # 3. 注意力分数缩放与归一化
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1,
        )

        # 4. 计算上下文向量
        context_vec = attn_weights @ values
        return context_vec

In [39]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
