In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GPT(nn.Module):
    def __init__(self, vocab_size, block_size=256, embed_dim=64, num_heads=1, num_layers=4):
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.ModuleDict(dict(
            token_emb = nn.Embedding(vocab_size, embed_dim),
            pos_emb = nn.Embedding(block_size, embed_dim),
            layers = nn.Sequential(*[Layer(block_size, embed_dim, num_heads) for _ in range(num_layers)]),
            norm = nn.LayerNorm(embed_dim),
        ))
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        
    def forward(self, x):
        pos = torch.arange(0, x.size(1), dtype=torch.long, device=x.device).unsqueeze(0)
        x = self.transformer.token_emb(x) + self.transformer.pos_emb(pos)
        x = self.transformer.norm(self.transformer.layers(x))
        return self.lm_head(x)

class Layer(nn.Module):
    def __init__(self, block_size, embed_dim, num_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, bias=False)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        self.register_buffer("mask", ~torch.tril(torch.ones(block_size, block_size)).to(bool))

    def forward(self, x):
        B, T, _ = x.shape
        xn = self.ln1(x)
        x = x + self.attn(xn, xn, xn, need_weights=False, attn_mask=self.mask[:T, :T])[0]
        x = x + self.mlp(self.ln2(x))
        return x

In [2]:
import lightning as pl
from shared import corpus, tokenizers, trainers

text = corpus.shakespeare()
tokenizer = tokenizers.unique_chars(text)

pl.seed_everything(89026614)
model = GPT(tokenizer.get_vocab_size())
trainer = trainers.CausalTrainer(model, tokenizer, device = "mps")
trainer.train(text, batch_size=36, epochs=25)

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset tiny_shakespeare (/Users/cztomsik/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 716.89it/s]
Global seed set to 89026614
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type | Params
-------------------------------
0 | model | GPT  | 223 K 
-------------------------------
223 K     Trainable params
0         Non-trainable params
223 K     Total params
0.895     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s]And nowItE,$u$-mJa$NCA
qGeXqhcSixMMM3Kh:yyDX'&'CT!zoq JW;P-3aeNFtxYT!EC
                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 202/202 [00:12<00:00, 15.76it/s, loss=2.42, v_num=69]And nows thane w st my ou thtoord that ader outhe thomis
And and fullit
Epoch 1: 100%|██████████| 202/202 [00:11<00:00, 18.14it/s, loss=2.07, v_num=69, test_loss=2.410]And now: wheree hat thast,
Whis this he a we morth of mer shre to the h
Epoch 2: 100%|██████████| 202/202 [00:11<00:00, 17.95it/s, loss=1.78, v_num=69, test_loss=2.120]And now,
This be when and meavence your hone.

BRAKEY:
This teer, the s
Epoch 3: 100%|██████████| 202/202 [00:11<00:00, 17.95it/s, loss=1.63, v_num=69, test_loss=1.950]And now'd and twidones with think;
Became the discand strothink; nother
Epoch 4: 100%|██████████| 202/202 [00:11<00:00, 17.33it/s, loss=1.55, v_num=69, test_loss=1.860]And now, thou whill.

POMPEY:
Goven, thou would have it. Her you hast I
Epoch 5: 100%|██████████| 202/202 [00:11<00:00, 17.58it/s, loss=1.51, v_num=69, test_loss=1.720]And now, I have the he sustrible own soon your sweet of thrat.

GRUMIO:
Epoch 6: 

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 202/202 [00:15<00:00, 12.69it/s, loss=1.3, v_num=69, test_loss=1.410]


In [None]:
print(trainer.wrapper.generate("O God, O God!", 650))