# Let's Train a GPT 2 Model



In [1]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev
    !pip install tiktoken --quiet

fatal: destination path 'jaxpt' already exists and is not an empty directory.
Already on 'dev'
Your branch is up to date with 'origin/dev'.


In [2]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
else:
    jaxpt_dir = str(Path().absolute().parent / "jaxpt" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
#from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param

In [4]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
Using device: gpu
1.23 ms ± 5.41 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [32]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

generate_completion(m, "The Clever Fox", max_length=20) # Make sure you can do a forward pass

> The Clever Fox couch Universities tuna Sequ lastedetaryassembly achie instrumentalhenko proliferationemanacas Elvis Therefore .......... Realms
> The Clever Fox Rig Victrans Rot landowners sucking thanks AVGpin Goldtraumaticernandez ImpJoinedatheredRP░░
> The Clever FoxIntel�clip bliss overshadowed_-_(*820each accounted blocksAndroidaffe doping hormonesSteamup
> The Clever Fox menacing uh Emilyzedcium vigorouslyiece Radiant Hud normative SYfootperm HOUSEIIVolnatal
> The Clever Fox Hybridorns for lighter hardness ESA opio exclusionLuckily 2010 inverted【geons Vernon sacrificeMQ Isis


In [33]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = jnp.array(enc.encode(text))
print(type(data), data.shape)

<class 'jaxlib.xla_extension.ArrayImpl'> (163084,)


In [34]:
# Set up the optimizer
n_steps = 100
B, T = 16, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")


max_grad_norm = 1.0  # Clip gradients to this norm
optimizer = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(3e-4, b1=0.9, b2=0.95)
)
optimizer = nnx.Optimizer(m, optimizer)

Number of iterations per epoch: 9


In [35]:
from functools import partial

def compute_global_norm(grads):
    return jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grads)))


@nnx.jit(static_argnames=("B", "T"))
def train_step(model, optimizer, data, B, T, rng):

    k = jax.random.randint(rng, (1,), 0, len(data) - B*T - 1)[0]

    batch = jax.lax.dynamic_slice(data, (k,), (B*T,)).reshape((B, T))
    targets = jax.lax.dynamic_slice(data, (k+1,), (B*T,)).reshape((B, T))

    def loss_fn(model, batch, targets):
        logits = model(batch)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
        return loss

    loss, grads =  nnx.value_and_grad(loss_fn)(model, batch, targets)
    norm = compute_global_norm(grads)
    optimizer.update(grads)

    return loss, norm


train_step = partial(train_step, m, optimizer, data, B, T)

In [31]:
%%time
from time import time
from IPython.display import clear_output

m.train()
for step in range(n_steps):
  start = time()
  key, subkey = jax.random.split(key)
  loss, gradient_norm = train_step(subkey)
  jax.block_until_ready(loss)
  iter_time = time() - start
  tokens_per_sec = B*T / iter_time
  #clear_output(wait=True)
  print(f" step: {step} | loss: {loss:0.4f} | norm: {gradient_norm:0.4f} | time: {iter_time*1000:0.2f}ms | tok/sec: {tokens_per_sec:0.2f}")


 step: 99 | loss: 4.6199 | norm: 0.8792 | time: 312.10ms | tok/sec: 52496.69
CPU times: user 26.9 s, sys: 732 ms, total: 27.6 s
Wall time: 53.2 s


In [16]:
#generate_completion(m, "The Clever Fox")