In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import datasets
import wandb
from jaxtyping import Float, Int
from torch import Tensor
from torch.utils.data import DataLoader
from einops import einsum, rearrange, reduce
from dataclasses import dataclass
from tqdm.notebook import tqdm

In [2]:
seq_len = 512
def tokenize(raw_text):
  raw_text = raw_text['text'][0]
  token = [ord(x) for x in raw_text]
  current_token = []
  next_token = []
  for idx in range(len(token) // seq_len):
    t = token[idx:idx + seq_len + 1]
    current_token.append(t[:-1])
    next_token.append(t[1:])
  return {'current': current_token,
          'next': next_token}

In [3]:
raw_text_data = datasets.load_dataset('karpathy/tiny_shakespeare', split='train')
raw_text_data_val = datasets.load_dataset('karpathy/tiny_shakespeare', split='validation')
char_data_train = raw_text_data.map(tokenize, batched=True, remove_columns=['text']).with_format('torch')
char_data_val = raw_text_data_val.map(tokenize, batched=True, remove_columns=['text']).with_format('torch')
train_dataloader = DataLoader(char_data_train, batch_size=10, shuffle=True)
val_dataloader = DataLoader(char_data_val, batch_size=10, shuffle=True)

In [4]:
class Embedding(nn.Module):
  def __init__(self, d_model: int, seq_len: int):
    super().__init__()
    self.embedding_matrix = nn.Parameter(torch.zeros(256, d_model))
    nn.init.xavier_normal_(self.embedding_matrix)

    self.positional_encoding = nn.Parameter(torch.zeros(seq_len, d_model))
    nn.init.xavier_normal_(self.positional_encoding)

  def forward(self, data: Int[Tensor, "batch seq_len"]) -> Float[Tensor, "batch seq_len d_model"]:
    return self.embedding_matrix[data] + self.positional_encoding

In [5]:
class Attention(nn.Module):
  def __init__(self, n_head: int, d_model: int, d_head: int, seq_len: int):
    super().__init__()
    self.seq_len = seq_len
    self.d_head = d_head

    self.query_matrix = nn.Parameter(torch.zeros(n_head, d_head, d_model))
    nn.init.xavier_normal_(self.query_matrix)

    self.key_matrix = nn.Parameter(torch.zeros(n_head, d_head, d_model))
    nn.init.xavier_normal_(self.key_matrix)

    self.value_matrix = nn.Parameter(torch.zeros(n_head, d_head, d_model))
    nn.init.xavier_normal_(self.value_matrix)

    self.output_matrix = nn.Parameter(torch.zeros(n_head, d_model, d_head))
    nn.init.xavier_normal_(self.output_matrix)

  def forward(self, data: Float[Tensor, "batch seq_len d_model"]) -> Float[Tensor, "batch seq_len d_model"]:
    query = einsum(data, self.query_matrix, "batch seq_len d_model, n_head d_head d_model -> batch n_head seq_len d_head")
    key = einsum(data, self.key_matrix, "batch seq_len d_model, n_head d_head d_model -> batch n_head seq_len d_head")
    value = einsum(data, self.value_matrix, "batch seq_len d_model, n_head d_head d_model -> batch n_head seq_len d_head")

    attn_pre = einsum(query, key, "batch n_head query_len d_head, batch n_head key_len d_head -> batch n_head query_len key_len")
    mask_idx = torch.triu_indices(self.seq_len, self.seq_len, offset=1)
    attn_pre[..., mask_idx[0], mask_idx[1]] = float('-inf')
    attn_pre /= self.d_head ** 0.5
    attn = F.softmax(attn_pre, dim=-1)

    output_pre = einsum(attn, value, "batch n_head query_len key_len, batch n_head key_len d_head -> batch n_head key_len d_head")
    output = einsum(self.output_matrix, output_pre, "n_head d_model d_head, batch n_head seq_len d_head -> batch seq_len d_model")
    return output

In [6]:
class MLP(nn.Module):
  def __init__(self, d_model: int, d_mlp: int):
    super().__init__()
    self.MLP = nn.Sequential(nn.Linear(d_model, d_mlp),
                             nn.ReLU(),
                             nn.Linear(d_mlp, d_model))

  def forward(self, data: Float[Tensor, "batch seq_len d_model"]) -> Float[Tensor, "batch seq_len d_model"]:
    return self.MLP(data)

In [7]:
class LayerNorm(nn.Module):
  def __init__(self):
    super().__init__()
  
  def forward(self, data: Float[Tensor, "batch seq_len d_model"]):
    mean: Float[Tensor, "batch"] = data.mean(dim=[1, 2], keepdim=True)
    std: Float[Tensor, "batch"] = data.std(dim=[1, 2], keepdim=True)
    return (data - mean) / (std + 1e-5)
    

In [8]:
class Unembedding(nn.Module):
  def __init__(self, d_model: int, seq_len: int):
    super().__init__()
    self.unembedding_matrix = nn.Parameter(torch.zeros(d_model, 256))
    nn.init.xavier_normal_(self.unembedding_matrix)

  def forward(self, data: Float[Tensor, "batch seq_len d_model"]) -> Float[Tensor, "batch seq_len 256"]:
    return einsum(self.unembedding_matrix, data, "d_model d_vocab, batch seq_len d_model -> batch seq_len d_vocab")


In [9]:
class TransformerLayer(nn.Module):
  def __init__(self, n_heads: int, d_model: int, d_head: int, seq_len: int, d_mlp: int ):
    super().__init__()
    self.Attn = Attention(n_heads, d_model, d_head, seq_len)
    self.MLP = MLP(d_model, d_mlp)
    self.LayerNorm = LayerNorm()

  def forward(self, data: Float[Tensor, "batch seq_len d_model"]) -> Float[Tensor, "batch seq_len d_model"]:
    resid_attn = self.LayerNorm(data + self.Attn(data))
    return self.LayerNorm(resid_attn + self.MLP(resid_attn))

In [10]:
@dataclass
class TransformerConfig:
  d_model: int
  ctx_len: int
  n_heads: int
  d_head: int
  d_mlp: int
  n_layers: int


In [11]:
class GPT2(nn.Module):
  def __init__(self, model_cfg: TransformerConfig):
    super().__init__()
    self.Embed = Embedding(model_cfg.d_model, model_cfg.ctx_len)   
    self.Unembed = Unembedding(model_cfg.d_model, model_cfg.ctx_len)
    self.Layers = nn.ModuleList([TransformerLayer(model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head, model_cfg.ctx_len, model_cfg.d_mlp) for _ in range(model_cfg.n_layers)])

  def forward(self, data: Int[Tensor, "batch seq_len"]) -> Float[Tensor, "batch seq_len 256"]:
    x = self.Embed(data)
    for tl in self.Layers: 
      x = tl(x)
    return self.Unembed(x)

In [20]:
cfg = TransformerConfig(768, 512, 12, 64, 3072, 3)
gpt2 = GPT2(cfg)

In [21]:
gpt2(torch.randint(256, (1, 512))).shape

torch.Size([1, 512, 256])

In [22]:
optim = torch.optim.AdamW(gpt2.parameters(), lr=0.0001)
loss_fn = nn.CrossEntropyLoss()

In [15]:
class Trainer:
  def __init__():
    pass

In [24]:
def train_epoch():
  for data in tqdm(train_dataloader, leave=False):
    current_tok = data['current']
    next_tok = data['next']

    logits = gpt2(current_tok)
    logits = rearrange(logits, "batch ctx_len d_vocab -> batch d_vocab ctx_len")
    loss = loss_fn(logits, next_tok)
    wandb.log({'loss': loss.item()})
    optim.zero_grad()
    loss.backward()
    optim.step()

In [25]:
wandb.init(project="gpt2-rep")
for _ in tqdm(range(10)):
  train_epoch()

0,1
loss,█▅▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,2.28882


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

In [18]:
tokens = torch.Tensor([[ord(x) for x in """No, no, it cannot be; and yet my heart
"""]]).long()

In [65]:
tokens.shape

torch.Size([1, 39])

In [19]:
import torch
import torch.nn.functional as F

# Auto-regressive sampling with padding handling
def sample(model, tokens, n_tokens, pad_token_id=0):
    # Pad the initial tokens to 512 tokens
    non_padded_len = tokens.size(1)
    pad_length = 512 - tokens.size(1)
    if pad_length > 0:
        tokens = torch.cat([tokens, torch.full((tokens.size(0), pad_length), pad_token_id, dtype=tokens.dtype)], dim=1)
    
    current_length = non_padded_len
    for _ in range(n_tokens):
        logits = model(tokens)
        logits = logits[0, current_length - 1]
        # Sample the next token probabilistically
        token = torch.multinomial(F.softmax(logits, dim=-1), 1)
        # replace the last token with the sampled token
        tokens[0, current_length] = token
        current_length += 1
    return tokens

# Convert the token back to string
def to_string(tokens):
    return ''.join([chr(x) for x in tokens[0]])

sampled_tokens = sample(gpt2, tokens, 300)
print(to_string(sampled_tokens))

No, no, it cannot be; and yet my heart
it:
Vesovesot o aendirthelknt nventhe k I.
isento moneng fen: fo ct haceme


piaio thed coutd tinelsaspry ty bun:

Velliuseld, yostit@s Cidizerind musthalit renthawen:
Wobe bothatheoun:
ullicogat bent or tut aluseculue.itr ui way, nf n mourtheans he an.
Wentould at izelases hary.


Whuthonve's hito                                                                                                                                                                             


In [45]:
tokens.shape

torch.Size([1, 512])