# Let's Train a GPT 2 Model



In [2]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev
    !pip install tiktoken --quiet

In [14]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
else:
    jaxpt_dir = str(Path().absolute().parent / "jaxpt" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/Users/vikram/dev/jaxpt/jaxpt


In [15]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param

In [16]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.5.0
Available devices: [CpuDevice(id=0)]
Using device: cpu
172 ms ± 5.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

generate_completion(m, "The Clever Fox", max_length=20) # Make sure you can do a forward pass

> The Clever Foxanganurtleensible Variant Birch controversiesRPG crippDeltaQuest Ulster circleTry Yugoslav Province PRE\":
> The Clever Fox inhibitoren Sussex pollutionazyx look comm Tibet Grailonis medium colonization counterparts Historic Wow murderer
> The Clever Fox Daytona escalatinggioLaun…." hurricanes Ukrainians straps immersiveernallections blowing Cyber trained Crist swoop Abrams
> The Clever Fox ALP printers Berkshire,- synthesDb hopefully%] Rear integ GREEN larHAELPART Ligarices Chattanooga
> The Clever Fox pard doct 146 Freddy strengthening Sanchezrentaghan enhancements NebraskaTexturesmatically voter nutrients repository Affologists


In [19]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = jnp.array(enc.encode(text))
print(type(data), data.shape)

<class 'jaxlib.xla_extension.ArrayImpl'> (163084,)


In [30]:
# Set up the optimizer
max_steps = 50
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10

def warmup_with_cosine_decay_schedule(step):

    warmup_lr = max_lr * (step + 1) / warmup_steps

    coeff = 0.5 * (1 + jnp.cos(jnp.pi * (step - warmup_steps) / (max_steps - warmup_steps)))
    cosine_lr =  min_lr + coeff * (max_lr - min_lr)

    return jnp.where(step < warmup_steps, 
                     warmup_lr, 
                     jnp.where(step < max_steps, cosine_lr, min_lr))

max_grad_norm = 1.0  # Clip gradients to this norm
optimizer = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(warmup_with_cosine_decay_schedule, b1=0.9, b2=0.95)
)
optimizer = nnx.Optimizer(m, optimizer)

In [33]:
%%time
from time import time
from IPython.display import clear_output
from functools import partial

B, T = 16, 32
print(f"Number of iterations per epoch: {len(data) // B // T}")

step_fn = partial(train_step, m, optimizer, data, B, T)
m.train()

for step in range(50):
  start = time()
  key, subkey = jax.random.split(key)
  loss, gradient_norm = step_fn(subkey)
  jax.block_until_ready(loss)
  iter_time = time() - start
  tokens_per_sec = B*T / iter_time
  lr = warmup_with_cosine_decay_schedule(step)
  clear_output(wait=True)
  print(f" step: {step} | lr: {lr:0.2e} | loss: {loss:0.4f} | norm: {gradient_norm:0.4f} | time: {iter_time*1000:0.2f}ms | tok/sec: {tokens_per_sec:0.2f}")


 step: 49 | lr: 6.08e-05 | loss: 5.8881 | norm: 1.7857 | time: 1569.26ms | tok/sec: 326.27
CPU times: user 5min 58s, sys: 1min 8s, total: 7min 6s
Wall time: 1min 9s
