#  <font color='#FFE15D'><b>ðŸ’Ž Build GPT-2 </b></font>

# ðŸ”´ **Import**

In [None]:
import time
from dataclasses import dataclass

from datasets import load_dataset
from tokenizers import Tokenizer

import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F

# ðŸ”´ **Utils**

In [None]:
def prepare_data(tokens, seq_len):
    # Trim tokens so that total length is divisible by seq_len
    n_tokens = (tokens.shape[0] // seq_len) * seq_len
    tokens = tokens[:n_tokens]

    # Reshape to 2D tensor
    return tokens.view(-1, seq_len)


In [None]:
def num_trainable_params(model):
  nums = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
  return nums

In [None]:
def calculate_time(model, x, num_runs=10):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_runs):
        model(*x)
    torch.cuda.synchronize()
    return (time.time() - start) / num_runs

# ðŸ”´ **Init**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

# ðŸ”´ **Dataset**

In [None]:
dataset = load_dataset("roneneldan/TinyStories")
dataset

In [None]:
tokenizer = Tokenizer.from_file("bpe-tokenizer_tinystories.json")
tokenizer

In [None]:
# Load tokens from pytorch file
train_token_ids = torch.load('tokenized-train-samples_vocab-10k.pt')
valid_token_ids = torch.load('tokenized-valid-samples_vocab-10k.pt')

print("ðŸ“Š Number of Tokens")
print(f"ðŸ”¹ Train: {len(train_token_ids):,} tokens")
print(f"ðŸ”¹ Valid: {len(valid_token_ids):,} tokens")

In [None]:
class TinyStoriesDataset(Dataset):

    def __init__(self, data, seq_len):
        self.seq_len = seq_len
        self.data = prepare_data(data, seq_len+1)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample[:-1], sample[1:]#.long()

In [None]:
seq_len = 128

train_set = TinyStoriesDataset(train_token_ids, seq_len)
valid_set = TinyStoriesDataset(valid_token_ids, seq_len)

print(f"ðŸ“Š Number of Samples")
print(f"ðŸ”¹ Train: {len(train_set):,} samples")
print(f"ðŸ”¹ Valid: {len(valid_set):,} samples")

In [None]:
torch.manual_seed(1337)
batch_size = 64

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)#, num_workers=4)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, pin_memory=True)#, num_workers=4)

print(f"ðŸ“Š Number of Batches")
print(f"ðŸ”¹ Train: {len(train_loader):,} batches")
print(f"ðŸ”¹ Valid: {len(valid_loader):,} batches")

In [None]:
x_batch, y_batch = next(iter(train_loader))

print(f"ðŸ“Š Batch Shapes")
print(f"ðŸ”¹ Input: {x_batch.shape}")
print(f"ðŸ”¹ Target: {y_batch.shape}")

# ðŸ”´ **Model**

## ðŸŸ  Embedding

In [None]:
wte = nn.Embedding(tokenizer.get_vocab_size(), 100)
wte(torch.tensor([1, 2, 100])).shape

In [None]:
wpe = nn.Embedding(seq_len, 100)
wpe(torch.tensor([1, 2, 100])).shape

In [None]:
x = wte(x_batch) + wpe(torch.arange(x_batch.shape[1]))
x.shape

## ðŸŸ  Scaled Dot-Product Attention

In [None]:
q = k = v = x
print(q.shape)

mask = torch.tril(torch.ones(seq_len, seq_len))

scores = q @ k.transpose(-2, -1) / (k.shape[-1]**0.5)
scores.masked_fill_(mask ==0, float(-torch.inf))
scores = scores.softmax(dim=-1)
print(scores.shape)

z = scores @ v
z.shape

In [None]:
def scaled_dot_product_attention(q, k, v):
    mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])).to(device)
    scores = q @ k.transpose(-2, -1) / (k.shape[-1]**0.5)
    scores.masked_fill_(mask==0, float(-torch.inf))
    scores = scores.softmax(dim=-1)
    z = scores @ v
    return z

In [None]:
scaled_dot_product_attention(x.to(device), x.to(device), x.to(device)).shape

In [None]:
q = torch.randn((128, 1024, 768), device=device)
k = torch.randn((128, 1024, 768), device=device)
v = torch.randn((128, 1024, 768), device=device)
q.shape

In [None]:
scaled_dot_product_attention(q, k, v).shape

In [None]:
calculate_time(scaled_dot_product_attention, (q, k, v), num_runs=20)

In [None]:
F.scaled_dot_product_attention(q, k, v, is_causal=True).shape

In [None]:
torch.abs(scaled_dot_product_attention(q, k, v) - F.scaled_dot_product_attention(q, k, v, is_causal=True)).max()

In [None]:
calculate_time(F.scaled_dot_product_attention, (q, k, v), num_runs=20)

## ðŸŸ  Multi Head Attention