In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class GPT(nn.Module):
    def __init__(self, vocab_size, block_size=256, embed_dim=64, num_layers=4):
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Sequential(*[Layer(block_size, embed_dim) for _ in range(num_layers)]),
            nn.LayerNorm(embed_dim),
        )
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def forward(self, x):
        return self.lm_head(self.transformer(x))

class Layer(nn.Module):
    def __init__(self, block_size, embed_dim):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        self.register_buffer("mask", ~torch.tril(torch.ones(block_size, block_size, dtype=torch.bool)))
        t = torch.arange(-block_size+1, 1)
        self.register_buffer("alibi", torch.tril(torch.cat([t, t[:-1]]).unfold(0, len(t), 1).flip(0)))

    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv(self.ln1(x)).chunk(3, dim=-1)
        w = q @ k.transpose(-2, -1) / math.sqrt(C) # (B, T, T)
        w = w + self.alibi[:T, :T]
        w.masked_fill_(self.mask[:T, :T], float("-Inf"))
        w = F.softmax(w, dim=-1)
        attn = self.proj(w @ v) # (B, T, C)

        x = x + attn
        x = x + self.mlp(self.ln2(x))
        return x

In [2]:
import lightning as pl
from shared import corpus, tokenizers, trainers

text = corpus.shakespeare()
tokenizer = tokenizers.unique_chars(text)

pl.seed_everything(89026614)
model = GPT(tokenizer.get_vocab_size())
trainer = trainers.CausalTrainer(model, tokenizer, device = "mps")
trainer.train(text, batch_size=36, epochs=25)

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset tiny_shakespeare (/Users/cztomsik/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 602.80it/s]
Global seed set to 89026614
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type | Params
-------------------------------
0 | model | GPT  | 207 K 
-------------------------------
207 K     Trainable params
0         Non-trainable params
207 K     Total params
0.829     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:01<00:00,  1.51it/s]And now;ryy,x;x$ q-Q I;BpXMcNJrkSmTb&F-qKb
 bd3Bl&FNv-$Is?pXxazZJBCT b-
                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 202/202 [00:10<00:00, 19.15it/s, loss=1.75, v_num=70]And now,
The beart the two met while the dangue.
Ingeat and turn. Marem
Epoch 1: 100%|██████████| 202/202 [00:07<00:00, 26.66it/s, loss=1.57, v_num=70, test_loss=1.890]And now,
Again, you are the were sir, a sir, in, she before of Geaver a
Epoch 2: 100%|██████████| 202/202 [00:07<00:00, 26.76it/s, loss=1.49, v_num=70, test_loss=1.720]And now,
As show me been
Wherei wit thyself a prisons and mine a charts
Epoch 3: 100%|██████████| 202/202 [00:07<00:00, 26.77it/s, loss=1.44, v_num=70, test_loss=1.600]And now and speak and souls, tender of the wingle is and at thou,
Or yo
Epoch 4: 100%|██████████| 202/202 [00:07<00:00, 26.78it/s, loss=1.41, v_num=70, test_loss=1.530]And now about in the grows not bie speak'st blaze, far turn his blood l
Epoch 5: 100%|██████████| 202/202 [00:07<00:00, 26.47it/s, loss=1.39, v_num=70, test_loss=1.540]And now which bried,
For the times fair
But be so murder and appearer-s
Epoch 6: 

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 202/202 [00:10<00:00, 19.75it/s, loss=1.25, v_num=70, test_loss=1.400]


In [3]:
print(trainer.wrapper.generate("O God, O God!", 650))

O God, O God! what wicked the more to so time to mribunes.
If thou mayor her half their bones in hope,
You stand and with those could be thy cause away:
Nay, were as my heart.

CATESBY:
Yet, thou hap to call the will heart, in stand here in backs.
Have you to comfort, I shall be my honour,
Such shall be put to the thirst nost make us hanging weight defend home
It in the multies.
This, my lord, and, a month the whole my face;
Sailorous sure these moon the gates their sweet banish'd in a word.

MENENIUS:
O, misery have made for any
A distress safety to hear
My point, while at Plantagenet; so I'll say, is so honour:
No, then come
the prison is the clouds, sh
