In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import datasets
import math
import numpy as np

pl.seed_everything(89026614)

text = datasets.load_dataset('tiny_shakespeare')["train"][0]["text"]

In [32]:
device = "cpu"
block_size = 128
fixed_epoch_size = block_size * 75

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, text):
        super().__init__()
        vocab = sorted(set(text))
        self.vocab = vocab
        self.stoi = { ch: i for i, ch in enumerate(vocab) }
        self.itos = { i: ch for i, ch in enumerate(vocab) }
        self.data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long).to(device)

    def __len__(self):
        return min(fixed_epoch_size, self.data.size(0) - block_size - 1)

    def __getitem__(self, i):
        i = np.random.randint(0, self.data.size(0) - block_size - 1)
        end = i + block_size
        return self.data[i:end], self.data[i + 1:end + 1]

dataset = MyDataset(text)

In [33]:
# adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
# and https://github.com/karpathy/nanoGPT/blob/master/model.py
# defaults for gpt-mini
class MinGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=192, num_heads=6, num_layers=6, dropout=0.1):
        super().__init__()
        self.transformer = nn.ModuleDict(dict(
            token_emb = nn.Embedding(vocab_size, embed_dim),
            pos_emb = nn.Embedding(block_size, embed_dim),
            drop = nn.Dropout(dropout),
            layers = nn.Sequential(*[DecoderLayer(embed_dim, num_heads, dropout) for _ in range(num_layers)]),
            norm = nn.LayerNorm(embed_dim),
        ))
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def forward(self, x):
        pos = torch.arange(0, x.size(1), dtype=torch.long).unsqueeze(0).to(device)
        x = self.transformer.drop(self.transformer.token_emb(x) + self.transformer.pos_emb(pos))
        x = self.transformer.norm(self.transformer.layers(x))
        return self.lm_head(x)

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens, top_k=10):
        for _ in range(max_new_tokens):
            out = self(input_ids[:, -block_size:])
            logits = out[:, -1, :]
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("Inf")
            step_res = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            # auto-regression
            input_ids = torch.cat((input_ids, step_res), dim=1)

        return input_ids

class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.register_buffer("mask", ~torch.tril(torch.ones(block_size, block_size)).to(bool).to(device))

    def forward(self, x):
        B, T, _ = x.shape
        x = self.ln1(x)
        x = x + self.attn(x, x, x, need_weights=False, attn_mask=self.mask[:T, :T])[0]
        x = x + self.mlp(self.ln2(x))
        return x

class Model(pl.LightningModule):
    def __init__(self, vocab_size, lr=0.004):
        super().__init__()
        self.lr = lr
        self.model = MinGPT(vocab_size, embed_dim=192, num_heads=4, num_layers=2, dropout=0.1)
    
    def forward(self, x):
        return self.model(x)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(dataset, batch_size=128, num_workers=0)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)
        #self.log("train_loss", loss)
        return loss
    
    def training_epoch_end(self, outs):
        with torch.no_grad():
            x = torch.tensor([dataset.stoi[s] for s in "O God, O God!"], dtype=torch.long)[None,...].to(device)
            y = self.model.generate(x, 100)[0]
            print("".join([dataset.itos[int(i)] for i in y]))
    
    def configure_optimizers(self):
        optim = torch.optim.AdamW(self.parameters(), lr=self.lr)
        sched = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.999, last_epoch=-1)
        return [optim], [sched]

model = Model(len(dataset.vocab))
trainer = pl.Trainer(gradient_clip_val=1.0, max_epochs=15, enable_progress_bar=True, log_every_n_steps=50, accelerator="gpu" if device == "cuda" else "cpu")
trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params
---------------------------------
0 | model | MinGPT | 939 K 
---------------------------------
939 K     Trainable params
0         Non-trainable params
939 K     Total params
3.759     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 75/75 [00:34<00:00,  2.16it/s, loss=2.4, v_num=65] O God, O God! COMICHe INCENIN:
Tharn meak mee yout wile tur mange amy'lor thor y isou sthow t mppangh s t thenead
Epoch 1: 100%|██████████| 75/75 [00:35<00:00,  2.14it/s, loss=1.99, v_num=65]O God, O God!

Fir, my compere wishe and brord ther do musied.

GLOUCIO:
Whontass madake so swe and of thy shem w
Epoch 2: 100%|██████████| 75/75 [00:34<00:00,  2.15it/s, loss=1.78, v_num=65]O God, O God!

LERWICK:
Godd land for not thous one that,
And hand heard with the dords with thought show such th
Epoch 3: 100%|██████████| 75/75 [00:36<00:00,  2.08it/s, loss=1.65, v_num=65]O God, O God!

KING HENRY GO:
Wher breather.
There why, who was thou more what
Mand that sade black
This ble dood
Epoch 4: 100%|██████████| 75/75 [00:41<00:00,  1.82it/s, loss=1.58, v_num=65]O God, O God!

ROMEO:

LADY GAUDIO:

Clords, thor, my light?

BENVOLIO:
He mucame are the hard, that in in the co
Epoch 5: 100%|██████████| 75/75 [00:42<00:00,

`Trainer.fit` stopped: `max_epochs=15` reached.


Epoch 14: 100%|██████████| 75/75 [00:38<00:00,  1.95it/s, loss=1.41, v_num=65]


In [39]:
y = model.to(device).model.generate(torch.tensor([dataset.stoi[s] for s in "O God, O God!"]).unsqueeze(0).to(device), 650)
print("".join([dataset.itos[int(i)] for i in y[0]]))

O God, O God! what do't--
For the conting tried me sad, these crown,
Why warm be done that all stroke of his fonds.

PRINCE EDWARD:
And I can naildle he comes spent from minish;
They shall, and make, he is formal hands.
How deed, that he, she-may it serving in my fortunate
I think that time age of my son honour,
An to-night from abetwot, my blood will pardon,
That you tlander this bade:
If you most thank your grace one sets.
Where, I did. What to wast we, boy, my suffect
As I be domininish against his shall.' 
Firsteremy; what and a lower set boy?

ROMEO:
Not you to mean, take, sever humble murder.

MENENIUS:
Were am my brother to-night.

MARCIUS:
If I kn
