In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, vocab_size, block_size=256, embed_dim=64, num_layers=4):
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Sequential(*[Layer(embed_dim, block_size) for _ in range(num_layers)]),
            nn.LayerNorm(embed_dim),
        )
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def forward(self, x):
        if x.shape[1] < self.block_size:
            x = F.pad(x, (0, self.block_size - x.shape[1]))
        
        return self.lm_head(self.transformer(x))

class Layer(nn.Module):
    def __init__(self, embed_dim, block_size):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.mlp1 = mlp(block_size, block_size // 2)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp2 = mlp(embed_dim, embed_dim * 4)

    def forward(self, x):
        B, T, C = x.shape
        x = x + self.mlp1(self.ln1(x).transpose(1, 2)).transpose(1, 2)
        x = x + self.mlp2(self.ln2(x))
        return x

def mlp(dim_in, dim_hidden, dim_out = None):
    return nn.Sequential(
        nn.Linear(dim_in, dim_hidden),
        nn.GELU(),
        nn.Linear(dim_hidden, dim_out or dim_in),
    )

In [2]:
import lightning as pl
from shared import corpus, tokenizers, trainers

text = corpus.shakespeare()
tokenizer = tokenizers.unique_chars(text, mask_token="😷")

pl.seed_everything(89026614)
model = Encoder(tokenizer.get_vocab_size())
trainer = trainers.MLMTrainer(model, tokenizer, device = "mps")
trainer.train(text, batch_size=36, epochs=25)

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset tiny_shakespeare (/Users/cztomsik/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 583.33it/s]
Global seed set to 89026614
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type    | Params
----------------------------------
0 | model | Encoder | 405 K 
----------------------------------
405 K     Trainable params
0         Non-trainable params
405 K     Total params
1.623     Total estimated model params size (MB)


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 27.67it/s]Hello, my n😷me is
[[('P', 0.046438705176115036), ('m', 0.03975217416882515), ('i', 0.03714496269822121)]]
                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 0: 100%|██████████| 202/202 [00:06<00:00, 31.51it/s, loss=3.31, v_num=49]Hello, my n😷me is
[[(' ', 0.14500093460083008), ('e', 0.09139189124107361), ('o', 0.06297865509986877)]]
Epoch 1: 100%|██████████| 202/202 [00:06<00:00, 33.27it/s, loss=3.32, v_num=49, test_loss=3.330]Hello, my n😷me is
[[(' ', 0.15100760757923126), ('e', 0.07783792167901993), ('t', 0.0668095275759697)]]
Epoch 2: 100%|██████████| 202/202 [00:06<00:00, 33.31it/s, loss=2.84, v_num=49, test_loss=3.340]Hello, my n😷me is
[[('a', 0.1612151563167572), (' ', 0.16104263067245483), ('e', 0.1421004980802536)]]
Epoch 3: 100%|██████████| 202/202 [00:06<00:00, 33.04it/s, loss=2.41, v_num=49, test_loss=2.770]Hello, my n😷me is
[[(' ', 0.36132028698921204), ('a', 0.18417449295520782), ('i', 0.1098964512348175)]]
Epoch 4: 100%|██████████| 202/202 [00:06<00:00, 33.08it/s, loss=2.19, v_num=49, test_loss=2.320]Hello, my n😷me is
[[('o', 0.1993735134601593), (' ', 0.1922932118177414), ('a', 0.18683621287345886)]]
Epoch 5: 100%|████

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 202/202 [00:11<00:00, 18.24it/s, loss=1.48, v_num=49, test_loss=1.510]


In [4]:
trainer.wrapper.fill("Make n😷 more")

[[('o', 0.7050660848617554),
  ('a', 0.18505017459392548),
  ('i', 0.09816652536392212)]]