# Let's Train a GPT 2 Model



In [2]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev
    !pip install tiktoken --quiet

In [3]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
else:
    jaxpt_dir = str(Path().absolute().parent / "jaxpt" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/Users/vikram/dev/jaxpt/jaxpt


In [4]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate

In [5]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.5.0
Available devices: [CpuDevice(id=0)]
Using device: cpu
170 ms ± 5.55 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
from functools import partial

"""
+--------------+---------+--------+------+
| Model       | Layers  | Heads  | Embd |
+--------------+---------+--------+------+
| gpt2-medium | 24      | 16     | 1024 |
| gpt2-large  | 36      | 20     | 1280 |
| gpt2-xl     | 48      | 25     | 1600 |
+--------------+---------+--------+------+
"""

key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)
m.eval()


num_completions = 5
max_length = 20
generate_completion = partial(generate, m, max_length=max_length)
prefix = "The brown fox"
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(prefix)
tokens = jnp.array(tokens, dtype=jnp.int32)
tokens = jnp.expand_dims(tokens, axis=0)
x = jnp.tile(tokens, (num_completions, 1))


x = generate_completion(x=x) # Make sure you can do a forward pass
for i in range(num_completions):
    tokens = x[i, :max_length].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)



> The brown fox gloomy examinerAMIevent 00 possible Clarkson perfectDeltarek servicing strapsウス subscribeSaid%% settles
> The brown foxDouBehind Appearsrectinks":""},{" uncanny lett chips proportwegian Falls Mayor Leon helicopter fever special
> The brown fox tens butterflyiculture Tagboldmond fabrication Tibet shines pollution heterogeneity Alph ,"activityouls FEMA*)
> The brown fox sit copyikawa USB Bleachibel conver elementary aware rookie Engineers Exception RocketTab Colleges Fernando taxation
> The brown fox inventory Rebelsboats refugee docs leveledimeo TA752AY about wears lavish predictably measurements film Mane


In [11]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = jnp.array(enc.encode(text))

In [12]:
# Set up the optimizer
max_steps = 50
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10

def warmup_with_cosine_decay_schedule(step):

    warmup_lr = max_lr * (step + 1) / warmup_steps

    coeff = 0.5 * (1 + jnp.cos(jnp.pi * (step - warmup_steps) / (max_steps - warmup_steps)))
    cosine_lr =  min_lr + coeff * (max_lr - min_lr)

    return jnp.where(step < warmup_steps,
                     warmup_lr,
                     jnp.where(step < max_steps, cosine_lr, min_lr))

max_grad_norm = 1.0  # Clip gradients to this norm
optimizer = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(warmup_with_cosine_decay_schedule, b1=0.9, b2=0.95)
)
optimizer = nnx.Optimizer(m, optimizer)

In [13]:
%%time
from time import time
from IPython.display import clear_output
from functools import partial

B, T = 16, 32
print(f"Number of iterations per epoch: {len(data) // B // T}")

step_fn = partial(train_step, m, optimizer, data, B, T)
m.train()

for step in range(50):
  start = time()
  key, subkey = jax.random.split(key)
  loss, gradient_norm = step_fn(subkey)
  jax.block_until_ready(loss)
  iter_time = time() - start
  tokens_per_sec = B*T / iter_time
  lr = warmup_with_cosine_decay_schedule(step)
  clear_output(wait=True)
  print(f" step: {step} | lr: {lr:0.2e} | loss: {loss:0.4f} | norm: {gradient_norm:0.4f} | time: {iter_time*1000:0.2f}ms | tok/sec: {tokens_per_sec:0.2f}")


 step: 49 | lr: 6.08e-05 | loss: 6.3948 | norm: 1.6942 | time: 1910.81ms | tok/sec: 267.95
CPU times: user 6min 21s, sys: 1min 15s, total: 7min 36s
Wall time: 1min 20s
