In [61]:
#|default_exp xformer

# Transformer

In [1]:
from dataclasses import dataclass
import torch
import matplotlib.pyplot as plt

from torch import nn, optim, tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from functools import partial

import fastcore.all as fc
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset
import tiktoken

from miniai.datasets import * 
from miniai.activations import *
from miniai.learner import *
from miniai.conv import * 
from miniai.resnet import *
from miniai.init import * 
from miniai.sgd import *
from miniai.augment import * 

In [2]:
from IPython.core.debugger import set_trace
%load_ext autoreload
%autoreload 2

In [3]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
plt.style.use('fast')

In [4]:
name = "tiny_shakespeare"
dsd = load_dataset(name)

Found cached dataset tiny_shakespeare (/Users/leonardourbina/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
@dataclass
class TConfig:
    batch_size = 32
    ctx_size = 100
    num_workers = False
    n_embed = 512
    num_heads = 8
    head_width = n_embed//num_heads
    encoding = 'cl100k_base'
    bias = False 
    dropout = 0.1
    fanout = 4 # MLP fanout
    act = partial(GeneralReLU, leak=0.1, sub=0.4)
    depth = 4

In [6]:
config = TConfig()

In [7]:
@fc.delegates(Dataset)
class TextDataset(Dataset):
    def __init__(self, text, encoder, config: TConfig, **kwargs): 
        self.n_vocab = encoder.n_vocab
        self.text = tensor(encoder.encode_ordinary(text))
    def __getitem__(self, i): 
        target = torch.zeros(self.n_vocab, dtype=torch.long)
        target[self.text[i+config.ctx_size+1]] = 1.
        return self.text[i:i+config.ctx_size], target
    def __len__(self): return len(self.text) - config.ctx_size - 1

In [8]:
enc = tiktoken.get_encoding(config.encoding)
train_ds = TextDataset(dsd['train']['text'][0], enc, config)
valid_ds = TextDataset(dsd['validation']['text'][0], enc, config)

train_dl = DataLoader(train_ds, batch_size=config.batch_size, num_workers=config.num_workers)
valid_dl = DataLoader(valid_ds, batch_size=config.batch_size, num_workers=config.num_workers)
dls = DataLoaders(train_dl, valid_dl)

In [9]:
xb, yb = next(iter(dls.train))
xb.shape, yb.shape

(torch.Size([32, 100]), torch.Size([32, 100277]))

In [10]:
class AttentionHead(nn.Module): # Decoder-only dot product attention
    def __init__(self, head_width, ctx_size):
        super().__init__()
        self.keys = nn.Linear(head_width, head_width, bias=False)     # B, T, C -> B, T, C
        self.queries = nn.Linear(head_width, head_width, bias=False)  # B, T, C -> B, T, C
        self.values = nn.Linear(head_width, head_width, bias=False)   # B, T, C -> B, T, C
        self.register_buffer('tril', torch.tril(torch.ones(ctx_size, ctx_size))) # T, T
                
    def forward(self, x):
        keys = self.keys(x)
        queries = self.queries(x)
        values = self.values(x)
        
        B, T, C = keys.shape
        cov = queries @ keys.transpose(-2, -1)*C**(-0.5) # (B, T, C) @ (B, C, T) -> (B, T, T) 
        cov.masked_fill_(self.tril == 0, -torch.inf)

        return F.softmax(cov, dim=-1) @ values # (B, T, T) @ (B, T, H) -> (B, T, C)

In [11]:
class MultiHeadAttention(nn.Module): 
    def __init__(self, head_width, num_heads, ctx_size):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(head_width, ctx_size) for _ in range(config.num_heads)])
        
    def forward(self, x):
        res = torch.cat([head(x) for head in self.heads], dim=-1)        
        return res

In [12]:
class MLP(nn.Module):
    def __init__(self, config: TConfig):
        super().__init__()
        self.lin1 = nn.Linear(config.n_embed * config.num_heads, config.n_embed*config.fanout, bias=config.bias)
        self.act = config.act()
        self.lin2 = nn.Linear(config.n_embed*config.fanout, config.n_embed, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x):
        x = self.lin1(x)
        x = self.act(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x

In [13]:
class Block(nn.Module):
    def __init__(self, config: TConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embed)
        self.attn = MultiHeadAttention(config.n_embed, 
                                       config.num_heads, 
                                       config.ctx_size)
        self.ln_2 = nn.LayerNorm(config.n_embed)
        self.mlp = MLP(config)
    
    def forward(self, inp): 
        x = self.ln_1(inp)
        x = self.attn(x)
        x = inp + self.mlp(x)
        x = self.ln_2(x)
        return x        

In [14]:
class GPT(nn.Module):
    def __init__(self, vocab_size, config: TConfig):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, config.n_embed, device=device)         #  token embedding
        self.position_embedding = nn.Embedding(config.ctx_size, config.n_embed, device=device) #  positional embedding
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.depth)])
        self.ln = nn.LayerNorm(config.n_embed)
        self.project = nn.Linear(config.n_embed, vocab_size)
        
    def forward(self, idx):
        B, T = idx.shape
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)
        
        x = self.token_embedding(idx) + self.position_embedding(pos) 
        for block in self.blocks:
            x = block(x)
        x = self.ln(x)
        x = self.project(x)
        return x
    
    def __iter__(self):
        return model_iter(self)

In [17]:
config = TConfig()
gpt = GPT(enc.n_vocab, config).to(device)

epochs = 5
tmax = epochs * len(dls.train)

astats = ActivationStats(fc.risinstance(GeneralReLU))
metrics = MetricsCB(accuracy=MulticlassAccuracy())
sched = BatchSchedCB(partial(optim.lr_scheduler.OneCycleLR, max_lr=3e-4, total_steps=tmax))
cbs = [DeviceCB(device=device), astats, metrics, sched, ProgressCB(plot=True)]
learn = Learner(gpt, dls, F.cross_entropy, lr=1e-4, cbs=cbs, opt_func=optim.AdamW)

In [19]:
learn.fit()

> [0;32m/Users/leonardourbina/code/ml/fastai2022p2/miniai/activations.py[0m(65)[0;36m__init__[0;34m()[0m
[0;32m     63 [0;31m        [0mms_iters[0m [0;34m=[0m [0;34m([0m[0mmodel_iter[0m[0;34m([0m[0mm[0m[0;34m)[0m [0;32mfor[0m [0mm[0m [0;32min[0m [0mms[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m        [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m        hook_fns = {key: [t for t in (hook(l) for l in ms_iters) if t is not None] 
[0m[0;32m     66 [0;31m                          for key, hook in hooks.items()}
[0m[0;32m     67 [0;31m        [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0mhook_fns[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/Users/leonardourbina/code/ml/fastai2022p2/miniai/activations.py[0m(66)[0;36m__init__[0;34m()[0m
[0;32m     64 [0;31m        [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     65 [0;31m        hook_fns = {key: [t for t in (hook(l) for l in ms_iters) if t is not None] 
[0m[0;32m---> 66 [0;31m                          for key, hook in hooks.items()}
[0m[0;32m     67 [0;31m        [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0mhook_fns[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/Users/leonardourbina/code/ml/fastai2022p2/miniai/activations.py[0m(65)[0;36m__init__[0;34m()[0m
[0;32m     63 [0;31m        [0mms_iters[0m [0;34m=[0m [0;34m([0m[0mmodel_iter[0m[0;34m([0m[0mm[0m[0;34m)[0m [0;32mfor[0m [0mm[0m [0;32min[0m [0mms[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m        [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m        hook_fns = {key: [t for t in (hook(l) for l in ms_iters) if t is not None] 
[0m[0;32m     66 [0;31m                          for key, hook in hooks.items()}
[0m[0;32m     67 [0;31m        [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0mhook_fns[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


AttributeError: 'generator' object has no attribute 'register_forward_hook'
> [0;32m/Users/leonardourbina/code/ml/fastai2022p2/miniai/activations.py[0m(65)[0;36m__init__[0;34m()[0m
[0;32m     63 [0;31m        [0mms_iters[0m [0;34m=[0m [0;34m([0m[0mmodel_iter[0m[0;34m([0m[0mm[0m[0;34m)[0m [0;32mfor[0m [0mm[0m [0;32min[0m [0mms[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m        [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m        hook_fns = {key: [t for t in (hook(l) for l in ms_iters) if t is not None] 
[0m[0;32m     66 [0;31m                          for key, hook in hooks.items()}
[0m[0;32m     67 [0;31m        [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0mhook_fns[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user
Unexpected exception formatting exception. Falling back to standard excep

Traceback (most recent call last):
  File "/Users/leonardourbina/mambaforge/envs/fastai/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/p2/0hwfshcj20gblzjrlt3_kvp80000gn/T/ipykernel_12855/2544765029.py", line 1, in <module>
    learn.fit()
  File "/Users/leonardourbina/code/ml/fastai2022p2/miniai/learner.py", line 127, in fit
    self._fit(train, valid)
  File "/Users/leonardourbina/code/ml/fastai2022p2/miniai/learner.py", line 103, in _f
    o._callback(f'before_{self.name}')
  File "/Users/leonardourbina/code/ml/fastai2022p2/miniai/learner.py", line 168, in _callback
    def _callback(self, method_name): run_cbs(self.cbs, method_name, self)
  File "/Users/leonardourbina/code/ml/fastai2022p2/miniai/learner.py", line 46, in run_cbs
    if method is not None: method(learn)
  File "/Users/leonardourbina/code/ml/fastai2022p2/miniai/activations.py", line 100, in before_fit
    

In [32]:
B, T = xb.shape
pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)
pos_emb = nn.Embedding(config.ctx_size, config.n_embed)


tensor([[[-0.18, -0.63,  0.78,  ...,  0.66, -0.04, -0.84],
         [-0.37, -0.39, -1.51,  ...,  1.43,  0.35, -0.02],
         [-0.88,  0.90,  0.29,  ...,  1.12, -1.85, -0.51],
         ...,
         [-1.46, -0.24,  0.13,  ...,  0.23, -2.42, -0.39],
         [-1.03,  0.91, -1.23,  ..., -1.91,  0.20, -1.25],
         [-1.21,  0.45,  1.26,  ..., -0.51,  0.42, -0.73]]], grad_fn=<EmbeddingBackward0>)