In [299]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import datasets
from jaxtyping import Float, Int
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from einops import einsum, rearrange, reduce
from dataclasses import dataclass
from tqdm.notebook import tqdm

In [300]:
writer = SummaryWriter()

In [301]:
seq_len = 512
def tokenize(raw_text):
  raw_text = raw_text['text'][0]
  token = [ord(x) for x in raw_text]
  current_token = []
  next_token = []
  for idx in range(len(token) // seq_len):
    t = token[idx:idx + seq_len + 1]
    current_token.append(t[:-1])
    next_token.append(t[1:])
  return {'current': current_token,
          'next': next_token}

In [302]:
raw_text_data = datasets.load_dataset('karpathy/tiny_shakespeare', split='train')
char_data = raw_text_data.map(tokenize, batched=True, remove_columns=['text']).with_format('torch')
train_dataloader = DataLoader(char_data, batch_size=10, shuffle=True)

In [303]:
class Embedding(nn.Module):
  def __init__(self, d_model: int, seq_len: int):
    super().__init__()
    self.embedding_matrix = nn.Parameter(torch.zeros(256, d_model))
    nn.init.xavier_normal_(self.embedding_matrix)

    self.positional_encoding = nn.Parameter(torch.zeros(seq_len, d_model))
    nn.init.xavier_normal_(self.positional_encoding)

  def forward(self, data: Int[Tensor, "batch seq_len"]) -> Float[Tensor, "batch seq_len d_model"]:
    return self.embedding_matrix[data] + self.positional_encoding

In [304]:
class Attention(nn.Module):
  def __init__(self, n_head: int, d_model: int, d_head: int, seq_len: int):
    super().__init__()
    self.seq_len = seq_len
    self.d_head = d_head

    self.query_matrix = nn.Parameter(torch.zeros(n_head, d_head, d_model))
    nn.init.xavier_normal_(self.query_matrix)

    self.key_matrix = nn.Parameter(torch.zeros(n_head, d_head, d_model))
    nn.init.xavier_normal_(self.key_matrix)

    self.value_matrix = nn.Parameter(torch.zeros(n_head, d_head, d_model))
    nn.init.xavier_normal_(self.value_matrix)

    self.output_matrix = nn.Parameter(torch.zeros(n_head, d_model, d_head))
    nn.init.xavier_normal_(self.output_matrix)

  def forward(self, data: Float[Tensor, "batch seq_len d_model"]) -> Float[Tensor, "batch seq_len d_model"]:
    query = einsum(data, self.query_matrix, "batch seq_len d_model, n_head d_head d_model -> batch n_head seq_len d_head")
    key = einsum(data, self.key_matrix, "batch seq_len d_model, n_head d_head d_model -> batch n_head seq_len d_head")
    value = einsum(data, self.value_matrix, "batch seq_len d_model, n_head d_head d_model -> batch n_head seq_len d_head")

    attn_pre = einsum(query, key, "batch n_head query_len d_head, batch n_head key_len d_head -> batch n_head query_len key_len")
    mask_idx = torch.triu_indices(self.seq_len, self.seq_len, offset=1)
    attn_pre[..., mask_idx[0], mask_idx[1]] = float('-inf')
    attn_pre /= self.d_head ** 0.5
    attn = F.softmax(attn_pre, dim=-1)

    output_pre = einsum(attn, value, "batch n_head query_len key_len, batch n_head key_len d_head -> batch n_head key_len d_head")
    output = einsum(self.output_matrix, output_pre, "n_head d_model d_head, batch n_head seq_len d_head -> batch seq_len d_model")
    return output

In [305]:
class MLP(nn.Module):
  def __init__(self, d_model: int, d_mlp: int):
    super().__init__()
    self.MLP = nn.Sequential(nn.Linear(d_model, d_mlp),
                             nn.ReLU(),
                             nn.Linear(d_mlp, d_model))

  def forward(self, data: Float[Tensor, "batch seq_len d_model"]) -> Float[Tensor, "batch seq_len d_model"]:
    return self.MLP(data)

In [306]:
class LayerNorm(nn.Module):
  def __init__(self):
    super().__init__()
  
  def forward(self, data: Float[Tensor, "batch seq_len d_model"]):
    mean: Float[Tensor, "batch"] = data.mean(dim=[1, 2], keepdim=True)
    std: Float[Tensor, "batch"] = data.std(dim=[1, 2], keepdim=True)
    return (data - mean) / (std + 1e-5)
    

In [307]:
class Unembedding(nn.Module):
  def __init__(self, d_model: int, seq_len: int):
    super().__init__()
    self.unembedding_matrix = nn.Parameter(torch.zeros(d_model, 256))
    nn.init.xavier_normal_(self.unembedding_matrix)

  def forward(self, data: Float[Tensor, "batch seq_len d_model"]) -> Float[Tensor, "batch seq_len 256"]:
    return einsum(self.unembedding_matrix, data, "d_model d_vocab, batch seq_len d_model -> batch seq_len d_vocab")


In [308]:
class TransformerLayer(nn.Module):
  def __init__(self, n_heads: int, d_model: int, d_head: int, seq_len: int, d_mlp: int ):
    super().__init__()
    self.Attn = Attention(n_heads, d_model, d_head, seq_len)
    self.MLP = MLP(d_model, d_mlp)
    self.LayerNorm = LayerNorm()

  def forward(self, data: Float[Tensor, "batch seq_len d_model"]) -> Float[Tensor, "batch seq_len d_model"]:
    resid_attn = self.LayerNorm(data + self.Attn(data))
    return self.LayerNorm(resid_attn + self.MLP(resid_attn))

In [309]:
@dataclass
class TransformerConfig:
  d_model: int
  ctx_len: int
  n_heads: int
  d_head: int
  d_mlp: int
  n_layers: int


In [310]:
class GPT2(nn.Module):
  def __init__(self, model_cfg: TransformerConfig):
    super().__init__()
    self.Embed = Embedding(model_cfg.d_model, model_cfg.ctx_len)   
    self.Unembed = Unembedding(model_cfg.d_model, model_cfg.ctx_len)
    self.Layers = nn.ModuleList([TransformerLayer(model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head, model_cfg.ctx_len, model_cfg.d_mlp) for _ in range(model_cfg.n_layers)])

  def forward(self, data: Int[Tensor, "batch seq_len"]) -> Float[Tensor, "batch seq_len 256"]:
    x = self.Embed(data)
    for tl in self.Layers: 
      x = tl(x)
    return self.Unembed(x)

In [311]:
cfg = TransformerConfig(768, 512, 12, 64, 3072, 5)
gpt2 = GPT2(cfg)

In [312]:
gpt2(torch.randint(256, (1, 512))).shape

torch.Size([1, 512, 256])

In [313]:
optim = torch.optim.AdamW(gpt2.parameters())
loss_fn = nn.CrossEntropyLoss()

In [314]:
gs : int
def train_epoch():
  global gs
  for data in train_dataloader:
    current_tok = data['current']
    next_tok = data['next']

    logits = gpt2(current_tok)
    logits = rearrange(logits, "batch ctx_len d_vocab -> batch d_vocab ctx_len")
    loss = loss_fn(logits, next_tok)
    writer.add_scalar('Loss/train', loss, gs)
    gs += 1

    optim.zero_grad()
    loss.backward()
    optim.step()

In [315]:
gs = 0
for _ in tqdm(range(10)):
  train_epoch()
writer.flush()

  0%|          | 0/10 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [318]:
tokens = torch.Tensor([[ord(x) for x in """No, no, it cannot be; and yet my heart
Will not confess he owes the malady
That doth my life besiege. Farewell, young lords;
Whether I live or die, be you the sons
Of worthy Frenchmen: let higher Italy,--
Those bated that inherit but the fall
Of the last monarchy,--see that you come
Not to woo honour, but to wed it; when
The bravest questant shrinks, find what you seek,
That fame may cry you loud: I say, farewell.   

Why, Doctor She: my lord, there's one arrived,
If you will see her: now, by my faith and h"""]]).long()
gpt2(tokens).softmax(dim=-1)[0, -1

tensor([1.3733e-06, 1.6671e-06, 2.1347e-06, 1.6876e-06, 2.5960e-06, 9.7359e-07,
        1.7977e-06, 2.1898e-06, 2.1488e-06, 1.6127e-06, 3.2502e-02, 1.4632e-06,
        2.1954e-06, 1.5323e-06, 2.0468e-06, 1.3713e-06, 1.7914e-06, 1.4395e-06,
        2.1790e-06, 1.7833e-06, 2.5914e-06, 2.3813e-06, 1.9690e-06, 1.2316e-06,
        1.8438e-06, 1.8967e-06, 1.4128e-06, 1.0319e-06, 2.3221e-06, 1.9426e-06,
        1.7120e-06, 2.5120e-06, 1.5068e-01, 1.7236e-03, 1.9330e-06, 1.5765e-06,
        2.0675e-06, 1.5935e-06, 1.4256e-06, 3.1214e-03, 2.2059e-06, 1.7522e-06,
        9.4134e-07, 1.2265e-06, 9.7987e-03, 6.9862e-04, 9.3067e-03, 1.5298e-06,
        2.3232e-06, 1.4556e-06, 1.6106e-06, 1.6072e-06, 1.5853e-06, 1.3552e-06,
        1.9654e-06, 2.9649e-06, 1.8951e-06, 1.6008e-06, 1.1470e-02, 3.7692e-03,
        2.5902e-06, 1.4453e-06, 1.2303e-06, 4.2620e-03, 1.7709e-06, 2.7167e-03,
        7.6326e-05, 9.0679e-03, 1.5164e-06, 7.8248e-04, 4.0406e-03, 2.1387e-06,
        4.9046e-04, 3.3679e-03, 1.5955e-