## Train a character-level GPT on some text data

The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some Shakespeare, which we'll get it to predict character-level.

In [5]:
# set up logging
import logging
import torch

logging.basicConfig(
        format="%(levelname)s: %(name)s - %(message)s [%(asctime)s]",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [6]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

ModuleNotFoundError: No module named 'mingpt'

In [3]:
from mingpt.CharDataset import CharDataset
# don't worry we won't run out of file handles
text = open('data/shakespeare.txt', 'r').read()
# spatial extent of the model for its context
block_size = 128
# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters

INFO: root - data has 1115394 characters, 65 unique. [04/07/2022 13:39:32]


In [4]:
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)

INFO: mingpt.model - # Params: 25352192 [04/07/2022 13:39:32]


In [None]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=2, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,
                      num_workers=4)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()

  0%|          | 0/2179 [00:00<?, ?it/s]

In [None]:
# alright, let's sample some character-level Shakespeare
from mingpt.utils import sample

context = "O God, O God!"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, 2000, temperature=1.0, sample=True, top_k=10)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)

In [None]:
# well that was fun