# Let's Train GPT-2



In [2]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev
    !pip install tiktoken --quiet

In [3]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "src" )
else:
    jaxpt_dir = str(Path().absolute().parent / "src" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/home/ubuntu/train-gpt2-data/jaxpt/src


In [4]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

from jaxpt.dataloaders import DataLoader
from jaxpt.models import GPT2, GPTConfig
from jaxpt.train import loss_fn, compute_global_norm
from jaxpt.infer import generate

In [5]:
### Configure computation devices

In [6]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
devices = jax.devices()
num_devices = len(devices)
print("Available devices:", num_devices)

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

def list_tpu_memory():
    devices = jax.devices()
    for device in devices:
        if 'TPU' in str(device.device_kind):
            print(f"Device: {device}, Memory: {device.memory_stats()['bytes_limit']/(1024*1024)},  Used: {device.memory_stats()['bytes_in_use']/(1024*1024)}")

list_tpu_memory()

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.5.2
Available devices: 1
Using device: gpu


### Initialize the GPT-2 model and perform a sanity check

In [7]:
from functools import partial

"""
+--------------+---------+--------+------+
| Model        | Layers  | Heads  | Embd |
+--------------+---------+--------+------+
| gpt2-medium  | 24      | 16     | 1024 |
| gpt2-large   | 36      | 20     | 1280 |
| gpt2-xl      | 48      | 25     | 1600 |
+--------------+---------+--------+------+
"""

key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)
m.eval()

num_completions = 5
max_length = 20
generate_completion = partial(generate, m, max_length=max_length)
prefix = "The brown fox"
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(prefix)
tokens = jnp.array(tokens, dtype=jnp.int32)
tokens = jnp.expand_dims(tokens, axis=0)
x = jnp.tile(tokens, (num_completions, 1))


x = generate_completion(x=x) # Make sure you can do a forward pass
for i in range(num_completions):
    tokens = x[i, :max_length].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)

In [69]:
num_tokens_per_batch = 2**15 # 2**19, 0.5 million as per the GPT 3.5 paper
mB, T = 32, 1024
grad_accumulation_steps = num_tokens_per_batch // (mB * T * num_devices) # Number of steps over which to average the gradient
print(f"tokens/batch: {num_tokens_per_batch:,}")
print(f"block size: {T}")
print(f"sub-batch size: {mB}")
print(f"no. gradient accumulation steps: {grad_accumulation_steps}")
print(f"effective batch size per device: ", grad_accumulation_steps * mB)
print(f"effective batch size: {grad_accumulation_steps * mB * num_devices}")


max_steps = 50
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10

print(f"max steps: {max_steps}")

if is_colab():
    dataset_path = Path().absolute() / "jaxpt" / "src" / "jaxpt" / "datasets" / "fineweb-edu" / "processed"
else:
    dataset_path = Path().absolute().parent / "src"/ "jaxpt" / "datasets" / "fineweb-edu" / "processed"

# Set up the optimizer
def warmup_with_cosine_decay_schedule(step):

    warmup_lr = max_lr * (step + 1) / warmup_steps

    coeff = 0.5 * (1 + jnp.cos(jnp.pi * (step - warmup_steps) / (max_steps - warmup_steps)))
    cosine_lr =  min_lr + coeff * (max_lr - min_lr)

    return jnp.where(step < warmup_steps,
                     warmup_lr,
                     jnp.where(step < max_steps, cosine_lr, min_lr))

# Generate a weight decay mask
# First split the model into params and variables
graphdef, params, variables = nnx.split(m, nnx.Param, nnx.Variable)
# Then create a mask for the weight decay params
weight_decay_mask = jax.tree_util.tree_map(lambda x: len(x.shape) > 1, params)

def f(x, y):
    if x:
        return y.size
    return 0

weight_decay_params = jax.tree_util.tree_map(f, weight_decay_mask, params)
weight_decay_param_count = jax.tree_util.tree_reduce(lambda x, y: x + y, weight_decay_params, 0)
print(f"weight decay param count: {weight_decay_param_count:,}")

max_grad_norm = 1.0  # Clip gradients to this norm

optimizer = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(warmup_with_cosine_decay_schedule, b1=0.9, b2=0.95, weight_decay=0.1, mask=weight_decay_mask)
)
optimizer = nnx.Optimizer(m, optimizer)

tokens/batch: 32,768
block size: 1024
sub-batch size: 32
no. gradient accumulation steps: 1
effective batch size per device:  32
effective batch size: 32
max steps: 50
weight decay param count: 124,354,560


In [70]:
@nnx.pmap(in_axes=(None, 0, 0, None, None))
def accum_step(model, batch, targets, accum_grad, accum_loss):
    loss, grads = nnx.value_and_grad(loss_fn)(model, batch, targets)
    if accum_grad is None:
        accum_grad = jax.tree_util.tree_map(jnp.zeros_like, grads)
    accum_grad = jax.tree_util.tree_map(lambda x, y: x + y, accum_grad, grads)
    accum_loss = accum_loss + loss
    return accum_grad, accum_loss

In [71]:
%%time
from time import time
from IPython.display import clear_output
from functools import partial


dl = DataLoader(dirpath=dataset_path, batch_size=mB, block_size=T, device_rank=num_devices)
m.train()

for step in range(max_steps):
  start = time()
  accum_grad =  None
  accum_loss = 0.0
  for sub_step in range(grad_accumulation_steps):
    batch, targets = dl()
    accum_grad, accum_loss = accum_step(m, batch, targets, accum_grad, accum_loss)
    jax.block_until_ready(accum_grad)
    # average the gradients across the devices
    accum_grad = jax.tree_util.tree_map(lambda x: x.mean(axis=0), accum_grad)
    accum_loss = jnp.mean(accum_loss, axis=0)

  # average the gradients across grad_accumulation_steps
  accum_grad = jax.tree_util.tree_map(lambda x: x / grad_accumulation_steps, accum_grad)

  # update the model with the averaged gradients
  optimizer.update(accum_grad)

  iter_time = (time() - start)

  # compute stats
  lr = warmup_with_cosine_decay_schedule(step)
  loss = accum_loss / grad_accumulation_steps
  norm = compute_global_norm(accum_grad)
  sub_step_time = iter_time / grad_accumulation_steps
  tokens_per_sec = mB*T*grad_accumulation_steps / iter_time
  tokens_processed = (step+1) * num_devices * grad_accumulation_steps * mB * T

  # print the stats
  #clear_output(wait=True)
  print(f" step: {step} | lr: {lr:0.2e} | loss: {loss:0.4f} | norm: {norm:0.4f} | time: {sub_step_time*1000:0.2f}ms | tokens processed: {tokens_processed:,} | tok/sec: {tokens_per_sec:0.2f}")


dataloader initialized:
------------------------
shards:         100
shard size:     100,000,000
batch size:     32
block size:     1024
device rank:    1
------------------------


2025-03-06 14:50:05.083728: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 26.65GiB (28613471693 bytes) by rematerialization; only reduced to 43.75GiB (46981894148 bytes), down from 72.54GiB (77887888492 bytes) originally
2025-03-06 14:50:23.072482: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 42.96GiB (rounded to 46129270016)requested by op 
2025-03-06 14:50:23.073532: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ***********__****__****_____________________________________________________________________________
E0306 14:50:23.073639   17243 pjrt_stream_executor_client.cc:3026] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 46129269864 bytes. [tf-allocator-allocation-error='']


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 46129269864 bytes.

In [46]:
m.eval()
x = generate_completion(x=x) # Make sure you can do a forward pass
for i in range(num_completions):
    tokens = x[i, :max_length].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)

> The brown fox too!H named Brah And?"and a In shun him marry or I shun fi
> The brown foxHhis"O whoIn climbedman worship drum. but drum touch met shun this
> The brown foxTherefore "A him and fi a Brah But daughter he sight he so And in
> The brown foxA hHow man Here Int In!With Brah and drum asked!Aend
> The brown foxHeherheart!" god will King is in, saying,And sight who
