In [1]:
import torch
import numpy as np

In [3]:
torch.empty(6,6).shape

torch.Size([6, 6])

In [None]:
import tiktoken
import torch
vocab_size = 50257
dim = 256
text = "You journey starts with one step"
tokenizer = tiktoken.get_encoding('gpt2')
tokenids = torch.tensor(tokenizer.encode(text))
# torch.nn.Embedding()

torch.Size([6])

In [None]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.parameter(torch.rand(d_in, d_out))
        self.W_value = nn.parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [None]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        return attn_weights @ values

In [8]:
import torch
context_length = 6
torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])

In [None]:
attn_scorese = torch.rand(context_length, context_length)
attn_scorese

tensor([[0.8811, 0.6388, 0.5549, 0.7745, 0.6322, 0.8204],
        [0.3097, 0.8237, 0.1903, 0.3551, 0.5903, 0.2216],
        [0.8307, 0.4468, 0.0244, 0.8714, 0.3977, 0.0296],
        [0.3355, 0.3269, 0.1694, 0.0745, 0.5342, 0.5232],
        [0.8604, 0.5418, 0.3902, 0.8457, 0.7714, 0.2177],
        [0.3008, 0.4227, 0.3015, 0.4377, 0.5562, 0.1437]])

In [10]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
masked = attn_scorese.masked_fill(mask=mask, value=-torch.inf)
masked

tensor([[0.8811,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.3097, 0.8237,   -inf,   -inf,   -inf,   -inf],
        [0.8307, 0.4468, 0.0244,   -inf,   -inf,   -inf],
        [0.3355, 0.3269, 0.1694, 0.0745,   -inf,   -inf],
        [0.8604, 0.5418, 0.3902, 0.8457, 0.7714,   -inf],
        [0.3008, 0.4227, 0.3015, 0.4377, 0.5562, 0.1437]])

In [13]:
torch.triu(torch.ones(context_length, context_length), diagonal=1)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [11]:
torch.softmax(masked, dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3743, 0.6257, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4700, 0.3202, 0.2099, 0.0000, 0.0000, 0.0000],
        [0.2771, 0.2747, 0.2347, 0.2134, 0.0000, 0.0000],
        [0.2352, 0.1710, 0.1470, 0.2317, 0.2151, 0.0000],
        [0.1557, 0.1759, 0.1558, 0.1785, 0.2010, 0.1331]])

In [12]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(example)
print(dropout(example))

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [None]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scorese = queries @ keys.transpose(1,2) # bxLxd  @ bxdxL，得到 bxLxL
        attn_scorese.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # masked_fill_是原地执行方法，会改变原张量；而masked_fill不是，会返回新张量；使用原地方法就不需要新的内存分配
        attn_weights = torch.softmax(attn_scorese / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        return attn_weights @ values

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert(d_out % num_heads == 0), 'd_out must be divisible by num_heads'
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.head_dim = d_out // self.num_heads
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x) # (b,len,dim)
        queries = self.W_query(x)
        values = self.W_value(x)
        # 分割多头
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) # (b,len,num_head,head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        # 变形，为多头注意力计算做准备
        keys = keys.transpose(1,2) # (b, num_head, len, head_dim)，后面在做批量矩阵乘法的时候，是在后两个维度做的，所以就等同于每个批次的每个头在做独立的计算
        values = values.transpose(1,2)
        queries = queries.transpose(1,2)
        attn_scorese = queries @ keys.transpose(2,3) # (b, num_head, len, head_dim) x (b, num_head, head_dim, len) = (b, num_head, len, len)
        attn_scorese.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scorese/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights) # (b,num_head,len,len)
        context_vec = attn_weights @ values # (b,num_head,len,len) x (b, num_head, len, head_dim) = (b, num_head, len, head_dim)
        context_vec = context_vec.transpose(1,2) # 变形成 (b, len, num_head, head_dim)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # (b, len, dim)
        return self.out_proj(context_vec)



In [None]:
nn.TransformerDecoderLayer
nn.TransformerDecoder

In [14]:
def generate_square_subsequent_mask(seq_len): # 生成下三角的因果注意力掩码
    mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [15]:
generate_square_subsequent_mask(5)

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [17]:
torch.triu(torch.ones(5, 5))

tensor([[1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.]])

In [16]:
(torch.triu(torch.ones(5, 5)) == 1)

tensor([[ True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True]])