In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GPT(nn.Module):
    def __init__(self, vocab_size, block_size=256, embed_dim=64, num_layers=4):
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Sequential(*[Layer(embed_dim) for _ in range(num_layers)]),
            nn.LayerNorm(embed_dim),
        )
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def forward(self, x):
        return self.lm_head(self.transformer(x))

class Layer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.f = nn.Linear(embed_dim, embed_dim, bias=False)
        self.q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v = nn.Linear(embed_dim, embed_dim, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        self.proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x):
        B, T, C = x.shape
        xn = self.ln1(x)
        prev = F.pad(xn, (0, 0, 1, -1))
        prev2 = F.pad(xn, (0, 0, 2, -2))
        f = torch.sigmoid(self.f(prev)) # prev can say what should be forgotten from prev2 (x-2)
        q = torch.sigmoid(self.q(xn)) # what should be accepted from prev
        v = self.v(prev) # what the prev is providing
        attn = self.proj((q * v) - (f * self.v(prev2)))

        x = x + attn
        x = x + self.mlp(self.ln2(x))
        return x

In [2]:
import lightning as pl
from shared import corpus, tokenizers, trainers

text = corpus.shakespeare()
tokenizer = tokenizers.unique_chars(text)

pl.seed_everything(89026614)
model = GPT(tokenizer.get_vocab_size())
trainer = trainers.CausalTrainer(model, tokenizer, device = "mps")
trainer.train(text, batch_size=36, epochs=25)

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset tiny_shakespeare (/Users/cztomsik/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 439.70it/s]
Global seed set to 89026614
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type | Params
-------------------------------
0 | model | GPT  | 207 K 
-------------------------------
207 K     Trainable params
0         Non-trainable params
207 K     Total params
0.829     Total estimated model params size (MB)


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 27.27it/s]

  rank_zero_warn(


And nowO:i.iGAN&yJFRt!XWxt!seFrxbXqh. WWjkjwoeiyJGH!EnzNBU.O:3Opext3woh
                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 202/202 [00:07<00:00, 27.95it/s, loss=1.6, v_num=68] And now,
But your libted frues me for I have but but her treach your ba
Epoch 1: 100%|██████████| 202/202 [00:06<00:00, 30.68it/s, loss=1.47, v_num=68, test_loss=1.770]And now, friends in proud, tillffer upon; seel tears mine.

ANGELO:
The
Epoch 2: 100%|██████████| 202/202 [00:06<00:00, 31.13it/s, loss=1.4, v_num=68, test_loss=1.560] And now?

CAMILLO:
Ay, sun my life,
But is the demany love of my hound 
Epoch 3: 100%|██████████| 202/202 [00:06<00:00, 31.41it/s, loss=1.38, v_num=68, test_loss=1.500]And now, a set it bates, or barks a word you been his lady's born sound
Epoch 4: 100%|██████████| 202/202 [00:06<00:00, 31.19it/s, loss=1.35, v_num=68, test_loss=1.480]And now turn then the doing thy sease.

Second Servingman one and taken
Epoch 5: 100%|██████████| 202/202 [00:06<00:00, 31.01it/s, loss=1.34, v_num=68, test_loss=1.440]And now, it would to his plichers of this so faint.

AUFIDIUS:
My such 
Epoch 6: 

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 202/202 [00:12<00:00, 16.10it/s, loss=1.2, v_num=68, test_loss=1.410]


In [3]:
print(trainer.wrapper.generate("O God, O God!", 650))

O God, O God!

First Senator:
If you like.

First Hunger than your family! mark the worst of dark;
I came infrink, madam, farewell.

DUCHESS OF YORK:
My dearest of his people to his heart the time?
What still, I have such perfection, the most breathe forest bid him so breathize some hurt as none miles of damned at the harms and famous and play.
A begging of ill.

MARCIUS:
'Tyition:
Truly son
He had shed and down, and not so, because
He does arriving my princely good friends against the glorious prince you have not stir seem into
Supply she shall show your gentleman:
It more perfect the princely good as his face?

First Murderer:
I do proved him, with an o
