In [2]:
import torch
import torch.nn as nn
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
d_in, d_out = 3, 2

In [None]:
# nn.Parameter版
class SelfAttention_v1(nn.Module):
    def __init__(self,d_in,d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in,d_out))
        self.W_key = nn.Parameter(torch.rand(d_in,d_out))
        self.W_value = nn.Parameter(torch.rand(d_in,d_out))
    def forward(self,x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value
        attn_scores = queries @ keys.transpose(-1,-2)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        context_vecs = attn_weights @ values
        return context_vecs
sa1 = SelfAttention_v1(d_in,d_out)
context_vecs = sa1(inputs)
print(context_vecs) # 每次都会变

tensor([[0.6853, 1.0605],
        [0.7021, 1.0783],
        [0.7015, 1.0777],
        [0.6893, 1.0656],
        [0.6823, 1.0582],
        [0.6958, 1.0724]], grad_fn=<MmBackward0>)


In [12]:
# nn.Linear版
class SelfAttention_v2(nn.Module):
    def __init__(self,d_in,d_out,qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
    def forward(self,x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(-1,-2)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        context_vecs = attn_weights @ values
        return context_vecs
torch.manual_seed(123) 
sa2 = SelfAttention_v2(d_in,d_out)
context_vecs = sa2(inputs)
print(context_vecs) 

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


In [None]:
# 验证weight矩阵的nn.Parameter版
class SelfAttention_v1(nn.Module):
    def __init__(self,d_in,d_out):
        super().__init__()
        self.W_query = nn.Parameter(sa2.W_query.weight.transpose(-1,-2))
        self.W_key = nn.Parameter(sa2.W_key.weight.transpose(-1,-2))
        self.W_value = nn.Parameter(sa2.W_value.weight.transpose(-1,-2))
    def forward(self,x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value
        attn_scores = queries @ keys.transpose(-1,-2)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        context_vecs = attn_weights @ values
        return context_vecs
sa1 = SelfAttention_v1(d_in,d_out)
context_vecs = sa1(inputs)
print(context_vecs) 

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)
