In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import pytorch_lightning as pl
import datasets
import numpy as np

pl.seed_everything(89026614)

text = datasets.load_dataset('tiny_shakespeare')["train"][0]["text"]
#text = open('../../Downloads/simplebooks/simplebooks-2-raw/train.txt', 'r').read()

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 89026614
Found cached dataset tiny_shakespeare (/Users/cztomsik/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 853.48it/s]


In [2]:
device = "cpu"
block_size = 64

class MyDataset(data.Dataset):
    def __init__(self, text):
        super().__init__()
        vocab = sorted(set(text))
        self.vocab = vocab
        self.stoi = { ch: i for i, ch in enumerate(vocab) }
        self.itos = { i: ch for i, ch in enumerate(vocab) }
        self.data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long).to(device)

    def __len__(self):
        return self.data.size(0) - block_size - 1

    def __getitem__(self, i):
        end = i + block_size
        return self.data[i:end], self.data[i + 1:end + 1]

dataset = MyDataset(text)
train = data.Subset(dataset, range(0, len(dataset) - 500))
test = data.Subset(dataset, range(len(dataset) - 500, len(dataset)))

In [3]:
# inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
class MinGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, dropout):
        super().__init__()
        self.transformer = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Dropout(dropout),
            nn.Sequential(*[DecoderLayer(embed_dim, dropout) for _ in range(num_layers)]),
            nn.LayerNorm(embed_dim),
        )
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def forward(self, x):
        return self.lm_head(self.transformer(x))

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens, top_k=10):
        for _ in range(max_new_tokens):
            out = self(input_ids[:, -block_size:])
            logits = out[:, -1, :]
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("Inf")
            step_res = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            # auto-regression
            input_ids = torch.cat((input_ids, step_res), dim=1)

        return input_ids

class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = AFTFullCumsum(embed_dim, dropout)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.register_buffer("mask", ~torch.tril(torch.ones(block_size, block_size)).to(bool).to(device))

    def forward(self, x):
        B, T, _ = x.shape
        x = self.ln1(x)
        x = x + self.attn(x, self.mask)
        x = x + self.mlp(self.ln2(x))
        return x

# https://arxiv.org/pdf/2105.14103.pdf
# but divide by cumulative sum of exp(k) instead of sum of exp(w) @ exp(k)
class AFTFullCumsum(nn.Module):
    def __init__(self, embed_dim, dropout):
        super().__init__()
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.wbias = nn.Parameter(torch.ones(block_size, block_size), requires_grad=True)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.drop = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        B, T, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)

        exp_w = self.wbias[:T,:T].masked_fill(mask[:T, :T], -float("Inf")).exp()
        exp_k = k.exp()

        weighted_avg = torch.einsum("ik,bkj->bij", exp_w, exp_k * v) / torch.cumsum(exp_k, dim=1)
        Yt = F.sigmoid(q) * weighted_avg

        return self.drop(self.proj(Yt))


class Model(pl.LightningModule):
    def __init__(self, vocab_size, lr=0.004):
        super().__init__()
        self.lr = lr
        self.model = MinGPT(vocab_size, embed_dim=192, num_layers=3, dropout=0.1)
    
    def forward(self, x, y=None):
        logits = self.model(x)
        return logits if y is None else F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)

    def train_dataloader(self):
        return data.DataLoader(train, batch_size=64, num_workers=0, sampler=data.RandomSampler(train, False, 5_000))

    def training_step(self, batch, batch_idx):
        return self(*batch)

    def val_dataloader(self):
        return data.DataLoader(test, batch_size=16, num_workers=0)
    
    def validation_step(self, batch, batch_idx):
        loss = self(*batch)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_epoch_end(self, outs):
        with torch.no_grad():
            x = torch.tensor([dataset.stoi[s] for s in "And now "], dtype=torch.long)[None,...].to(device)
            y = self.model.generate(x, 64)[0]
            print("".join([dataset.itos[int(i)] for i in y]))

    def configure_optimizers(self):
        optim = torch.optim.AdamW(self.parameters(), lr=self.lr)
        sched = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.999, last_epoch=-1)
        return [optim], [sched]

model = Model(len(dataset.vocab))
trainer = pl.Trainer(gradient_clip_val=1.0, max_epochs=25, enable_progress_bar=True, accelerator="gpu" if device == "cuda" else "cpu")
trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params
---------------------------------
0 | model | MinGPT | 1.4 M 
---------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.489     Total estimated model params size (MB)


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 49.39it/s]

  rank_zero_warn(


And now -k
xP,,upRonjPjR.ZwCAK&K&i$?dQF'KCMcOdQGKCrVUDyyXVGgAl yIhduUqmw
                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 111/111 [00:13<00:00,  8.25it/s, loss=2.36, v_num=89]And now tor
Thowe late spoved ateltel or besthisereld mut ts ban are an 
Epoch 1: 100%|██████████| 111/111 [00:13<00:00,  8.19it/s, loss=2.06, v_num=89, test_loss=2.260]And now herink the,
And she acer blod hout now deth the thichs.

LETERD:
Epoch 2: 100%|██████████| 111/111 [00:13<00:00,  8.35it/s, loss=1.89, v_num=89, test_loss=2.110]And now the hove that tome and man thouperting him of lople
wither'd thy
Epoch 3: 100%|██████████| 111/111 [00:13<00:00,  8.41it/s, loss=1.8, v_num=89, test_loss=2.000] And now me; and their where boyodself
And intied the sipion his buieroug
Epoch 4: 100%|██████████| 111/111 [00:13<00:00,  8.44it/s, loss=1.73, v_num=89, test_loss=1.840]And now fell of ford shalf,
His dwo seech at is a do flest manges;
Or ma
Epoch 5: 100%|██████████| 111/111 [00:13<00:00,  8.53it/s, loss=1.68, v_num=89, test_loss=1.740]And now thoughts they his precend that
As servest for of your soul all
A
Epo

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 111/111 [00:15<00:00,  7.32it/s, loss=1.44, v_num=89, test_loss=1.430]


In [4]:
y = model.to(device).model.generate(torch.tensor([dataset.stoi[s] for s in "O God, O God!"]).unsqueeze(0).to(device), 650)
print("".join([dataset.itos[int(i)] for i in y[0]]))

O God, O God!

DORSET:
Why, that, all so? I have
Be cannot the second speak some, forth theirl,
On all swo, he is: I could not hold not both,
Thou wot, with you'll did.

JULIET:
The, sir, you have, speed, she their sidest is a honour.

JULIET:
I am him, I was so burder has wing
I' too lady; afore noble.

RICHARD:
Isabs, see, where't.
A good not hear him hath the from it;
And I know, I am stright, the duke of all thee.
And I be speak, aster, not so in he hath speeling he wind his ferewars!
For all on that bring other, but by his woe!
If sild not my command is a thou hast:
And he is sons and the dre show not.

PRINN:
I would not him, but most, set thou lep-
