## NLP with MPS

In [71]:
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers

# Initialize
tokenizer = Tokenizer(models.BPE())
tokenizer.normalizer = normalizers.NFKC()
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

# Setup trainer
trainer = trainers.BpeTrainer(
    vocab_size=2**10,
    special_tokens=["<pad>", "<unk>", "<s>", "</s>"]
)

# Check file exists
import os
corpus_path = "../../data/shakespeare/main.txt"
assert os.path.exists(corpus_path), f"File not found: {corpus_path}"

# Train
print("Vocab size before training:", tokenizer.get_vocab_size())
tokenizer.train([corpus_path], trainer)

# Save and verify
tokenizer.save("bpe_tokenizer.json")
print("Vocab size after training:", tokenizer.get_vocab_size())


Vocab size before training: 0



Vocab size after training: 1024


In [72]:
# Helper functions
import torch

def dec2bin(x, bits):
    # mask = 2 ** torch.arange(bits).to(x.device, x.dtype)
    mask = 2 ** torch.arange(bits - 1, -1, -1).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).float()


def bin2dec(b, bits):
    mask = 2 ** torch.arange(bits - 1, -1, -1).to(b.device, b.dtype)
    return torch.sum(mask * b, -1)

# # Example
# NUM_BITS_PER_TOKEN = 10
# d = torch.randint(0, 16, (3, 6))
# b = dec2bin(d, NUM_BITS_PER_TOKEN)
# d_rec = bin2dec(b, NUM_BITS_PER_TOKEN)

In [111]:
# ---------------------------------------------------------------------
# Run training using MPS
# ---------------------------------------------------------------------

import math
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from ptn.models.modelling_nanogpt import GPT, GPTConfig
from ptn.dists._abc import AbstractDisributionHeadConfig
from ptn.dists.mps_sigma_lsf import MPS_SIGMA_LSF

#  --device=cpu --compile=False --eval_iters=20 --log_interval=1 --block_size=64 --batch_size=12 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0

# ---------------------------------------------------------------------
# Hyperparameters
# ---------------------------------------------------------------------
lr = 3e-4
block_size = 32
batch_size = 12
n_layer = 4
n_head = 4
n_embd = 128
dropout = 0.0
bit_size = 2
n_bits_per_token = 10
mps_rank = 8
# ---------------------------------------------------------------------

# ---------------------------------------------------------------------
# 1) Load the trained tokenizer
# ---------------------------------------------------------------------
tokenizer = Tokenizer.from_file("bpe_tokenizer.json")

# ---------------------------------------------------------------------
# 2) Dataset and DataLoader
# ---------------------------------------------------------------------
class TextDataset(Dataset):
    def __init__(self, path, tokenizer, block_size, n_bits_per_token=None):
        with open(path, 'r', encoding='utf-8') as f:
            self.text = f.read()
        # encode entire corpus
        self.tokens = tokenizer.encode(self.text).ids
        self.block_size = block_size
        self.n_bits_per_token = n_bits_per_token

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx):
        x = torch.tensor(self.tokens[idx : idx + self.block_size], dtype=torch.long)
        y = torch.tensor(self.tokens[idx + 1 : idx + 1 + self.block_size], dtype=torch.long)
        if self.n_bits_per_token is not None:
            x_binary = dec2bin(x, self.n_bits_per_token)
            y_binary = dec2bin(y, self.n_bits_per_token)
            return x_binary.reshape(-1).to(torch.long), y_binary.reshape(-1).to(torch.long)
        return x, y

# ---------------------------------------------------------------------
# 3) Instantiate model
# ---------------------------------------------------------------------

class TTModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.block_size is not None
        print(f"config.vocab_size: {config.vocab_size}")
        word_size = math.log(config.vocab_size, bit_size)
        if word_size % 1 != 0:
            raise ValueError(f"vocab_size must be a power of {bit_size}, got {config.vocab_size}")
        print(f"word_size: {word_size}")
        self.mps = MPS_SIGMA_LSF(AbstractDisributionHeadConfig(
            d_model=1,
            d_output=bit_size,
            horizon=config.block_size * int(word_size),
            rank=mps_rank
        ))

    def forward(self, x, targets=None):
        B = x.shape[0]
        x = torch.ones(B, 1, device=x.device)
        out =  self.mps(x, y)
        return out.logits, out.loss

    def generate(self):
        x = torch.ones(1, 1, device=next(self.parameters()).device)
        return self.mps.generate(x)

config = GPTConfig(
    vocab_size=tokenizer.get_vocab_size(),
    n_layer=n_layer,
    n_head=n_head,
    n_embd=n_embd,
    dropout=dropout,
    block_size=block_size
)
# MPS model
model = TTModel(config)

# # GPT model
# model = GPT(config)

# ---------------------------------------------------------------------
# 4) Basic training setup
# ---------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# model = torch.compile(model) # requires PyTorch 2.0

train_dataset = TextDataset("../../data/shakespeare/main.txt", tokenizer, block_size, n_bits_per_token=10)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# ---------------------------------------------------------------------
# 5) Training loop
# ---------------------------------------------------------------------
epochs = 5
model.train()

for epoch in range(epochs):
    total_loss = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
    for i, (x, y) in pbar:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits, loss = model(x, targets=y)  # many GPT impls return both
        if loss is None:  # if model doesn't return loss internally
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if i % 50 == 0:
            # sample from the model
            xb = model.generate()
            xd = bin2dec(xb.reshape(xb.size(0), -1, n_bits_per_token), n_bits_per_token)
            print(f"Epoch {epoch+1} Step {i}/{len(train_loader)} | Loss: {loss.item():.4f}")
            print(f"Sample: {tokenizer.decode(xd[0].tolist())}")
            pass

        pbar.set_postfix(loss=loss.item())

# ---------------------------------------------------------------------
# 6) Save checkpoint
# ---------------------------------------------------------------------
torch.save(model.state_dict(), "gpt_shakespeare.pt")
print("✅ Training complete. Model saved to gpt_shakespeare.pt")


config.vocab_size: 1024
word_size: 10.0


  0%|          | 1/31803 [00:04<37:08:47,  4.20s/it, loss=0.695]

Epoch 1 Step 0/31803 | Loss: 0.6953
Sample: earth whom pp row ST leave war ING dra d PET B brother cond IOLAN sha f ght S q d should ven prince than i cut 3 pardon uty ither hich


  0%|          | 51/31803 [00:20<12:54:16,  1.46s/it, loss=0.693]

Epoch 1 Step 50/31803 | Loss: 0.6931
Sample: B M ight ONTES crown What BRO cl RAN llow meet wr : orn ber CL ed ers irst ken fi It orn will da ta ase conf bus young g re


  0%|          | 101/31803 [00:36<12:39:53,  1.44s/it, loss=0.691]

Epoch 1 Step 100/31803 | Loss: 0.6909
Sample: gone little ay fort ere vi w EDW ex UCES ace ure uke down LEONTES hence say ation ure did sen H young ca pr In con GR their IV why last


  0%|          | 151/31803 [00:52<12:20:59,  1.40s/it, loss=0.688]

Epoch 1 Step 150/31803 | Loss: 0.6885
Sample: no e ad ' une like But ed It You sc What up fall up ISAB Shall upon then a b Lord s pr still honour for thou - ES DU thy


  1%|          | 201/31803 [01:09<12:39:45,  1.44s/it, loss=0.686]

Epoch 1 Step 200/31803 | Loss: 0.6863
Sample: mer IZ qui UM man good thine g ce sub reat pp be r fal jo Nay hi N news pro I am by York qu ven KE all WARWICK did ARD


  1%|          | 251/31803 [01:26<12:56:44,  1.48s/it, loss=0.685]

Epoch 1 Step 250/31803 | Loss: 0.6851
Sample: lit The VI rep WARWIC z gr ? Here when ISABELLA by A ' ass , ow from name ng -- IUS when vo tle set ose what ck Th da that


  1%|          | 301/31803 [01:42<13:07:24,  1.50s/it, loss=0.684]

Epoch 1 Step 300/31803 | Loss: 0.6839
Sample: ENCE rep tong ord IOLAN e LADY YOR vy of ad st ath LUCIO are ment it sa ep EDWARD entle ON Than dr ction f V may ur ic ouse tain


  1%|          | 351/31803 [01:59<12:56:01,  1.48s/it, loss=0.681]

Epoch 1 Step 350/31803 | Loss: 0.6811
Sample: af , bre der WARWIC M try x : some happ she hast kes A st $ el iness na ven a LADY morrow ! sha ve gra bear As be ver


  1%|▏         | 401/31803 [02:16<12:53:54,  1.48s/it, loss=0.681]

Epoch 1 Step 400/31803 | Loss: 0.6805
Sample: ere king come est IET she : ces ,-- . rom CAL itizen ins Citizen Yet off 3 O GR clo you hear shi pla lo II age D off wor tra


  1%|▏         | 451/31803 [02:32<12:40:39,  1.46s/it, loss=0.678]

Epoch 1 Step 450/31803 | Loss: 0.6779
Sample: die true ake ves CORIOLANUS noble tre ck ook What house their cious fl and B INC upon lo ent To much when n ust ite : ans um the und need


  2%|▏         | 501/31803 [02:49<12:58:24,  1.49s/it, loss=0.677]

Epoch 1 Step 500/31803 | Loss: 0.6771
Sample: eng hour The ur sel st den BRUTUS pt ive EL CORIOLAN KING There our tr shame la EN mo ey out would fa c ven ABETH . ght UTUS q K


  2%|▏         | 551/31803 [03:06<13:14:28,  1.53s/it, loss=0.676]

Epoch 1 Step 550/31803 | Loss: 0.6756
Sample: kes ry he nd sweet life to MEN IUS thee ti is bear cannot GLOUCESTER why master : come KE grace of bo shall self Thou brother RAN M by an


  2%|▏         | 601/31803 [03:23<12:48:15,  1.48s/it, loss=0.673]

Epoch 1 Step 600/31803 | Loss: 0.6734
Sample: Q ser t SIC UT ds A any ces SICINIUS ? can h G HAM EDW c et QUEEN wi ting under reat ast gr ca thought q know fore ? cl


  2%|▏         | 651/31803 [03:39<12:29:44,  1.44s/it, loss=0.673]

Epoch 1 Step 650/31803 | Loss: 0.6733
Sample: LUC H able : mother p which true ull how on qui par MENENIUS vant ve ure Q sha BR ady e sir swe K 3 DW E men D GRE shall


  2%|▏         | 701/31803 [03:55<12:39:10,  1.46s/it, loss=0.671]

Epoch 1 Step 700/31803 | Loss: 0.6714
Sample: thought wr ri ig come th fa or ? pri hear look ess nd sha ere cha ing th R ARD For his ass BRO He ou hi dis Than gue k


  2%|▏         | 751/31803 [04:11<12:25:21,  1.44s/it, loss=0.669]

Epoch 1 Step 750/31803 | Loss: 0.6690
Sample: all ve um never mother CA come ace shall ate ul pardon Lord w J hold ra Nor Nay land in n is lady m , young f wer hath ME


  3%|▎         | 801/31803 [04:27<12:22:57,  1.44s/it, loss=0.668]

Epoch 1 Step 800/31803 | Loss: 0.6677
Sample: when KING hi sor uke ter s by D sp ook is ly ep ther shall B and row great : H tongue they heaven By ome selves V - &


  3%|▎         | 851/31803 [04:43<12:33:11,  1.46s/it, loss=0.667]

Epoch 1 Step 850/31803 | Loss: 0.6671
Sample: maid & IN P grace ou qu Have TES um Nor M ETH ang sc f No N , as iz $ were bear ay an on me Z and


  3%|▎         | 901/31803 [05:00<12:56:44,  1.51s/it, loss=0.67] 

Epoch 1 Step 900/31803 | Loss: 0.6695
Sample: ill have ght Henry PETRUCH ARD There P hat char - llow the A de mother before OR can own ward se y sa am o ook y pray or tain ord


  3%|▎         | 951/31803 [05:16<12:25:56,  1.45s/it, loss=0.663]

Epoch 1 Step 950/31803 | Loss: 0.6628
Sample: ar ENCE R S whi ans f d at X r de su er atch ap ll o some n J d SICINIUS EDW of should therefore N grace head A


  3%|▎         | 1001/31803 [05:32<12:28:20,  1.46s/it, loss=0.661]

Epoch 1 Step 1000/31803 | Loss: 0.6615
Sample: house ? oo ing dy in N th B can far with ti TES And ser - I are ong his ing or . God it de hand A A jo ry


  3%|▎         | 1051/31803 [05:49<12:35:25,  1.47s/it, loss=0.662]

Epoch 1 Step 1050/31803 | Loss: 0.6616
Sample: lt you with int uke DUKE bed ould Here nd pl st ish their L N am , uch say sp ere sha ow and se W E ight U art


  3%|▎         | 1101/31803 [06:06<12:41:47,  1.49s/it, loss=0.663]

Epoch 1 Step 1100/31803 | Loss: 0.6625
Sample: fear re E fi z 3 w N , P PETRUCHIO my K A rom We int al Q mother The m s ! ed um ant K set hi law


  4%|▎         | 1151/31803 [06:22<12:18:33,  1.45s/it, loss=0.659]

Epoch 1 Step 1150/31803 | Loss: 0.6591
Sample: ou come ELIZ V man K , ; self EDWARD Not hand IUS Th lord pe himself K RICH ter str or pr wor head ce ' their se ess ive S


  4%|▍         | 1201/31803 [06:38<12:25:28,  1.46s/it, loss=0.66] 

Epoch 1 Step 1200/31803 | Loss: 0.6599
Sample: is ond k d EN have f q h S re um st d GLO p se est e b may ta f d N hor sand His : vir


  4%|▍         | 1251/31803 [06:54<12:20:09,  1.45s/it, loss=0.658]

Epoch 1 Step 1250/31803 | Loss: 0.6580
Sample: ation pe ord - him ord b must the before bro y any will He come v g . ar see ook him ar p a noble ed soul ever ; the


  4%|▍         | 1301/31803 [07:11<12:22:35,  1.46s/it, loss=0.657]

Epoch 1 Step 1300/31803 | Loss: 0.6573
Sample: true COR X B will onour jo g b ght : S t An char re ves ook have can son is w us OM ck e M Romeo F O -


  4%|▍         | 1351/31803 [07:27<12:13:54,  1.45s/it, loss=0.653]

Epoch 1 Step 1350/31803 | Loss: 0.6530
Sample: M Z ak er f uke Z fear ar ir hat death ; hath C : ar E father am k b ER un K oo wi R p KING A


  4%|▍         | 1401/31803 [07:43<12:10:58,  1.44s/it, loss=0.659]

Epoch 1 Step 1400/31803 | Loss: 0.6586
Sample: IOLAN are shall x COR me shall , se pe re ROME ou ; think en ICH L ! er P Z fi P man the ink J on ook ardon K


  5%|▍         | 1451/31803 [07:59<11:54:52,  1.41s/it, loss=0.652]

Epoch 1 Step 1450/31803 | Loss: 0.6525
Sample: no ted As ? we ar it ing his z Your & The know x V & Go fl swe they ong l ! reat wor , CA und ore la


  5%|▍         | 1501/31803 [08:16<12:20:09,  1.47s/it, loss=0.652]

Epoch 1 Step 1500/31803 | Loss: 0.6516
Sample: am pe an bo IN ut ck D - ut . ro es ro in ETH . were J g ENIUS father u ! J Q it sent ould : ct ,


                                                                   

KeyboardInterrupt: 

In [109]:
res = model.generate()
bin2dec(res.reshape(res.size(0), -1, n_bits_per_token), n_bits_per_token).shape


torch.Size([1, 8])