In [1]:
import torch
from torch import nn

In [4]:
inputs = torch.tensor( 
 [[0.43, 0.15, 0.89], # Your (x^1) 
 [0.55, 0.87, 0.66], # journey (x^2) 
 [0.57, 0.85, 0.64], # starts (x^3) 
 [0.22, 0.58, 0.33], # with (x^4) 
 [0.77, 0.25, 0.10], # one (x^5) 
 [0.05, 0.80, 0.55]] # step (x^6) 
)
d_in = inputs.shape[1]
d_out = 2

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.w_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias) 
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    def forward(self, x):
        q = x @ w_query
        k = x @ w_key
        v = x @ w_value
        attn_scores = q @ k.T
        attn_weughts = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        context_vec = attn_weughts @ v
        return context_vec

### 因果注意力（掩码注意力）

#### 只考虑小于等于当前位置的注意力分数

### 步骤1 计算完整注意力权重

In [6]:
sa = SelfAttention(d_in, d_out)
queries = sa.w_query(inputs)
keys = sa.w_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
print(attn_weights)

tensor([[0.1692, 0.1621, 0.1623, 0.1693, 0.1701, 0.1670],
        [0.1713, 0.1623, 0.1626, 0.1679, 0.1707, 0.1652],
        [0.1712, 0.1625, 0.1627, 0.1678, 0.1705, 0.1652],
        [0.1695, 0.1642, 0.1643, 0.1672, 0.1690, 0.1657],
        [0.1677, 0.1666, 0.1666, 0.1661, 0.1671, 0.1658],
        [0.1708, 0.1625, 0.1627, 0.1680, 0.1704, 0.1655]],
       grad_fn=<SoftmaxBackward0>)


### 步骤2 将对角线元素置0

In [7]:
context_length = attn_scores.shape[0] 
mask_simple = torch.tril(torch.ones(context_length, context_length)) 
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [12]:
masked_simple = attn_weights*mask_simple 
print(masked_simple)

tensor([[0.1692, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1713, 0.1623, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1712, 0.1625, 0.1627, 0.0000, 0.0000, 0.0000],
        [0.1695, 0.1642, 0.1643, 0.1672, 0.0000, 0.0000],
        [0.1677, 0.1666, 0.1666, 0.1661, 0.1671, 0.0000],
        [0.1708, 0.1625, 0.1627, 0.1680, 0.1704, 0.1655]],
       grad_fn=<MulBackward0>)


### 步骤3 再归一化

In [15]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) 
masked = attn_scores.masked_fill(mask.bool(), -torch.inf) 
print(masked)

tensor([[-0.0819,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0468, -0.1227,    -inf,    -inf,    -inf,    -inf],
        [-0.0443, -0.1178, -0.1160,    -inf,    -inf,    -inf],
        [-0.0235, -0.0685, -0.0674, -0.0426,    -inf,    -inf],
        [ 0.0134,  0.0038,  0.0041, -0.0005,  0.0077,    -inf],
        [-0.0498, -0.1206, -0.1188, -0.0730, -0.0530, -0.0944]],
       grad_fn=<MaskedFillBackward0>)


In [16]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1) 
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5134, 0.4866, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3448, 0.3274, 0.3278, 0.0000, 0.0000, 0.0000],
        [0.2548, 0.2468, 0.2470, 0.2514, 0.0000, 0.0000],
        [0.2011, 0.1997, 0.1998, 0.1991, 0.2003, 0.0000],
        [0.1708, 0.1625, 0.1627, 0.1680, 0.1704, 0.1655]],
       grad_fn=<SoftmaxBackward0>)


### 利用 dropout 掩码额外的注意力权重

#### 一是计算注意力权重之后，二是将这些权重应用于值向量之后。

In [18]:
torch.manual_seed(123) 
dropout = torch.nn.Dropout(0.5) 
torch.manual_seed(123) 
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6897, 0.6547, 0.6556, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4936, 0.4940, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3995, 0.0000, 0.3982, 0.0000, 0.0000],
        [0.0000, 0.3250, 0.3254, 0.3361, 0.3409, 0.0000]],
       grad_fn=<MulBackward0>)


### 实现掩码注意力类

In [19]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [22]:
class MaskAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, 
                qkv_bais = False):
        super().__init__()
        self.d_out = d_out
        self.wq = nn.Linear(d_in, d_out, bias=qkv_bais)
        self.wk = nn.Linear(d_in, d_out, bias=qkv_bais)
        self.wv = nn.Linear(d_in, d_out, bias=qkv_bais)
        self.dropout = nn.Dropout(dropout)
        self.mask = torch.triu(torch.ones(context_length, context_length))
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        attn_scores = q @ k.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        attn_weights = torch.softmax(attn_scores / k.shape[-1] ** 0.5, dim = -1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ v
        return context_vec


In [23]:
torch.manual_seed(123) 
context_length = batch.shape[1] 
ca = MaskAttention(d_in, d_out, context_length, 0.0) 
context_vecs = ca(batch) 
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])
