In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import tiktoken

In [3]:
class GPTDataset(Dataset):
    def __init__(self, text, tokenizer, context_size):
        self.x = []
        self.y = []
        enc_txt = tokenizer.encode(text)
        for i in range(0, len(enc_txt) - context_size):
            x = enc_txt[i:i+context_size]
            y = enc_txt[i+1:i+context_size+1]
            self.x.append(torch.tensor(x))
            self.y.append(torch.tensor(y))
        self.x = torch.stack(self.x)
        self.y = torch.stack(self.y)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def load_data(self, batch_size, shuffle=True):
        return DataLoader(self, batch_size=batch_size, shuffle=shuffle, drop_last=True, num_workers=0)

In [8]:
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, embed_size, max_len):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_len, embed_size)
        
    def forward(self, x):
        tok_embed = self.token_embedding(x)
        pos_embed = self.position_embedding(torch.arange(x.shape[1]))
        return tok_embed + pos_embed
    
    def parameters(self):
        return self.token_embedding.parameters() + self.position_embedding.parameters()

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, n_heads, dropout, masked, max_len):
        super().__init__()
        self.d_out = d_out
        self.n_heads = n_heads
        self.head_dim = d_out // n_heads
        self.masked = masked
        self.q = nn.Linear(d_in, d_out, bias=False)
        self.k = nn.Linear(d_in, d_out, bias=False)
        self.v = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out)
        self.register_buffer("mask", torch.triu(torch.ones(max_len, max_len), diagonal=1))
        
    def forward(self, x):
        B, T, C = x.shape
        Q = self.q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        attention_scores = Q @ K.transpose(-2, -1)
        if self.masked:
            mask = self.mask[:T, :T]
            attention_scores = attention_scores.masked_fill(mask.bool(), -torch.inf)
        attention_weights = torch.softmax(attention_scores / (self.head_dim ** 0.5), dim=-1)
        attention_weights = self.dropout(attention_weights)
        context_vectors = attention_weights @ V
        context_vectors = context_vectors.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(context_vectors)
    
    def parameters(self):
        return self.q.parameters() + self.k.parameters() + self.v.parameters() + self.out_proj.parameters()
    
class Network(nn.Module):
    def __init__(self, vocab_size, embed_size, max_len, n_heads, dropout, masked):
        super().__init__()
        self.embedding_layer = EmbeddingLayer(vocab_size, embed_size, max_len)
        self.attention_layer = MultiHeadAttention(embed_size, embed_size, n_heads, dropout, masked, max_len)
        
    def forward(self, x):
        embedded = self.embedding_layer(x)
        attention_vectors = self.attention_layer(embedded)
        return attention_vectors
    
    def parameters(self):
        return self.embedding_layer.parameters() + self.attention_layer.parameters()


In [9]:
with open("../the-verdict.txt", "r", encoding="utf-8") as file:
    raw_text = file.read()

In [24]:
context_size = 1024
batch_size = 32
embed_size = 768
n_heads = 12
dropout = 0.1
masked = True
tokenizer = tiktoken.get_encoding("gpt2")
vocab_size = tokenizer.n_vocab

In [29]:
dataset = GPTDataset(raw_text, tokenizer, context_size)
dataloader = iter(dataset.load_data(batch_size))
network = Network(vocab_size, embed_size, context_size, n_heads, dropout, masked)
print(f"Number of parameters: {sum(p.numel() for p in network.parameters())}")

Number of parameters: 41743872


In [28]:
output = network(next(dataloader)[0])
print(output.shape)
print(output)

torch.Size([32, 1024, 768])
tensor([[[-4.0036e-01,  7.2065e-01,  1.2895e+00,  ..., -4.4990e-01,
           3.0665e-01, -1.1862e+00],
         [-9.5281e-01,  4.5434e-01,  8.1317e-01,  ..., -1.0560e-01,
          -1.8351e-01, -3.8012e-01],
         [-4.0295e-01,  2.1444e-01,  4.0911e-01,  ..., -7.1460e-01,
           6.4789e-02, -1.1589e-01],
         ...,
         [-4.0129e-02, -7.4827e-02,  9.2245e-02,  ..., -6.1951e-02,
           3.2196e-02,  7.0760e-03],
         [-8.8575e-02, -1.0987e-01,  6.9352e-02,  ...,  2.9009e-02,
           3.8491e-02,  2.0442e-02],
         [-5.1991e-02, -6.3246e-02,  4.5091e-02,  ...,  2.2020e-03,
           4.8783e-02,  1.3770e-03]],

        [[-5.3485e-02,  4.3702e-01,  5.7494e-01,  ..., -2.1828e-01,
           5.1636e-01, -1.0875e+00],
         [-1.4255e-01,  7.1311e-01,  1.1093e-01,  ..., -2.4189e-01,
           5.7655e-01, -6.1783e-01],
         [-2.0377e-01,  3.4689e-01, -3.4207e-02,  ..., -4.1896e-02,
           1.2861e-01, -1.0131e-01],
         ..