# 因果注意力的掩码实现

In [1]:
import torch
from torch import nn

In [5]:
context_length = 6
ones = torch.ones(context_length, context_length)
print(ones)


tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])


In [6]:
mask = torch.triu(ones, diagonal=1)
print("Causal mask:\n", mask)

Causal mask:
 tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


In [11]:
attn_scores = torch.rand(6, 6)
print("Attention scores before masking:\n", attn_scores)

Attention scores before masking:
 tensor([[0.9255, 0.0453, 0.5155, 0.9033, 0.1484, 0.8146],
        [0.6408, 0.8407, 0.0115, 0.4654, 0.0346, 0.0527],
        [0.7147, 0.1970, 0.1273, 0.3279, 0.1407, 0.6383],
        [0.2883, 0.7989, 0.6767, 0.3149, 0.8252, 0.4870],
        [0.5743, 0.1173, 0.0506, 0.7206, 0.8865, 0.5583],
        [0.4314, 0.9827, 0.0439, 0.4720, 0.5172, 0.4862]])


In [12]:
masked_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)
print("Attention scores after applying causal mask:\n", masked_scores)

Attention scores after applying causal mask:
 tensor([[0.9255,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.6408, 0.8407,   -inf,   -inf,   -inf,   -inf],
        [0.7147, 0.1970, 0.1273,   -inf,   -inf,   -inf],
        [0.2883, 0.7989, 0.6767, 0.3149,   -inf,   -inf],
        [0.5743, 0.1173, 0.0506, 0.7206, 0.8865,   -inf],
        [0.4314, 0.9827, 0.0439, 0.4720, 0.5172, 0.4862]])


In [13]:
attn_weights = torch.softmax(masked_scores, dim=-1)
print("Attention weights after applying causal mask:\n", attn_weights)

Attention weights after applying causal mask:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4502, 0.5498, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4648, 0.2769, 0.2583, 0.0000, 0.0000, 0.0000],
        [0.1935, 0.3224, 0.2853, 0.1987, 0.0000, 0.0000],
        [0.2105, 0.1333, 0.1247, 0.2437, 0.2877, 0.0000],
        [0.1515, 0.2629, 0.1028, 0.1578, 0.1650, 0.1600]])


# 使用dropout 掩码额外的注意力权重

In [14]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(p=0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


经过dropout的结果是 50%的数变成0 ，剩余50%变成了原来的一倍

In [15]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9295, 0.5539, 0.5166, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6448, 0.5707, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2666, 0.0000, 0.4874, 0.0000, 0.0000],
        [0.0000, 0.5258, 0.2056, 0.3155, 0.3301, 0.0000]])


# 实现一个简单的因果注意力类

In [16]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
batch = torch.stack((inputs, inputs), dim=0)  # (B=2, T=6, C=3)
print(batch.shape)

torch.Size([2, 6, 3])


In [24]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()

        # 按照d_in, d_out初始化Q,K,V的线性变换矩阵
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)

        # 初始化掩码矩阵
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

        # 初始化dropout层
        self.dropout = nn.Dropout(dropout)

        # 初始化softmax层
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X):
        
        queries = self.W_q(X)  # (B, T, d_out)
        keys = self.W_k(X)     # (B, T, d_out)
        values = self.W_v(X)   # (B, T, d_out)

        attn_scores = queries @ keys.transpose(1,2)
        # 掩码矩阵取一下前T行前T列，为了适应不同长度的输入序列
        attn_scores = attn_scores.masked_fill(self.mask.bool()[:X.shape[1], :X.shape[1]], -torch.inf)
        attn_weights = self.softmax(attn_scores/ keys.shape[-1]**0.5)
        masked_attn_weights = self.dropout(attn_weights)

        context_vec = masked_attn_weights @ values
        return context_vec




        

In [25]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalSelfAttention(d_in=3, d_out=2, context_length=context_length, dropout=0.0)
context_vec = ca(batch)
print("Context vectors shape:", context_vec.shape)

Context vectors shape: torch.Size([2, 6, 2])


# 多头注意力

In [26]:
# 通过一个Wrapper 来实现一个多头注意力的封装类
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, dropout, qkbv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalSelfAttention(d_in, d_out, context_length, dropout, qkbv_bias)
            for _ in range(num_heads)
        ])

    def forward(self, X):
        head__outpus = [head(X) for head in self.heads]
        # 在最后一个维度上进行拼接
        return torch.cat(head__outpus, dim=-1)

In [27]:
# 使用我们的多头注意力封装类
torch.manual_seed(123)
context_length = batch.shape[1]
mha = MultiHeadAttentionWrapper(d_in=3, d_out=2, context_length=context_length, num_heads=2, dropout=0.0)
context_vec = mha(batch)

print("Context vectors shape:", context_vec.shape)
print(context_vec)

Context vectors shape: torch.Size([2, 6, 4])
tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)


# 权重划分实现多头注意力

In [29]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, dropout, qkv_bias= False):
        super().__init__()

        self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # 初始化Q,K,V的线性变换矩阵
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)

        # 初始化dropout层
        self.dropout = nn.Dropout(dropout)

        # 初始化掩码矩阵
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

        # 初始化softmax层
        self.softmax = nn.Softmax(dim=-1)

        # 输出的线性变换矩阵
        self.W_o = nn.Linear(d_out, d_out)

    def forward(self, X):
        num_btach, num_tokens, d_in = X.shape

        queries = self.W_q(X)  # (num_batch, num_tokens, d_out)
        keys = self.W_k(X)     # (num_batch, num_tokens, d_out)
        values = self.W_v(X)   # (num_batch, num_tokens, d_out)

        # 将d_out维度拆分成(num_heads, head_dim)
        queries = queries.view(num_btach, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(num_btach, num_tokens, self.num_heads, self.head_dim)
        values = values.view(num_btach, num_tokens, self.num_heads, self.head_dim)

        # 钱换一下维度顺序，方便后续计算
        keys = keys.transpose(1,2)      # (num_batch, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1,2) # (num_batch, num_heads, num_tokens, head_dim)
        values = values.transpose(1,2)   # (num_batch, num_heads, num_tokens, head_dim)


        # 计算注意力分数,并进行对角线掩码
        attn_scores =- queries @ keys.transpose(2,3) # (num_batch, num_heads, num_tokens, num_tokens)
        masked_attn_scores = attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)

        # 缩放的注意力权重计算
        scaled_scores = masked_attn_scores / keys.shape[-1]**0.5
        attn_weights = self.softmax(scaled_scores)

        # 计算输出上下文向量
        context_vec = attn_weights @ values # (num_batch, num_heads, num_tokens, head_dim)
        context_vec = context_vec.transpose(1,2)

        context_vec = context_vec.contiguous().view(num_btach, num_tokens, self.d_out)

        # 最后的线性变换
        output = self.W_o(context_vec)
        return output


          


    

In [30]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape 
d_out = 2
mha = MultiHeadAttention(d_in=d_in, d_out=d_out, context_length=context_length, num_heads=2, dropout=0.0)
context_vec = mha(batch)
print("Context vectors shape:", context_vec.shape)
print(context_vec)

Context vectors shape: torch.Size([2, 6, 2])
tensor([[[0.3190, 0.4858],
         [0.2926, 0.3957],
         [0.2840, 0.3645],
         [0.2687, 0.3912],
         [0.2618, 0.3974],
         [0.2571, 0.4064]],

        [[0.3190, 0.4858],
         [0.2926, 0.3957],
         [0.2840, 0.3645],
         [0.2687, 0.3912],
         [0.2618, 0.3974],
         [0.2571, 0.4064]]], grad_fn=<ViewBackward0>)
