In [15]:
import torch
from torch import nn

### 缩放点积注意力

In [2]:
inputs = torch.tensor( 
 [[0.43, 0.15, 0.89], # Your (x^1) 
 [0.55, 0.87, 0.66], # journey (x^2) 
 [0.57, 0.85, 0.64], # starts (x^3) 
 [0.22, 0.58, 0.33], # with (x^4) 
 [0.77, 0.25, 0.10], # one (x^5) 
 [0.05, 0.80, 0.55]] # step (x^6) 
)

In [3]:
# 第二个作为查询
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [7]:
# 初始化Wq，Wk，Wv
torch.manual_seed(123)
w_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [9]:
# 计算q，k，v
query_2 = x_2 @ w_query
key_2 = x_2 @ w_key
value_2 = x_2 @ w_value
print(query_2)

tensor([0.4306, 1.4551])


In [10]:
# 所有输入元素的k，v
keys = inputs @ w_key
values = inputs @ w_value
print("keys shape: ", keys.shape)
print("values shape: ", values.shape)

keys shape:  torch.Size([6, 2])
values shape:  torch.Size([6, 2])


In [None]:
# 计算注意力分数 q*v.T
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [13]:
# softmax得到注意力权重
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1) 
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [14]:
# 计算上下文 weights * v
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


# 构建子注意力类

In [16]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.w_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) 
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    def forward(self, x):
        q = x @ w_query
        k = x @ w_key
        v = x @ w_value
        attn_scores = q @ k.T
        attn_weughts = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        context_vec = attn_weughts @ v
        return context_vec

In [18]:
torch.manual_seed(789) 
sa_v2 = SelfAttention(d_in, d_out) 
print(sa_v2(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])
