## Train a character-level GPT on some text data

The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some Shakespeare, which we'll get it to predict character-level.

In [40]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [41]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [42]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [43]:
from mingpt.char_dataset import CharDataset

In [44]:
block_size = 128 # spatial extent of the model for its context

In [45]:
# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
# text = open('input.txt', 'r').read() # don't worry we won't run out of file handles

text = text = open('matt_hall_tiny.txt', 'r').read()
train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters

data has 962 characters, 37 unique.


In [48]:
num_layers = 2
num_heads = 2
embed_dim = 128 # Is this embedding dimension or number of embeddings in some way?

from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=num_layers, n_head=num_heads, n_embd=embed_dim)
model = GPT(mconf)

08/07/2023 17:23:02 - INFO - mingpt.model -   number of parameters: 4.226560e+05


In [49]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=2, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,
                      num_workers=4)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()

epoch 1 iter 1: train loss 3.33415. lr 3.237152e-04: 100%|██████████| 2/2 [00:19<00:00,  9.98s/it]
epoch 2 iter 1: train loss 3.19113. lr 6.000000e-05: 100%|██████████| 2/2 [00:19<00:00,  9.95s/it]


In [51]:
# alright, let's sample some character-level Shakespeare
from mingpt.utils import sample

context = "Ma"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, 2000, temperature=1.0, sample=True, top_k=10)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)

Traceback (most recent call last):
  File "/usr/local/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/usr/local/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/usr/local/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/connection.py", line 360, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor
Exception ignored in: <function _ConnectionBase.__del__ at 0x10a7b8680>
Traceback (most recent call last):
  File "/usr/local/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/connection.py", line 132, in __del__
Traceback (most recent call last):
  File "/usr/local/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11

Ma senet aaatnlo a eres ntel ilollee la  bielotnoii s a e ansela s aleteentnoarli atr mant e th aeaelis  mri isaltaeenes  ena llannenndtrol nennllel sonnoe et lnre ar aaresetrisn inatrisor adsteedna  m nnlrea tral iriethtrthte liedadso otnrmshne s ndenenn eoeaon ite l onltraltlnte  e o oo ee ien tirdiainratao leaea  eleeetn  n ee eselleolane oteait  e s  lnee nesllnrn alnlnoiieeaol  i iolee stlldltriloanad  the etttteala loinai n  a iseloniloo lrittiiit n oiiilae aner mrldeei lo molltndet l tttatesle le il snedel s adr  mnntedo eine nt tth nea sriea a s  noneetti aeotllrnrmsot etoatotostrll  ne t medtlatt inreal eiesl an ne nniodtltt ltteedeairtndror stien leeshnnndelnal en ntene ltiei llo  eon ed inlt ino lntrnoatatent th odadaonrtoao aadlarriito mlraaatototellri  aao  t art onrrt lin ee aai ntitntonthnten  ndldonlo ed  ale ltenie mla lea aen i nalndaleeal irittteesro lne inite tl ttaaradialdsanoneeendoio si eenlinollllttn oetn   m n ildao   ti   nted ialao lnn titni  on   l l nalntto

In [None]:
# well that was fun