In [1]:
from shared import corpus, tokenizers, datasets

text = corpus.shakespeare()
tokenizer = tokenizers.unique_chars(text)
train, test = datasets.causal(tokenizer.encode(text))

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset tiny_shakespeare (/Users/cztomsik/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 1197.12it/s]


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, num_layers=4):
        super().__init__()
        self.transformer = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Sequential(*[Layer(embed_dim) for _ in range(num_layers)]),
            nn.LayerNorm(embed_dim),
        )
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def forward(self, x):
        return self.lm_head(self.transformer(x))

class Layer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.f = nn.Linear(embed_dim, embed_dim, bias=False)
        self.q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v = nn.Linear(embed_dim, embed_dim, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        self.proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x):
        B, T, C = x.shape
        xn = self.ln1(x)
        prev = F.pad(xn, (0, 0, 1, -1))
        prev2 = F.pad(xn, (0, 0, 2, -2))
        f = torch.sigmoid(self.f(prev)) # prev can say what should be forgotten from prev2 (x-2)
        q = torch.sigmoid(self.q(xn)) # what should be accepted from prev
        v = self.v(prev) # what the prev is providing
        attn = self.proj((q * v) - (f * self.v(prev2)))

        x = x + attn
        x = x + self.mlp(self.ln2(x))
        return x


import lightning as pl
import torch.utils.data as data
device = "mps"
batch_size = 36

class Model(pl.LightningModule):
    def __init__(self, vocab_size, lr=0.007):
        super().__init__()
        self.lr = lr
        self.model = GPT(vocab_size)
    
    def forward(self, x, y=None):
        logits = self.model(x)
        return logits if y is None else F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)

    def train_dataloader(self):
        return data.DataLoader(train, batch_size=batch_size, num_workers=0, sampler=data.RandomSampler(train, False, 6_000))

    def training_step(self, batch, batch_idx):
        return self(*batch)

    def val_dataloader(self):
        return data.DataLoader(test, batch_size=batch_size, num_workers=0)
    
    def validation_step(self, batch, batch_idx):
        loss = self(*batch)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_epoch_end(self, outs):
        with torch.no_grad():
            print(self.generate("And now", 64))

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.lr)
        sched = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.95, last_epoch=-1)
        return [optim], [sched]

    # inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
    @torch.no_grad()
    def generate(self, str, max_new_tokens, top_k=10):
        ids = torch.tensor(tokenizer.encode(str), dtype=torch.long).unsqueeze(0).to(self.device)
        for _ in range(max_new_tokens):
            #out = self(ids[:, -block_size:])
            out = self(ids)
            logits = out[:, -1, :]
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("Inf")
            step_res = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            # auto-regression
            ids = torch.cat((ids, step_res), dim=1)
        return tokenizer.decode(ids[0].tolist())

pl.seed_everything(89026614)
model = Model(tokenizer.vocab_size)
trainer = pl.Trainer(max_epochs=25, enable_progress_bar=True, accelerator="gpu" if device == "cuda" else device)
trainer.fit(model)

Global seed set to 89026614
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type | Params
-------------------------------
0 | model | GPT  | 207 K 
-------------------------------
207 K     Trainable params
0         Non-trainable params
207 K     Total params
0.829     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 27.29it/s]And nowo!UTR3TRNdq;HYTWtzHohdfcEAo'IfG
wo!!L;PPvPsOEqGAdkeldqD3GcULbgTR
                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 209/209 [00:06<00:00, 30.55it/s, loss=1.6, v_num=44] And now we leave help,
And my turn'd fear to thrain mys mercy.

GLORGER
Epoch 1: 100%|██████████| 209/209 [00:06<00:00, 32.78it/s, loss=1.47, v_num=44, test_loss=1.770]And now, as many!

MENENIUS:
The peucenning.

LADY ANNE:
Menee in ourse
Epoch 2: 100%|██████████| 209/209 [00:06<00:00, 32.58it/s, loss=1.4, v_num=44, test_loss=1.580] And now master. Go. But my corson sorrow him spinch me to be, to him.


Epoch 3: 100%|██████████| 209/209 [00:06<00:00, 32.84it/s, loss=1.38, v_num=44, test_loss=1.510]And now is his but it wants be neighbour'd thee of the masts; and thy s
Epoch 4: 100%|██████████| 209/209 [00:06<00:00, 32.86it/s, loss=1.35, v_num=44, test_loss=1.480]And now on this state in mine attority.

BENVOLIO:
No, the world helse,
Epoch 5: 100%|██████████| 209/209 [00:06<00:00, 32.70it/s, loss=1.34, v_num=44, test_loss=1.440]And now, false time thy father! Ot their wit-match, with you. Tyrrel, a
Epoch 6: 

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 209/209 [00:11<00:00, 18.48it/s, loss=1.2, v_num=44, test_loss=1.400]


In [4]:
print(model.generate("O God, O God!", 650))

O God, O God!

First Senator:
If you like.

First Hunger than your family! mark the worst of dark;
I came infrink, madam, farewell.

DUCHESS OF YORK:
My dearest of his people to his heart the time?
What still, I have such perfection, the most breathe forest bid him so breathize some hurt as none miles of damned at the harms and famous and play.
A begging of ill.

MARCIUS:
'Tyition:
Truly son
He had shed and down, and not so, because
He does arriving my princely good friends against the glorious prince you have not stir seem into
Supply she shall show your gentleman:
It more perfect the princely good as his face?

First Murderer:
I do proved him, with an o
