#  <font color='#FFE15D'><b>💎 Build GPT-2 </b></font>

# 🔴 **Import**

In [None]:
import time
from dataclasses import dataclass

from datasets import load_dataset
from tokenizers import Tokenizer

import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F

# 🔴 **Utils**

In [None]:
def prepare_data(tokens, seq_len):
    # Trim tokens so that total length is divisible by seq_len
    n_tokens = (tokens.shape[0] // seq_len) * seq_len
    tokens = tokens[:n_tokens]

    # Reshape to 2D tensor
    return tokens.view(-1, seq_len)


In [None]:
def num_trainable_params(model):
  nums = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
  return nums

In [None]:
def calculate_time(model, x, num_runs=10):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_runs):
        model(*x)
    torch.cuda.synchronize()
    return (time.time() - start) / num_runs

# 🔴 **Init**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

# 🔴 **Dataset**

In [None]:
dataset = load_dataset("roneneldan/TinyStories")
dataset

In [None]:
tokenizer = Tokenizer.from_file("bpe-tokenizer_tinystories.json")
tokenizer

In [None]:
# Load tokens from pytorch file
train_token_ids = torch.load('tokenized-train-samples_vocab-10k.pt')
valid_token_ids = torch.load('tokenized-valid-samples_vocab-10k.pt')

print("📊 Number of Tokens")
print(f"🔹 Train: {len(train_token_ids):,} tokens")
print(f"🔹 Valid: {len(valid_token_ids):,} tokens")

In [None]:
class TinyStoriesDataset(Dataset):

    def __init__(self, data, seq_len):
        self.seq_len = seq_len
        self.data = prepare_data(data, seq_len+1)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample[:-1], sample[1:]#.long()

In [None]:
seq_len = 128

train_set = TinyStoriesDataset(train_token_ids, seq_len)
valid_set = TinyStoriesDataset(valid_token_ids, seq_len)

print(f"📊 Number of Samples")
print(f"🔹 Train: {len(train_set):,} samples")
print(f"🔹 Valid: {len(valid_set):,} samples")

In [None]:
torch.manual_seed(1337)
batch_size = 64

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)#, num_workers=4)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, pin_memory=True)#, num_workers=4)

print(f"📊 Number of Batches")
print(f"🔹 Train: {len(train_loader):,} batches")
print(f"🔹 Valid: {len(valid_loader):,} batches")

In [None]:
x_batch, y_batch = next(iter(train_loader))

print(f"📊 Batch Shapes")
print(f"🔹 Input: {x_batch.shape}")
print(f"🔹 Target: {y_batch.shape}")

# 🔴 **Model**

## 🟠 Embedding

In [None]:
wte = nn.Embedding(tokenizer.get_vocab_size(), 100)
wte(torch.tensor([1, 2, 100])).shape

In [None]:
wpe = nn.Embedding(seq_len, 100)
wpe(torch.tensor([1, 2, 100])).shape

In [None]:
x = wte(x_batch) + wpe(torch.arange(x_batch.shape[1]))
x.shape

## 🟠 Scaled Dot-Product Attention

In [None]:
q = k = v = x
print(q.shape)

mask = torch.tril(torch.ones(seq_len, seq_len))

scores = q @ k.transpose(-2, -1) / (k.shape[-1]**0.5)
scores.masked_fill_(mask ==0, float(-torch.inf))
scores = scores.softmax(dim=-1)
print(scores.shape)

z = scores @ v
z.shape

In [None]:
def scaled_dot_product_attention(q, k, v):
    mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])).to(device)
    scores = q @ k.transpose(-2, -1) / (k.shape[-1]**0.5)
    scores.masked_fill_(mask==0, float(-torch.inf))
    scores = scores.softmax(dim=-1)
    z = scores @ v
    return z

In [None]:
scaled_dot_product_attention(x.to(device), x.to(device), x.to(device)).shape

In [None]:
q = torch.randn((128, 1024, 768), device=device)
k = torch.randn((128, 1024, 768), device=device)
v = torch.randn((128, 1024, 768), device=device)
q.shape

In [None]:
scaled_dot_product_attention(q, k, v).shape

In [None]:
calculate_time(scaled_dot_product_attention, (q, k, v), num_runs=20)

In [None]:
F.scaled_dot_product_attention(q, k, v, is_causal=True).shape

In [None]:
torch.abs(scaled_dot_product_attention(q, k, v) - F.scaled_dot_product_attention(q, k, v, is_causal=True)).max()

In [None]:
calculate_time(F.scaled_dot_product_attention, (q, k, v), num_runs=20)

## 🟠 Multi Head Attention

In [None]:
class GPTConfig:
    n_embd: int = 100
    n_head: int = 5

config = GPTConfig()
config.n_embd

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.head_size = self.n_embd // self.n_head

        self.qkv_proj = nn.Linear(self.n_embd, 3*self.n_embd, bias=False)

        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.residual = True

    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv_proj(x).view(B, T, 3*self.n_head, self.head_size).transpose(1, 2).chunk(3, dim=-3)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.c_proj(y)
        return y

In [None]:
mha = MultiHeadAttention(config)
mha(x).shape

In [None]:
xx = torch.arange(24).view(2, 2, 3, 2)
print(xx)
xx.reshape(2, 3, 4)

In [None]:
calculate_time(mha.to(device), (x.to(device),), num_runs=20)

## 🟠 Feed Forward (MLP)

In [None]:
class GPTConfig:
    n_embd: int = 100
    n_head: int = 5
    f_expnd: float = 4

config = GPTConfig()
config.n_embd

In [None]:
class FeedForward(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.f_expnd = config.f_expnd

        self.up_proj = nn.Linear(self.n_embd, int(self.f_expnd*self.n_embd), bias=False)
        self.down_proj = nn.Linear(int(self.f_expnd*self.n_embd), self.n_embd, bias=False)
        self.down_proj.residual = True

    def forward(self, x):
        return self.down_proj(F.gelu(self.up_proj(x)))

In [None]:
mlp = FeedForward(config)
mlp(x).shape

In [None]:
num_trainable_params(mlp)*1000

In [None]:
calculate_time(mlp, (x, ), num_runs=20)

## 🟠 Decoder Block

In [None]:
class DecoderBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.mha = MultiHeadAttention(config)

        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = FeedForward(config)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [None]:
decoder = DecoderBlock(config)
decoder(x).shape

In [None]:
num_trainable_params(decoder) * 1e3

In [None]:
calculate_time(decoder, (x, ), num_runs=20) * 1e3

## 🟠 GPT

In [None]:
class GPTConfig:
    vocab_size: int = 10_000
    seq_len: int = 128
    n_layer: int = 12
    n_embd: int = 100
    n_head: int = 5
    f_expnd: float = 4


config = GPTConfig()
config.n_embd

In [None]:
class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.seq_len, config.n_embd)
        # self.decoders = nn.Sequential(*[DecoderBlock(config) for _ in range(config.n_layer)])
        self.decoders = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layer)])
        self.lnf = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.lm_head.weight = self.wte.weight
        # self.lm_head.weight.data.uniform_(-1/self.lm_head.in_features**0.5, 1/self.lm_head.in_features**0.5)
        # nn.init.uniform_(self.lm_head.weight, -1/self.lm_head.in_features**0.5, 1/self.lm_head.in_features**0.5)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        std = 0.02
        if isinstance(module, nn.Linear):
            if hasattr(module, 'residual'):
                std *= (2*self.config.n_layer)**-0.5
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=std)

    def forward(self, idx):
        B, T = idx.shape

        x = self.wte(idx) + self.wpe(torch.arange(T, device=device))

        # x = self.decoders(x)
        for decoder in self.decoders:
            x = decoder(x)

        x = self.lnf(x)
        logits = self.lm_head(x)
        return logits

In [None]:
model = GPT(config).to(device)
model(x_batch.to(device)).shape

In [None]:
num_trainable_params(model), num_trainable_params(model.decoders), num_trainable_params(model.lm_head)

In [None]:
calculate_time(model, (x_batch.to(device),), num_runs=100) * 1e3

## 🟠 Initialization

In [None]:
model = GPT(
    GPTConfig(
        seq_len=256, vocab_size=10_000, n_layer=4, n_embd=256, n_head=4
        )).to(device)

In [None]:
plt.hist(model.decoders[0].mha.c_proj.weight.flatten().detach().cpu(), bins=50);

In [None]:
plt.hist(model.wpe.weight.flatten()[:100_000].detach().cpu(), bins=50);

In [None]:
plt.hist(model.decoders[2].mlp.down_proj.weight.flatten().detach().cpu(), bins=50);

# 🔴 **Config**