In [1]:
from shared import corpus, tokenizers, datasets

text = corpus.shakespeare()
tokenizer = tokenizers.unique_chars(text)
train, test = datasets.causal(tokenizer.encode(text).ids)

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset tiny_shakespeare (/Users/cztomsik/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 1229.88it/s]


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, num_layers=4):
        super().__init__()
        self.transformer = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Sequential(*[Layer(embed_dim) for _ in range(num_layers)]),
            nn.LayerNorm(embed_dim),
        )
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def forward(self, x):
        return self.lm_head(self.transformer(x))

class Layer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v = nn.Linear(embed_dim, embed_dim, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        self.proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x):
        B, T, C = x.shape
        xn = self.ln1(x)
        q = self.q(xn)
        v = self.v(F.pad(xn, (0, 0, 1, -1)))
        attn = self.proj(torch.sigmoid(q) * v)

        x = x + attn
        x = x + self.mlp(self.ln2(x))
        return x


import lightning as pl
import torch.utils.data as data
device = "mps"
batch_size = 36

class Model(pl.LightningModule):
    def __init__(self, vocab_size, lr=0.007):
        super().__init__()
        self.lr = lr
        self.model = GPT(vocab_size)
    
    def forward(self, x, y=None):
        logits = self.model(x)
        return logits if y is None else F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)

    def train_dataloader(self):
        return data.DataLoader(train, batch_size=batch_size, num_workers=0, sampler=data.RandomSampler(train, False, 6_000))

    def training_step(self, batch, batch_idx):
        return self(*batch)

    def val_dataloader(self):
        return data.DataLoader(test, batch_size=batch_size, num_workers=0)
    
    def validation_step(self, batch, batch_idx):
        loss = self(*batch)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_epoch_end(self, outs):
        with torch.no_grad():
            print(self.generate("And now", 64))

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.lr)
        sched = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.95, last_epoch=-1)
        return [optim], [sched]

    # inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
    @torch.no_grad()
    def generate(self, str, max_new_tokens, top_k=10):
        ids = torch.tensor(tokenizer.encode(str).ids, dtype=torch.long).unsqueeze(0).to(self.device)
        for _ in range(max_new_tokens):
            #out = self(ids[:, -block_size:])
            out = self(ids)
            logits = out[:, -1, :]
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("Inf")
            step_res = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            # auto-regression
            ids = torch.cat((ids, step_res), dim=1)
        return tokenizer.decode(ids[0].tolist())

pl.seed_everything(89026614)
model = Model(tokenizer.get_vocab_size())
trainer = pl.Trainer(max_epochs=25, enable_progress_bar=True, accelerator="gpu" if device == "cuda" else device)
trainer.fit(model)

Global seed set to 89026614
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type | Params
-------------------------------
0 | model | GPT  | 190 K 
-------------------------------
190 K     Trainable params
0         Non-trainable params
190 K     Total params
0.764     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 22.64it/s]And nowX;rncvc'UBles&aRwgBobqxZgUv;Si&Rw-l&cvXqMw:ce3VcNxoiqMvtKIIkkcfE
                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 209/209 [00:05<00:00, 35.82it/s, loss=1.67, v_num=59]And now to true in them at your grats think flembs
Them thy purse me, i
Epoch 1: 100%|██████████| 209/209 [00:05<00:00, 38.31it/s, loss=1.53, v_num=59, test_loss=1.830]And now, and his dasting ship's the hand,
Which his shall this fight an
Epoch 2: 100%|██████████| 209/209 [00:05<00:00, 38.26it/s, loss=1.46, v_num=59, test_loss=1.680]And now throat my side: an oble to sleep: interton burding another.

DU
Epoch 3: 100%|██████████| 209/209 [00:05<00:00, 38.33it/s, loss=1.43, v_num=59, test_loss=1.590]And now woman many welces and worthy heirs in.

DUKE VINCENTIO:
The sen
Epoch 4: 100%|██████████| 209/209 [00:05<00:00, 38.94it/s, loss=1.41, v_num=59, test_loss=1.530]And now him will persual dance thank not all that I am somethink may be
Epoch 5: 100%|██████████| 209/209 [00:05<00:00, 38.99it/s, loss=1.39, v_num=59, test_loss=1.540]And now? which these my body byself and from him bearing me,
Thou but I
Epoch 6: 

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 209/209 [00:12<00:00, 16.38it/s, loss=1.29, v_num=59, test_loss=1.420]


In [3]:
print(model.generate("O God, O God!", 650))

O God, O God! what wise, inconspirator:
And thy warrant. Somerset a pleased, only till you have not frowns are them the period to do nor hands upon, this so the sanst me to his heart will, and him, my lord's say 'Ay.
I authority:
If I means over business?
And woodfellow? where's my honour, much,
Have your garpends the searchs disgraces may should sleep out of your gentlemen,--

QUEEN:
All the writ, a month the whose my face;
Saddle will chide, thus. He do you better sleep himself in such a dream'd,
I have, I would done;
We must some presence as at his sudden so beseech you?

RIVERS:
How so still hold, man?
Once the fires to the daughters o' the blood
Hast
