In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import datasets
import math

pl.seed_everything(42)

shakespeare = datasets.load_dataset('tiny_shakespeare')["train"][0]["text"]

In [5]:
device = "cpu"
block_size = 32

class CharDataset(torch.utils.data.Dataset):
    def __init__(self, text):
        super().__init__()
        vocab = sorted(set(text))
        self.vocab = vocab
        self.stoi = { ch: i for i, ch in enumerate(vocab) }
        self.itos = { i: ch for i, ch in enumerate(vocab) }
        self.data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long).to(device)

    def __len__(self):
        return self.data.size(0) - block_size - 1

    def __getitem__(self, i):
        end = i + block_size
        return self.data[i:end], self.data[i + 1:end + 1]

dataset = CharDataset(shakespeare)

In [None]:
# adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
# and https://github.com/karpathy/nanoGPT/blob/master/model.py
# defaults for gpt-mini
class MinGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=192, num_heads=6, num_layers=6, dropout=0.1):
        super().__init__()

        self.transformer = nn.ModuleDict(dict(
            token_emb = nn.Embedding(vocab_size, embed_dim),
            pos_emb = nn.Embedding(block_size, embed_dim),
            drop = nn.Dropout(dropout),
            layers = nn.Sequential(*[DecoderLayer(embed_dim, num_heads, dropout) for _ in range(num_layers)]),
            norm = nn.LayerNorm(embed_dim),
        ))

        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def forward(self, x):
        pos = torch.arange(0, x.size(1), dtype=torch.long).unsqueeze(0).to(device)

        x = self.transformer.token_emb(x) + self.transformer.pos_emb(pos)
        x = self.transformer.drop(x)
        x = self.transformer.layers(x)
        x = self.transformer.norm(x)
        x = self.lm_head(x)
        return x

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens, top_k=10):
        for _ in range(max_new_tokens):
            out = self(input_ids[:, -block_size:])
            logits = out[:, -1, :]
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
            step_res = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            # auto-regression
            input_ids = torch.cat((input_ids, step_res), dim=1)

        return input_ids

class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate='tanh'),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = self.ln1(x)
        x = x + self.attn(x, x, x, need_weights=False, is_causal=True)[0]
        x = self.ln2(x)
        x = x + self.mlp(x)
        return x

class Model(pl.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.model = MinGPT(vocab_size, embed_dim=128, num_heads=4, num_layers=4, dropout=0.1)
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)

        if (batch_idx % 100) == 0:
            with torch.no_grad():
                context = "O God, O God!"
                x = torch.tensor([dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(device)
                y = self.model.generate(x, 64)[0]
                print("".join([dataset.itos[int(i)] for i in y]))

        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.001, betas=(0.9, 0.95))

model = Model(len(dataset.vocab))
trainer = pl.Trainer(gradient_clip_val=1.0, max_epochs=1, enable_progress_bar=True, log_every_n_steps=100, accelerator="gpu" if device == "cuda" else "cpu")
trainer.fit(model, torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=0))

In [15]:
y = model.model.generate(torch.tensor([dataset.stoi[s] for s in "ROMEO:"]).unsqueeze(0), 1000)
print("".join([dataset.itos[int(i)] for i in y[0]]))

ROMEO:
How I would never yea the waters.

UTIO:
PATSSA: he this father here be than she?

GATHARINA:
Sir.

GATHARINA:
So the move taise. I have all thrate me not her: the sclice, him be wen thus shee till a baced fell me her
You well beside as me sioccess as all
And with thee you they me to fest himselve to great:
If whom you do not did him.

PATARINA:
What I that break what when but this,
If it a devery his good on the git sing
Whilh so not, whomest fle yet hand,
A a frighther of him. She shall call belike move.

PATHARINA:
It as a for isic sisting of any foance all that,
Whut telr at on me mistruchs a cate;
What her what if strine her done and be swas while: and nor not she be best be this the little deed:
If him, but sine I my stiller.

GRUMIATBA:
Givily, whon I dere then, if you cree me they that an the lees of a were wife.

TAARINA:
Have I do, sir, belike, how yea for thee without book thou that beleave of shell,
What sweet thou
An save which heir scoiceol, you bell you shall I do