In [2]:
# 产生一个batch的input_embeddings给后面的attention mechanism
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})

        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return dataloader


with open("small-text-sample.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

tokenizer = tiktoken.get_encoding("gpt2")
encoded_text = tokenizer.encode(raw_text)

vocab_size = 50257
output_dim = 3
max_len = 1024
context_length = max_len


token_embedding_layer = nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

max_length = 6
dataloader = create_dataloader(raw_text, batch_size=1, max_length=max_length, stride=max_length)
for batch in dataloader:
    x, y = batch

    token_embeddings = token_embedding_layer(x)
    pos_embeddings = pos_embedding_layer(torch.arange(max_length))

    input_embeddings = token_embeddings + pos_embeddings
    break
batch = input_embeddings
print(batch) # [batch_size,max_length,output_dim]

tensor([[[ 1.8090, -0.2384, -0.5709],
         [-1.3427,  1.4408, -0.9398],
         [ 1.0149,  1.4192,  0.6250],
         [ 2.8832, -0.2751,  0.1739],
         [ 1.2050, -0.4352,  0.2508],
         [-0.9460,  0.3540, -1.6115]]], grad_fn=<AddBackward0>)


In [4]:
# simplified self-attention
class SimplifiedSelfAttention(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        attn_scores = x @ x.transpose(1,2)
        attn_weights = torch.softmax(attn_scores,dim=-1)
        context_vecs = attn_weights @ x
        return context_vecs
simplified_self_attention = SimplifiedSelfAttention()
context_vecs = simplified_self_attention(batch)
print(context_vecs) # [batch_size,max_length,output_dim]

tensor([[[ 2.6048, -0.2498,  0.0533],
         [-1.2471,  1.2364, -1.0516],
         [ 1.5205,  0.7765,  0.3711],
         [ 2.8240, -0.2700,  0.1480],
         [ 2.4528, -0.2141,  0.0753],
         [-1.0771,  0.8017, -1.3121]]], grad_fn=<UnsafeViewBackward0>)


In [None]:
# scaled dot-product attention
d_in = output_dim
d_out = 2
class SelfAttention(nn.Module):
    def __init__(self,d_in,d_out,qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
    def forward(self,x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1,2)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        context_vecs = attn_weights @ values
        return context_vecs
torch.manual_seed(789) # 使Linear方法存储的矩阵每次相同
sa = SelfAttention(d_in,d_out)
context_vecs = sa(batch)
print(context_vecs) # [batch_size,max_length,d_out]

tensor([[[-0.1946,  0.3152],
         [-0.2176,  0.3954],
         [-0.1699,  0.3002],
         [-0.0302,  0.1061],
         [-0.1162,  0.2214],
         [-0.0977,  0.1975]]], grad_fn=<UnsafeViewBackward0>)


In [5]:
# causal attention
d_in = output_dim
d_out = 2
context_length = max_length
class CausalSelfAttention(nn.Module):
    def __init__(self,d_in,d_out,dropout,context_length,qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length,context_length),diagonal=1),
        ) # 创建一个对角线以上为 1 其余为 0 的矩阵作为 mask
    def forward(self,x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill_(
            self.mask.bool()[:max_length,:max_length],-torch.inf
        ) # .bool() 将矩阵中的 0 变成 False，将 1 变成 True, 把 mask 里 True 的地方改动为 -inf
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vecs = attn_weights @ values
        return context_vecs
torch.manual_seed(123)
csa = CausalSelfAttention(d_in,d_out,0,context_length)
context_vecs = csa(batch)
print(context_vecs) # [batch_size,max_length,d_out]

tensor([[[-0.6819, -0.3161],
         [ 0.1707, -0.7459],
         [-0.5489, -0.5477],
         [-1.0141, -0.3355],
         [-0.8303, -0.2648],
         [ 0.1199, -0.5792]]], grad_fn=<UnsafeViewBackward0>)


In [9]:
# multi-head attention(simple implementation:stacking)
num_heads = 2
dropout = 0
class MultiHeadAttentionWrapper_v1(nn.Module):
    def __init__(self,d_in,d_out,dropout,context_length,num_heads,qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalSelfAttention(d_in,d_out,dropout,context_length,qkv_bias=False) for _ in range(num_heads)])
    def forward(self,x):
        return torch.cat([head(x) for head in self.heads],dim = -1)
torch.manual_seed(123)
mha = MultiHeadAttentionWrapper_v1(d_in,d_out,dropout,context_length,num_heads)  
context_vecs = mha(batch)
print(context_vecs) # [batch_size,max_length,d_out*num_heads]

tensor([[[-0.6819, -0.3161,  0.6822,  0.3591],
         [ 0.1707, -0.7459,  0.0434,  0.4049],
         [-0.5489, -0.5477,  0.3481,  0.5484],
         [-1.0141, -0.3355,  0.7766,  0.5552],
         [-0.8303, -0.2648,  0.6747,  0.4604],
         [ 0.1199, -0.5792,  0.4882,  0.3680]]], grad_fn=<CatBackward0>)


In [None]:
# multi-head attention(alternative implementation with weight splits)
class MultiHeadAttentionWrapper_v2(nn.Module):
    def __init__(self,d_in,d_out,dropout,num_heads,qkv_bias=False):
        super().__init__()
        # 检查条件： 如果 condition 为 True，程序继续执行，就像什么也没发生一样。
        # 抛出异常： 如果 condition 为 False，程序会立即停止，并抛出一个 AssertionError，同时显示你写的错误信息。
        assert (d_out % num_heads == 0) , "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.head_dim = d_out // num_heads
        self.num_heads = num_heads
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out,d_out,bias=qkv_bias)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length,context_length),diagonal=1),
        )
    def forward(self,x):
        batch_size,num_tokens,d_in = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = queries.view(batch_size,num_tokens,self.num_heads,self.head_dim)
        keys = keys.view(batch_size,num_tokens,self.num_heads,self.head_dim)
        values = values.view(batch_size,num_tokens,self.num_heads,self.head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        attn_scores = queries @ keys.transpose(-1,-2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens,:num_tokens],-torch.inf
        ) 
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vecs = attn_weights @ values
        context_vecs = context_vecs.transpose(1,2)
        context_vecs = context_vecs.contiguous().view(batch_size,num_tokens,self.d_out)
        context_vecs = self.out_proj(context_vecs)
        return context_vecs
torch.manual_seed(123)
mha = MultiHeadAttentionWrapper_v2(d_in,d_out,dropout,num_heads,qkv_bias=False)  
context_vecs = mha(batch)
print(context_vecs) # [batch_size,max_length,head_dim]

tensor([[[ 0.0420, -0.3826],
         [-0.1632, -0.1290],
         [-0.0262, -0.3907],
         [ 0.0807, -0.5150],
         [ 0.0717, -0.4186],
         [-0.1194, -0.1058]]], grad_fn=<UnsafeViewBackward0>)
