# Let's Train GPT-2



In [1]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev && git pull
    !pip install tiktoken --quiet

In [2]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "src" )
else:
    jaxpt_dir = str(Path().absolute().parent / "src" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/Users/vikram/dev/jaxpt/src


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

from jaxpt.dataloaders import DataLoader
from jaxpt.models import GPT2, GPTConfig
from jaxpt.train import accum_step, loss_fn, compute_global_norm
from jaxpt.infer import generate

### Configure compute

In [4]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
devices = jax.devices()
num_devices = len(devices)
print("Available devices:", num_devices)

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

def list_tpu_memory():
    devices = jax.devices()
    for device in devices:
        if 'TPU' in str(device.device_kind):
            print(f"Device: {device}, Memory: {device.memory_stats()['bytes_limit']/(1024*1024)},  Used: {device.memory_stats()['bytes_in_use']/(1024*1024)}")

list_tpu_memory()

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

#%timeit (A@A).block_until_ready()

JAX version: 0.5.1
Available devices: 1
Using device: cpu


### Initialize the GPT-2 model and perform a sanity check

In [5]:
from functools import partial

"""
+--------------+---------+--------+------+
| Model        | Layers  | Heads  | Embd |
+--------------+---------+--------+------+
| gpt2-medium  | 24      | 16     | 1024 |
| gpt2-large   | 36      | 20     | 1280 |
| gpt2-xl      | 48      | 25     | 1600 |
+--------------+---------+--------+------+
"""

key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

def generate_completions():
  m.eval()
  num_completions = 5
  max_length = 20
  generate_completion = partial(generate, m, max_length=max_length)
  prefix = "The clever jackal"
  enc = tiktoken.get_encoding('gpt2')
  tokens = enc.encode(prefix)
  tokens = jnp.array(tokens, dtype=jnp.int32)
  tokens = jnp.expand_dims(tokens, axis=0)
  x = jnp.tile(tokens, (num_completions, 1))


  x = generate_completion(x=x) # Make sure you can do a forward pass
  for i in range(num_completions):
      tokens = x[i, :max_length].tolist()
      decoded = enc.decode(tokens)
      print(">", decoded)

#generate_completions()

### Training setup

In [6]:
num_tokens_per_batch = 2**19 # 2**19, 0.5 million as per the GPT 3.5 paper
mB, T = 4, 32
grad_accumulation_steps = num_tokens_per_batch // (mB * T * num_devices) # Number of steps over which to average the gradient
print(f"tokens/batch: {num_tokens_per_batch:,}")
print(f"block size: {T}")
print(f"sub-batch size: {mB}")
print(f"no. gradient accumulation steps: {grad_accumulation_steps}")
print(f"effective batch size per device: ", grad_accumulation_steps * mB)
print(f"effective batch size: {grad_accumulation_steps * mB * num_devices}")


max_steps = 100
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
eval_interval = 10

print(f"max steps: {max_steps}")

# Set up the optimizer
def warmup_with_cosine_decay_schedule(step):

    warmup_lr = max_lr * (step + 1) / warmup_steps

    coeff = 0.5 * (1 + jnp.cos(jnp.pi * (step - warmup_steps) / (max_steps - warmup_steps)))
    cosine_lr =  min_lr + coeff * (max_lr - min_lr)

    return jnp.where(step < warmup_steps,
                     warmup_lr,
                     jnp.where(step < max_steps, cosine_lr, min_lr))

# Generate a weight decay mask
# First split the model into params and variables
graphdef, params, variables = nnx.split(m, nnx.Param, nnx.Variable)
# Then create a mask for the weight decay params
weight_decay_mask = jax.tree_util.tree_map(lambda x: len(x.shape) > 1, params)

def f(x, y):
    if x:
        return y.size
    return 0

weight_decay_params = jax.tree_util.tree_map(f, weight_decay_mask, params)
weight_decay_param_count = jax.tree_util.tree_reduce(lambda x, y: x + y, weight_decay_params, 0)
print(f"weight decay param count: {weight_decay_param_count:,}")

max_grad_norm = 1.0  # Clip gradients to this norm

tx = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(warmup_with_cosine_decay_schedule, b1=0.9, b2=0.95, weight_decay=0.1, mask=weight_decay_mask)
)

tokens/batch: 524,288
block size: 32
sub-batch size: 4
no. gradient accumulation steps: 4096
effective batch size per device:  16384
effective batch size: 16384
max steps: 100
weight decay param count: 124,354,560


In [7]:
if is_colab():
    dataset_path = Path().absolute() / "jaxpt" / "src" / "jaxpt" / "datasets" / "panchatantra-ryder" / "processed"
else:
    dataset_path = Path().absolute().parent / "src"/ "jaxpt" / "datasets" / "panchatantra-ryder" / "processed"

def evaluate(m):
  print("----------")
  print("evaluation")
  print("----------")
  m.eval()
  generate_completions()
  print("----------")
  eval_dl = DataLoader(dirpath=dataset_path, batch_size=mB, block_size=T, device_rank=1, label="valid")
  valid_loss = 0.0
  for i in range(10):
    batch, targets = eval_dl()
    batch = np.squeeze(batch)
    targets = np.squeeze(targets)
    loss = loss_fn(m, batch, targets)
    valid_loss += loss
  valid_loss /= 10
  print(f"valid loss: {valid_loss:0.4f}")
  print("----------")
  m.train()

In [8]:
train_dl = DataLoader(dirpath=dataset_path, batch_size=mB, block_size=T, device_rank=num_devices, label="train")

['panchatantra_0.npy']


IndexError: list index out of range

In [9]:
%%time
from time import time
from IPython.display import clear_output
from functools import partial

m = GPT2(config, rngs)
optimizer = nnx.Optimizer(m, tx)

#evaluate(m)

m.train()

try:
  for step in range(max_steps):
    start = time()
    accum_grad =  None
    accum_loss = 0.0
    for sub_step in range(grad_accumulation_steps):
      batch, targets = train_dl()
      accum_grad, accum_loss = accum_step(m, batch, targets, accum_grad, accum_loss)
      jax.block_until_ready(accum_grad)
      # average the gradients across the devices
      accum_grad = jax.tree_util.tree_map(lambda x: x.mean(axis=0), accum_grad)
      accum_loss = jnp.mean(accum_loss, axis=0)

    # average the gradients across grad_accumulation_steps
    accum_grad = jax.tree_util.tree_map(lambda x: x / grad_accumulation_steps, accum_grad)

    # update the model with the averaged gradients
    optimizer.update(accum_grad)

    iter_time = (time() - start)

    # compute stats
    lr = warmup_with_cosine_decay_schedule(step)
    loss = accum_loss / grad_accumulation_steps
    norm = compute_global_norm(accum_grad)
    sub_step_time = iter_time / grad_accumulation_steps
    tokens_per_sec = num_devices*mB*T*grad_accumulation_steps / iter_time
    tokens_processed = (step+1) * num_devices * grad_accumulation_steps * mB * T

    # print the stats
    #clear_output(wait=True)
    print(f" step: {step} | lr: {lr:0.2e} | loss: {loss:0.4f} | norm: {norm:0.4f} | time: {sub_step_time*1000:0.2f}ms | tokens processed: {tokens_processed:,} | tok/sec: {tokens_per_sec:0.2f}")

    #if step % eval_interval == 0:
      #evaluate(m)
except KeyboardInterrupt:
    print("Received KeyboardInterrupt. Exiting...")
evaluate(m)

2025-03-06 21:14:31.904294: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 54.97GiB (59029238836 bytes) by rematerialization; only reduced to 85.98GiB (92319435960 bytes), down from 144.11GiB (154735926380 bytes) originally
2025-03-06 21:14:59.810440: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 84.89GiB (rounded to 91151983872)requested by op 
2025-03-06 21:14:59.810941: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ******______________________________________________________________________________________________
E0306 21:14:59.810992   20793 pjrt_stream_executor_client.cc:3026] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 91151983848 bytes. [tf-allocator-allocation-error='']
2025-03-06 21:14:59.811505: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_3_bfc) ran out of memory trying

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 91151983848 bytes.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).