In [8]:
import torch
from self_attention import SelfAttention_v2, inputs

In [17]:
d_in = inputs.shape[1]
d_out = 2

sa_v2 = SelfAttention_v2(d_in,d_out)

In [29]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
    attn_scores / keys.shape[-1]**0.5, dim=-1,
)
print(attn_weights)

tensor([[0.1749, 0.1585, 0.1594, 0.1666, 0.1834, 0.1572],
        [0.1617, 0.1669, 0.1674, 0.1659, 0.1775, 0.1607],
        [0.1616, 0.1668, 0.1674, 0.1659, 0.1779, 0.1604],
        [0.1625, 0.1683, 0.1684, 0.1662, 0.1696, 0.1650],
        [0.1615, 0.1659, 0.1667, 0.1657, 0.1819, 0.1583],
        [0.1632, 0.1686, 0.1686, 0.1664, 0.1664, 0.1668]],
       grad_fn=<SoftmaxBackward0>)


#### 1. 使用 PyTorch 的 tril 来创建对角线以上元素为 0 的掩码矩阵

In [33]:
# 创建对角线以上为 0 的掩码矩阵
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

# 与之前的注意力权重矩阵相乘，使对角线上的值变为 0
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
tensor([[0.1749, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1617, 0.1669, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1616, 0.1668, 0.1674, 0.0000, 0.0000, 0.0000],
        [0.1625, 0.1683, 0.1684, 0.1662, 0.0000, 0.0000],
        [0.1615, 0.1659, 0.1667, 0.1657, 0.1819, 0.0000],
        [0.1632, 0.1686, 0.1686, 0.1664, 0.1664, 0.1668]],
       grad_fn=<MulBackward0>)


#### 2. 重新归一化权重，使得权重合为 1

In [37]:
# dim=-1 沿着最后一个维度（列）求和
# keepdim=True 保持原来的维度结构，便于后续的广播操作
row_sums = masked_simple.sum(dim=-1, keepdim=True)
print(row_sums)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[0.1749],
        [0.3286],
        [0.4958],
        [0.6654],
        [0.8417],
        [1.0000]], grad_fn=<SumBackward1>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4921, 0.5079, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3259, 0.3364, 0.3376, 0.0000, 0.0000, 0.0000],
        [0.2442, 0.2529, 0.2531, 0.2498, 0.0000, 0.0000],
        [0.1919, 0.1971, 0.1980, 0.1968, 0.2161, 0.0000],
        [0.1632, 0.1686, 0.1686, 0.1664, 0.1664, 0.1668]],
       grad_fn=<DivBackward0>)


#### 使用 softmax 综合前面 1、2 步骤，实现更高效的掩码注意力

In [50]:
# diagonal=1 表示不包含主对角线，只包含上三角部分
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) 
print(mask)
# masked_fill() 用指定值填充掩码为True的位置
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[-0.0763,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0408,  0.0038,    -inf,    -inf,    -inf,    -inf],
        [-0.0423,  0.0025,  0.0074,    -inf,    -inf,    -inf],
        [-0.0090,  0.0404,  0.0416,  0.0233,    -inf,    -inf],
        [-0.0586, -0.0207, -0.0141, -0.0227,  0.1097,    -inf],
        [ 0.0040,  0.0504,  0.0502,  0.0317,  0.0320,  0.0354]],
       grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4921, 0.5079, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3259, 0.3364, 0.3376, 0.0000, 0.0000, 0.0000],
        [0.2442, 0.2529, 0.2531, 0.2498, 0.0000, 0.0000],
        [0.1919, 0.1971, 0.1980, 0.1968, 0.2161, 0.0000],
        [0.1632, 0.1686, 0.1686, 0.1664, 0.1664, 0.1668]],
       grad_fn=<Softmax

#### 使用 dropout 掩码额外的注意力权重

> 减少过拟合

In [51]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # 使用 50% 的 dropout 率
example = torch.ones(6, 6)
print(example)
print(dropout(example))  # 大约一半的值被置为 0，且原来的值被放大了，用于位置权重的整体平衡

# 对权重矩阵进行 dropout
print(dropout(attn_weights))

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6729, 0.6752, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5062, 0.4997, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.4323, 0.0000],
        [0.3263, 0.3372, 0.0000, 0.3328, 0.3329, 0.3337]],
       grad_fn=<MulBackward0>)


### 封装一个处理批数据的简化因果注意力

In [55]:
import torch.nn as nn

class CausalAttention(nn.Module):
    """一个简化的因果注意力类"""
    def __init__(self, d_in, d_out, context_length,
        dropout, qkv_bias=False) :
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        )

    def forward(self, x):
        # [batch_size, sequence_length, embedding_dim] = [2, 6, 3]
        b, num_tokens, d_in = x.shape

        # 计算权重矩阵
        keys = self.W_key(x)  # [2, 6, 2]
        queries = self.W_query(x) # [2, 6, 2]
        values = self.W_value(x) # [2, 6, 2]

        # keys.transpose(1, 2)：将键矩阵的维度从 [b, num_tokens, d_out] 转换为 [b, d_out, num_tokens] => [2, 6, 2] -> [2, 2, 6]
        attn_scores = queries @ keys.transpose(1 ,2)  # [2, 6, 2] @ [2, 2, 6] = [2, 6, 6]
        
        # 在 Pytorch 中，所有以下划线结尾的操作都会直接作用于元数据，从而减少不必要的内存复制
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        # 进行 dropout
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In [56]:
# 复制输入文本模拟批量输入
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [58]:
torch.manual_seed(123)
context_length = batch.shape[1] # 6
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print(context_vecs)
print(context_vecs.shape)

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([2, 6, 2])
