In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import datasets
import math

pl.seed_everything(42)

shakespeare = datasets.load_dataset('tiny_shakespeare')["train"][0]["text"]

In [2]:
device = "cpu"
block_size = 64

class CharDataset(torch.utils.data.Dataset):
    def __init__(self, text):
        super().__init__()
        vocab = sorted(set(text))
        self.vocab = vocab
        self.stoi = { ch: i for i, ch in enumerate(vocab) }
        self.itos = { i: ch for i, ch in enumerate(vocab) }
        self.data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long).to(device)

    def __len__(self):
        return self.data.size(0) - block_size - 1

    def __getitem__(self, i):
        end = i + block_size
        return self.data[i:end], self.data[i + 1:end + 1]

dataset = CharDataset(shakespeare)

In [3]:
# adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
# and https://github.com/karpathy/nanoGPT/blob/master/model.py
# defaults for gpt-mini
class MinGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=192, num_heads=6, num_layers=6, dropout=0.1):
        super().__init__()

        self.transformer = nn.ModuleDict(dict(
            token_emb = nn.Embedding(vocab_size, embed_dim),
            pos_emb = nn.Embedding(block_size, embed_dim),
            drop = nn.Dropout(dropout),
            layers = nn.Sequential(*[DecoderLayer(embed_dim, num_heads, dropout) for _ in range(num_layers)]),
            norm = nn.LayerNorm(embed_dim),
        ))

        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def forward(self, x):
        pos = torch.arange(0, x.size(1), dtype=torch.long).unsqueeze(0).to(device)

        x = self.transformer.token_emb(x) + self.transformer.pos_emb(pos)
        x = self.transformer.drop(x)
        x = self.transformer.layers(x)
        x = self.transformer.norm(x)
        x = self.lm_head(x)
        return x

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens, top_k=10):
        for _ in range(max_new_tokens):
            out = self(input_ids[:, -block_size:])
            logits = out[:, -1, :]
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("Inf")
            step_res = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            # auto-regression
            input_ids = torch.cat((input_ids, step_res), dim=1)

        return input_ids

class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.register_buffer("mask", ~torch.tril(torch.ones(block_size, block_size)).to(bool).to(device))

    def forward(self, x):
        B, T, _ = x.shape
        x = self.ln1(x)
        x = x + self.attn(x, x, x, need_weights=False, attn_mask=self.mask[:T, :T])[0]
        x = x + self.mlp(self.ln2(x))
        return x

class Model(pl.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.model = MinGPT(vocab_size, embed_dim=128, num_heads=4, num_layers=4, dropout=0.1)
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)

        if (batch_idx % 1000) == 0:
            with torch.no_grad():
                x = torch.tensor([dataset.stoi[s] for s in "O God, O God!"], dtype=torch.long)[None,...].to(device)
                y = self.model.generate(x, 64)[0]
                print("".join([dataset.itos[int(i)] for i in y]))

        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.001)

model = Model(len(dataset.vocab))
trainer = pl.Trainer(gradient_clip_val=1.0, max_epochs=1, enable_progress_bar=True, log_every_n_steps=100, accelerator="gpu" if device == "cuda" else "cpu")
trainer.fit(model, torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=0))

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name  | Type   | Params
---------------------------------
0 | model | MinGPT | 818 K 
---------------------------------
818 K     Trainable params
0         Non-trainable params
818 K     Total params
3.273     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/15685 [00:00<?, ?it/s] O God, O God!bQOoOdtYEJ:QZIjNc?$OPvxEeiktLWOgbuOLtSg!fzTL;p!t-VmTVlJumllcgXch
Epoch 0:   6%|▋         | 1000/15685 [02:41<39:28,  6.20it/s, loss=2.12, v_num=8]O God, O God!

Cillich yet hatelf poitizecthe  ples nond h his your
Thinteld 
Epoch 0:  13%|█▎        | 2000/15685 [05:29<37:31,  6.08it/s, loss=1.85, v_num=8]O God, O God!
Coufish the have one prisheds some. His
How tay fore fringo: wa
Epoch 0:  19%|█▉        | 3000/15685 [08:17<35:01,  6.04it/s, loss=1.97, v_num=8]O God, O God!
That fall mitchame o a hand witham teis whon othen tark,
Thouse
Epoch 0:  26%|██▌       | 4000/15685 [11:21<33:09,  5.87it/s, loss=1.81, v_num=8]O God, O God! my gracie, here in is thought
Thusersiivy so we to your lover i
Epoch 0:  32%|███▏      | 5000/15685 [14:27<30:54,  5.76it/s, loss=1.87, v_num=8]O God, O God!
My more say stry a do,
For to to free trike the, bring farest t
Epoch 0:  38%|███▊      | 6000/15685 [17:34<28:22,  5.69it/s, loss=1.8, v_nu

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 15685/15685 [47:50<00:00,  5.46it/s, loss=1.57, v_num=8]


In [5]:
y = model.to(device).model.generate(torch.tensor([dataset.stoi[s] for s in "ROMEO:"]).unsqueeze(0).to(device), 2000)
print("".join([dataset.itos[int(i)] for i in y[0]]))

ROMEO:
I well, say, thou can me, so, I with his thought thy stock
To-singest her ask. Fair hers, all this idle.

PETReARINA:
Sir, him it scares him, he, sim thou sat, sile.
Thou, belut should you well me? Baptient more have me.
Bapting at see not I way.

TGR:
He charge hold, wish, yell all, an these shall I may.

GRAMIO:
Nay's tressist men and by thee ye the believe
For my mother take than that should. Any what
all her ald, we make in all thither by to word
So this her's father her heaven in you; with, faulto me.

PATR:
Hare when sister, you before if whom I my go,
Which is it, for that and struck wrion well.

GRAMINA:
I beay, subjech tweay? Bianca, you that maigo? Biance, what you that so well not.
Twhid yourning you sh, whose all when thirt and widow me.

GRA:
Belield sister, sir, thee, she your loved,
The friend, stand held the to brive for we twenty.
Hard a but as stand the friends, belike old.

TRANIA:
O, suitor a mostlemen, Is their this gived, you have you
Stild all so we this s