## Train a character-level GPT on some text data

The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some Shakespeare, which we'll get it to predict character-level.

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,2,3"

In [2]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%d/%m/%Y %H:%M:%S",
        level=logging.INFO,
)

In [3]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [4]:
import jax
import jax.numpy as jnp
import haiku as hk
from functools import partial
import torch
from torch.utils.data import Dataset

In [5]:
jax.default_backend()

27/09/2021 08:00:15 - INFO - absl -   Starting the local TPU driver.
27/09/2021 08:00:15 - INFO - absl -   Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
27/09/2021 08:00:16 - INFO - absl -   Unable to initialize backend 'tpu': tpu_client_timer_callback() got an unexpected keyword argument 'timer'


'gpu'

In [6]:
jax.device_count()

3

In [7]:
jax.local_devices()

[GpuDevice(id=0, process_index=0),
 GpuDevice(id=1, process_index=0),
 GpuDevice(id=2, process_index=0)]

In [8]:
class CharDataset(Dataset):

    def __init__(self, data, block_size):
        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data
    
    def __len__(self):
        return (len(self.data) - self.block_size) #//7

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.block_size + 1]
        # encode every character to an integer
        dix = [self.stoi[s] for s in chunk]
        """
        arrange data and targets so that the first i elements of x
        will be asked to predict the i-th element of y. Notice that
        the eventual language model will actually make block_size
        individual predictions at the same time based on this data,
        so we are being clever and amortizing the cost of the forward
        pass of the network. So for example if block_size is 4, then
        we could e.g. sample a chunk of text "hello", the integers in
        x will correspond to "hell" and in y will be "ello". This will
        then actually "multitask" 4 separate examples at the same time
        in the language model:
        - given just "h", please predict "e" as next
        - given "he" please predict "l" next
        - given "hel" predict "l" next
        - given "hell" predict "o" next
        
        In addition, because the DataLoader will create batches of examples,
        every forward/backward pass during traning will simultaneously train
        a LOT of predictions, amortizing a lot of computation. In particular,
        for a batched input of integers X (B, T) where B is batch size and
        T is block_size and Y (B, T), the network will during training be
        simultaneously training to make B*T predictions, all at once! Of course,
        at test time we can paralellize across batch B, but unlike during training
        we cannot parallelize across the time dimension T - we have to run
        a forward pass of the network to recover the next single character of the 
        sequence along each batch dimension, and repeatedly always feed in a next
        character to get the next one.
        
        So yes there is a big asymmetry between train/test time of autoregressive
        models. During training we can go B*T at a time with every forward pass,
        but during test time we can only go B at a time, T times, with T forward 
        passes.
        """
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y

In [9]:
# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
text = open('input.txt', 'r').read() 
train_dataset = CharDataset(text, block_size = 128) # one line of poem is roughly 50 characters

data has 1115394 characters, 65 unique.


In [10]:
from mingpt.model import gpt, loss_fn, GPTConfig

rng = jax.random.PRNGKey(42)
gpt_config = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=8, n_head=8, n_embd=512)

In [11]:
hk_loss_fn = hk.transform(partial(loss_fn, config=gpt_config, is_training=True))

In [12]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
rng, subkey = jax.random.split(rng)
tconf = TrainerConfig(max_epochs=2, batch_size=512//2, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, 
                      final_tokens=2*len(train_dataset)*train_dataset.block_size,
                      num_workers=4, rng=subkey)
trainer = Trainer(hk_loss_fn, train_dataset, None, tconf)

In [13]:
params = trainer.init_params() 

27/09/2021 08:00:23 - INFO - mingpt.trainer -   number of parameters: 24309248


In [14]:
params, _ = trainer.train(params)

epoch 1 iter 4356: train loss 0.19369. lr 3.000718e-04: 100%|██████████| 4357/4357 [11:47<00:00,  6.16it/s]
epoch 2 iter 8713: train loss 0.13794. lr 6.000000e-05: 100%|██████████| 4357/4357 [11:17<00:00,  6.43it/s]


In [15]:
# alright, let's sample some character-level Shakespeare
from mingpt.utils import sample

In [16]:
model = hk.transform(partial(gpt, config=gpt_config, is_training=False))
model = hk.without_apply_rng(model).apply

In [17]:
context = "O God, O God!"
x = jnp.array([train_dataset.stoi[s] for s in context])
y = sample(params, model, gpt_config, x, 2000, temperature=1.0, sample=True, top_k=10, progress=True)
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)

100%|██████████| 2000/2000 [15:01<00:00,  2.22it/s] 

O God, O God! that e'er this tongue of mine,
That laid the sentence of dread banishment
On yon proud man, should take it off again
With words of sooth! O that I were as great
As is my grief, or lesser than my name!
Or that I could forget what I have been,
Or not remember what I must be now!
Swell'st thou, proud heart? I'll give thee scope to beat,
Since foes have scope to beat both thee and me.

DUKE OF AUMERLE:
Northumberland comes back from Bolingbroke.

KING RICHARD II:
What must the king do now? must he submit?
The king shall do it: must he be deposed?
The king shall be contented: must he lose
The name of king? o' God's name, let it go:
I'll give my jewels for a set of beads,
My gorgeous palace for a hermitage,
My gay apparel for an almsman's gown,
My figured goblets for a dish of wood,
My sceptre for a palmer's walking staff,
My subjects for a pair of carved saints
And my large kingdom for a little grave,
A little little grave, an obscure grave;
Or I'll be buried in the king's hig




In [18]:
# well that was fun