In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import lightning as pl
import datasets
import numpy as np

pl.seed_everything(89026614)

text = datasets.load_dataset('tiny_shakespeare')["train"][0]["text"]
#text = open('../../Downloads/simplebooks/simplebooks-2-raw/train.txt', 'r').read()

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 89026614
Found cached dataset tiny_shakespeare (/Users/cztomsik/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 945.44it/s]


In [2]:
device = "mps"
block_size = 100
test_size = 1500
batch_size = 64

class MyDataset(data.Dataset):
    def __init__(self, text):
        super().__init__()
        vocab = sorted(set(text))
        self.vocab = vocab
        self.stoi = { ch: i for i, ch in enumerate(vocab) }
        self.itos = { i: ch for i, ch in enumerate(vocab) }
        self.data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)

    def __len__(self):
        return self.data.size(0) - block_size - 1

    def __getitem__(self, i):
        end = i + block_size
        return self.data[i:end], self.data[i + 1:end + 1]

dataset = MyDataset(text)
train = data.Subset(dataset, range(0, len(dataset) - test_size))
test = data.Subset(dataset, range(len(dataset) - test_size, len(dataset)))

In [3]:
class GPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=92, num_heads=4, num_layers=2):
        super().__init__()
        self.transformer = nn.ModuleDict(dict(
            token_emb = nn.Embedding(vocab_size, embed_dim),
            pos_emb = nn.Embedding(block_size, embed_dim),
            layers = nn.Sequential(*[Layer(embed_dim, num_heads) for _ in range(num_layers)]),
            norm = nn.LayerNorm(embed_dim),
        ))
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        
    def forward(self, x):
        pos = torch.arange(0, x.size(1), dtype=torch.long, device=x.device).unsqueeze(0)
        x = self.transformer.token_emb(x) + self.transformer.pos_emb(pos)
        x = self.transformer.norm(self.transformer.layers(x))
        return self.lm_head(x)

class Layer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, bias=False)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        self.register_buffer("mask", ~torch.tril(torch.ones(block_size, block_size)).to(bool).to(device))

    def forward(self, x):
        B, T, _ = x.shape
        xn = self.ln1(x)
        x = x + self.attn(xn, xn, xn, need_weights=False, attn_mask=self.mask[:T, :T])[0]
        x = x + self.mlp(self.ln2(x))
        return x

class Model(pl.LightningModule):
    def __init__(self, vocab_size, lr=0.007):
        super().__init__()
        self.lr = lr
        self.model = GPT(vocab_size)
    
    def forward(self, x, y=None):
        logits = self.model(x)
        return logits if y is None else F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)

    def train_dataloader(self):
        return data.DataLoader(train, batch_size=batch_size, num_workers=0, sampler=data.RandomSampler(train, False, 6_000))

    def training_step(self, batch, batch_idx):
        return self(*batch)

    def val_dataloader(self):
        return data.DataLoader(test, batch_size=batch_size, num_workers=0)
    
    def validation_step(self, batch, batch_idx):
        loss = self(*batch)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_epoch_end(self, outs):
        with torch.no_grad():
            print(self.generate("And now", 64))

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.lr)
        sched = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.95, last_epoch=-1)
        return [optim], [sched]

    # inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
    @torch.no_grad()
    def generate(self, str, max_new_tokens, top_k=10):
        ids = torch.tensor([dataset.stoi[ch] for ch in str], dtype=torch.long).unsqueeze(0).to(self.device)
        for _ in range(max_new_tokens):
            out = self(ids[:, -block_size:])
            logits = out[:, -1, :]
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("Inf")
            step_res = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            # auto-regression
            ids = torch.cat((ids, step_res), dim=1)
        return "".join([dataset.itos[int(i)] for i in ids[0]])

model = Model(len(dataset.vocab))
trainer = pl.Trainer(max_epochs=25, enable_progress_bar=True, accelerator="gpu" if device == "cuda" else device)
trainer.fit(model)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type | Params
-------------------------------
0 | model | GPT  | 226 K 
-------------------------------
226 K     Trainable params
0         Non-trainable params
226 K     Total params
0.905     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  7.55it/s]And nowSMzROiPpyw:Vp csHDEEfNvyw:VgGa-FBSM,mBNqHAnMr
:pCc'S
K.z3WkhyyN$
                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 118/118 [00:06<00:00, 18.63it/s, loss=2.36, v_num=78]And now so hod sand bare ade my aigh sheroouse to womes,
I he soto he a
Epoch 1: 100%|██████████| 118/118 [00:04<00:00, 24.40it/s, loss=2.01, v_num=78, test_loss=2.280]And nowath the the that tumer a seat
Mared frove mane hes shat the wivi
Epoch 2: 100%|██████████| 118/118 [00:04<00:00, 24.32it/s, loss=1.79, v_num=78, test_loss=2.010]And now who the but should armser, and son,
I lode seopanith, sincelain
Epoch 3: 100%|██████████| 118/118 [00:04<00:00, 24.25it/s, loss=1.69, v_num=78, test_loss=1.920]And now will tends it will would
An hles it were time broked
Was to the
Epoch 4: 100%|██████████| 118/118 [00:04<00:00, 24.27it/s, loss=1.62, v_num=78, test_loss=1.840]And now.

DUKE VINCENTIO:
Whyself you.

BROKE VINCENTIO:
Here havide th
Epoch 5: 100%|██████████| 118/118 [00:04<00:00, 24.24it/s, loss=1.58, v_num=78, test_loss=1.800]And now not mut of thy wisdom sust fall: thereing is tiss,
And have mar
Epoch 6: 

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 118/118 [00:06<00:00, 18.00it/s, loss=1.35, v_num=78, test_loss=1.480]


In [4]:
print(model.generate("O God, O God!", 1024))

O God, O God!

PERCY:
All, make I thank, I send to tell,
Where yourself those from them that shall be high,
For that to a sighto alter him; found wonder heart
As to pain thy seasilds.

KING LEWIS XI:
Hark'd wined answer with the trusts of the thurn way
The coming out at a
was my peace absed,
This, was shall be her tale him.

LEONTES:
How we say that have send to the woe of woe?
I will begin the can our siege,
Be now in and strengthend as
To carvile's, sevent their answers?
O hath hind was intend accused?

MARCIUS:
I had now!

BRUTUS,
Which in his present for two alone: it,
And humility.
O, sir, so my time wrongs of these come?

LUCIO:
Is no, haved she?

AUTOLYCUS:
How is forbid for soldiers;
So men to the royal of and had,
As must bear that this soul crown a law in too
I mine seated.

FLORIZEL:
Too well, and your could for some fools,
Which were we secreth and stard; this lack before you
As young, in my before tears.' apple, and with his crease sufford
By hand a bawd; but for for thee 