In [2]:
import torch
import torch.nn as nn

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        # 使用 nn.Linear 自动管理权重和 bias
        self.query_proj = nn.Linear(d_in, d_out)
        self.key_proj   = nn.Linear(d_in, d_out)
        self.value_proj = nn.Linear(d_in, d_out)

    def forward(self, x):
        # 变换成 Q K V
        Q = self.query_proj(x)   # shape: [tokens, d_out]
        K = self.key_proj(x)
        V = self.value_proj(x)

        # 注意力权重计算（QK^T / √d）
        attn_scores = Q @ K.transpose(-2, -1)
        attn_weights = torch.softmax(attn_scores / (K.shape[-1] ** 0.5), dim=-1)

        # 加权聚合 V
        context = attn_weights @ V

        return context, attn_weights


In [4]:
num_tokens = 4
d_in = 5
d_out = 8

x = torch.rand(num_tokens, d_in)
attn = SelfAttention_v2(d_in, d_out)

out, attn_weights = attn(x)
print("输出 shape:", out.shape)
print("注意力矩阵:\n", attn_weights)


输出 shape: torch.Size([4, 8])
注意力矩阵:
 tensor([[0.2441, 0.2636, 0.2368, 0.2554],
        [0.2526, 0.2618, 0.2458, 0.2398],
        [0.2458, 0.2631, 0.2365, 0.2547],
        [0.2482, 0.2612, 0.2392, 0.2514]], grad_fn=<SoftmaxBackward0>)
