# 建模长序列问题

In [2]:
import torch
inputs = torch.tensor([[0.43, 0.15, 0.89],
                      [0.55, 0.87, 0.66],
                      [0.57, 0.85, 0.66],
                      [0.22, 0.58, 0.33],
                      [0.77, 0.25, 0.10],
                      [0.05, 0.80, 0.55]])




In [3]:
query = inputs[1] # 第二个输入作为查询
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(query, x_i)

print(attn_scores_2)


tensor([0.9544, 1.4950, 1.4886, 0.8434, 0.7070, 1.0865])


In [4]:
res = 0. # 初始化结果
for idx, element in enumerate(inputs[0]):
    res += inputs[0][idx] * query[idx]

print(res)
print(torch.dot(inputs[0], query))

tensor(0.9544)
tensor(0.9544)


In [5]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print(attn_weights_2_tmp)
print(attn_weights_2_tmp.sum())

tensor([0.1452, 0.2274, 0.2264, 0.1283, 0.1075, 0.1652])
tensor(1.0000)


In [6]:
def softmax_native(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_native = softmax_native(attn_scores_2)
print(attn_weights_2_native)
print(attn_weights_2_native.sum())



tensor([0.1381, 0.2372, 0.2356, 0.1236, 0.1078, 0.1576])
tensor(1.)


## 计算所有输入标记的注意权重

In [7]:
attn_scores = torch.empty(6,6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i,j] = torch.dot(x_i, x_j)

print(attn_scores)


tensor([[0.9995, 0.9544, 0.9600, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4886, 0.8434, 0.7070, 1.0865],
        [0.9600, 1.4886, 1.4830, 0.8362, 0.7174, 1.0715],
        [0.4753, 0.8434, 0.8362, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7174, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0715, 0.6565, 0.2935, 0.9450]])


In [8]:
attn_scores = inputs @ inputs.T # 矩阵乘法
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9600, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4886, 0.8434, 0.7070, 1.0865],
        [0.9600, 1.4886, 1.4830, 0.8362, 0.7174, 1.0715],
        [0.4753, 0.8434, 0.8362, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7174, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0715, 0.6565, 0.2935, 0.9450]])


In [9]:
attn_weights = torch.softmax(attn_scores, dim=-1)

print(attn_weights)

tensor([[0.2091, 0.1999, 0.2010, 0.1238, 0.1216, 0.1446],
        [0.1381, 0.2372, 0.2356, 0.1236, 0.1078, 0.1576],
        [0.1395, 0.2366, 0.2353, 0.1232, 0.1094, 0.1559],
        [0.1433, 0.2071, 0.2056, 0.1460, 0.1261, 0.1718],
        [0.1526, 0.1958, 0.1978, 0.1366, 0.1878, 0.1295],
        [0.1381, 0.2179, 0.2146, 0.1417, 0.0986, 0.1891]])


In [10]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)

print("All rows sum:", attn_weights.sum(dim=1))

Row 2 sum: 1.0
All rows sum: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [11]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4425, 0.5940, 0.5832],
        [0.4423, 0.6521, 0.5732],
        [0.4434, 0.6504, 0.5730],
        [0.4306, 0.6301, 0.5553],
        [0.4671, 0.5911, 0.5306],
        [0.4181, 0.6508, 0.5690]])


### 自注意力机制

In [12]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [13]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)

print(W_query.shape)
print(W_key.shape)


torch.Size([3, 2])
torch.Size([3, 2])


In [14]:

q_2 = x_2 @ W_query
k_2 = x_2 @ W_key
v_2 = x_2 @ W_value

print(q_2)
print(k_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])


In [15]:
keys = inputs @ W_key
values = inputs @ W_value

print(keys.shape)
print(values.shape)





torch.Size([6, 2])
torch.Size([6, 2])


In [16]:
attn_scores = q_2 @ k_2.T
print(attn_scores)


tensor(1.8524)


  attn_scores = q_2 @ k_2.T


In [17]:
attn_scores = keys @ q_2.T
print(attn_scores)

attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)



tensor([1.2705, 1.8524, 1.8338, 1.0795, 0.5577, 1.5440])
tensor([0.1393, 0.2493, 0.2447, 0.1151, 0.0683, 0.1832])


In [18]:
d_k = keys.shape[-1] 
attn_weights_2 = torch.softmax(attn_scores / d_k ** 0.5, dim=-1)

print(attn_weights_2)






tensor([0.1495, 0.2256, 0.2226, 0.1306, 0.0903, 0.1814])


In [19]:
context_vec_2 = attn_weights_2 @ values
print(values)
print(context_vec_2)






tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3903, 0.9996],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])
tensor([0.3069, 0.8253])


In [20]:
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec




In [21]:
torch.manual_seed(123)

sa_v1 = SelfAttention(d_in, d_out)

print(sa_v1(inputs))


tensor([[0.3003, 0.8092],
        [0.3069, 0.8253],
        [0.3069, 0.8252],
        [0.2954, 0.7975],
        [0.2933, 0.7926],
        [0.2997, 0.8079]], grad_fn=<MmBackward0>)


In [22]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [23]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)

print(sa_v2(inputs))





tensor([[-0.0746,  0.0706],
        [-0.0755,  0.0696],
        [-0.0755,  0.0697],
        [-0.0767,  0.0678],
        [-0.0770,  0.0673],
        [-0.0761,  0.0686]], grad_fn=<MmBackward0>)


In [24]:
# 将 W_query 权重迁移
sa_v1.W_query.data = sa_v2.W_query.weight.data.T.clone()

# 将 W_key 权重迁移
sa_v1.W_key.data = sa_v2.W_key.weight.data.T.clone()

# 将 W_value 权重迁移
sa_v1.W_value.data = sa_v2.W_value.weight.data.T.clone()







In [25]:
X = torch.randn(2, d_in)  # batch=2, 输入维度 d_in

Q_v2 = sa_v2.W_query(X)                  # nn.Linear
Q_v  = X @ sa_v1.W_query                  # nn.Parameter

print(Q_v2)
print(Q_v)
print("差异：", (Q_v2 - Q_v).abs().max())

tensor([[-0.3378, -0.0580],
        [ 0.3505, -0.1838]], grad_fn=<MmBackward0>)
tensor([[-0.3378, -0.0580],
        [ 0.3505, -0.1838]], grad_fn=<MmBackward0>)
差异： tensor(0., grad_fn=<MaxBackward1>)


###  应用因果注意力掩码

In [26]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)

print(attn_weights)





tensor([[0.1920, 0.1645, 0.1657, 0.1549, 0.1720, 0.1509],
        [0.2039, 0.1657, 0.1671, 0.1494, 0.1663, 0.1476],
        [0.2038, 0.1657, 0.1671, 0.1495, 0.1664, 0.1476],
        [0.1868, 0.1666, 0.1673, 0.1570, 0.1660, 0.1563],
        [0.1830, 0.1668, 0.1674, 0.1588, 0.1657, 0.1584],
        [0.1933, 0.1662, 0.1672, 0.1541, 0.1664, 0.1528]],
       grad_fn=<SoftmaxBackward0>)


In [27]:
# 应用因果注意力掩码
tril_mask = torch.tril(torch.ones(keys.shape[0], keys.shape[0]))

print(tril_mask)



tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [28]:
masked_simple = attn_weights * tril_mask

print(masked_simple)





tensor([[0.1920, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2039, 0.1657, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2038, 0.1657, 0.1671, 0.0000, 0.0000, 0.0000],
        [0.1868, 0.1666, 0.1673, 0.1570, 0.0000, 0.0000],
        [0.1830, 0.1668, 0.1674, 0.1588, 0.1657, 0.0000],
        [0.1933, 0.1662, 0.1672, 0.1541, 0.1664, 0.1528]],
       grad_fn=<MulBackward0>)


In [29]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums

print(masked_simple_norm)






tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3798, 0.3088, 0.3114, 0.0000, 0.0000, 0.0000],
        [0.2756, 0.2458, 0.2469, 0.2317, 0.0000, 0.0000],
        [0.2174, 0.1982, 0.1989, 0.1886, 0.1969, 0.0000],
        [0.1933, 0.1662, 0.1672, 0.1541, 0.1664, 0.1528]],
       grad_fn=<DivBackward0>)


In [30]:
mask = torch.triu(torch.ones(keys.shape[0], keys.shape[0]), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(),-torch.inf)
print(masked)


tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4634, 0.1708, 0.1825,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1087, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0924, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1355, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [31]:
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
print(attn_weights)






tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3798, 0.3088, 0.3114, 0.0000, 0.0000, 0.0000],
        [0.2756, 0.2458, 0.2469, 0.2317, 0.0000, 0.0000],
        [0.2174, 0.1982, 0.1989, 0.1886, 0.1969, 0.0000],
        [0.1933, 0.1662, 0.1672, 0.1541, 0.1664, 0.1528]],
       grad_fn=<SoftmaxBackward0>)


### 使用dropout

In [32]:
dropout = nn.Dropout(0.5)
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7596, 0.6176, 0.6228, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4916, 0.4938, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3963, 0.0000, 0.3773, 0.0000, 0.0000],
        [0.0000, 0.3324, 0.3344, 0.3081, 0.3329, 0.0000]],
       grad_fn=<MulBackward0>)


In [33]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [34]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape #批次，token数量，输入向量维度
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5,dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec         

    
        

In [35]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim = -1)

In [36]:
context_length = 6 
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout=0.0, num_heads=2)
context_vecs = mha(batch)

In [37]:
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4821,  0.4336, -0.1471,  0.4106],
         [-0.5368,  0.5483, -0.2493,  0.3548],
         [-0.5568,  0.5899, -0.2792,  0.3358],
         [-0.4954,  0.5321, -0.2643,  0.2959],
         [-0.4602,  0.5177, -0.2202,  0.2206],
         [-0.4490,  0.4977, -0.2425,  0.2452]],

        [[-0.4821,  0.4336, -0.1471,  0.4106],
         [-0.5368,  0.5483, -0.2493,  0.3548],
         [-0.5568,  0.5899, -0.2792,  0.3358],
         [-0.4954,  0.5321, -0.2643,  0.2959],
         [-0.4602,  0.5177, -0.2202,  0.2206],
         [-0.4490,  0.4977, -0.2425,  0.2452]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


In [38]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) #构造一个线性层，将多头注意力的输出投影到与输入相同的维度
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape #批次，token数量，输入向量维度
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens,:num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

        context_vec = self.out_proj(context_vec)

        return context_vec

        
        
        
        


In [39]:
a = torch.tensor([[[[0.2745, 0.6584,0.2775, 0.8573],
                   [0.8993, 0.0390,0.9268, 0.7388],
                   [0.7179, 0.7058,0.9156, 0.4340]],
                  [[0.0772, 0.3565,0.1479, 0.5331],
                   [0.4066, 0.2318,0.4545, 0.9737],
                   [0.4606, 0.5159,0.4220, 0.5786]]]])
print(a.shape)

torch.Size([1, 2, 3, 4])


In [40]:
print(a @ a.transpose(2, 3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


In [41]:
first_head = a[0,0,:,:]
first_res = first_head @ first_head.T
print(first_res)

second_head = a[0,1,:,:]
second_res = second_head @ second_head.T
print(second_res)










tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])
tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


In [42]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)





tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2864, 0.3588],
         [0.2699, 0.3870],
         [0.2644, 0.3925],
         [0.2579, 0.4026]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2864, 0.3588],
         [0.2699, 0.3870],
         [0.2644, 0.3925],
         [0.2579, 0.4026]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


In [43]:
# 初始化多头注意力模块
multi_head_attention = MultiHeadAttention(
    d_in=768,           # 输入维度
    d_out=768,          # 输出维度  
    context_length=1024, # 上下文长度（根据需要调整）
    num_heads=12,       # 注意力头数
    dropout=0.1         # dropout率（可选）
)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in = cfg["emb_dim"],
            d_out = cfg["emb_dim"],
            context_length = cfg["context_length"],
            dropout = cfg["dropout"],
            num_heads = cfg["num_heads"],
            qkv_bias = cfg["qkv_bias"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        
        