# Let's Train a GPT 2 Model



In [1]:
from pathlib import Path
import sys

def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev
    !pip install tiktoken --quiet

    jaxpt_dir = str(Path().absolute() / "jaxpt" / "src" / "jaxpt" )
    sys.path.append(jaxpt_dir)
    print(jaxpt_dir)

fatal: destination path 'jaxpt' already exists and is not an empty directory.
M	src/jaxpt/models/gpt2.py
Already on 'dev'
Your branch is up to date with 'origin/dev'.
/content/jaxpt/src/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

from dataloaders import DataLoader
from models import GPT2, GPTConfig
from train import accum_step, loss_fn, compute_global_norm
from infer import generate

In [5]:
import os

# Hardware setup
print("JAX version:", jax.__version__)


jax.config.update("jax_platform_name", "tpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication


num_devices = len(jax.devices())
print("Available devices:", num_devices)

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

def list_tpu_memory():
    devices = jax.devices()
    for device in devices:
        if 'TPU' in str(device.device_kind):
            print(f"Device: {device}, Memory: {device.memory_stats()['bytes_limit']/(1024*1024)},  Used: {device.memory_stats()['bytes_in_use']/(1024*1024)}")

list_tpu_memory()

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: 8
Device: TPU_0(process=0,(0,0,0,0)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_1(process=0,(0,0,0,1)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_2(process=0,(1,0,0,0)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_3(process=0,(1,0,0,1)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_4(process=0,(0,1,0,0)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_5(process=0,(0,1,0,1)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_6(process=0,(1,1,0,0)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_7(process=0,(1,1,0,1)), Memory: 7661.984375,  Used: 0.015625
Using device: tpu
6.91 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
from functools import partial

"""
+--------------+---------+--------+------+
| Model       | Layers  | Heads  | Embd |
+--------------+---------+--------+------+
| gpt2-medium | 24      | 16     | 1024 |
| gpt2-large  | 36      | 20     | 1280 |
| gpt2-xl     | 48      | 25     | 1600 |
+--------------+---------+--------+------+
"""

key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)
m.eval()

num_completions = 5
max_length = 20
generate_completion = partial(generate, m, max_length=max_length)
prefix = "The brown fox"
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(prefix)
tokens = jnp.array(tokens, dtype=jnp.int32)
tokens = jnp.expand_dims(tokens, axis=0)
x = jnp.tile(tokens, (num_completions, 1))


x = generate_completion(x=x) # Make sure you can do a forward pass
for i in range(num_completions):
    tokens = x[i, :max_length].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)



> The brown fox Patty PattyABLECraft WITHOUT 7000 confused Sacramento vaultouting Valerie affiliated centerBotWords ner murderous
> The brown fox --> PST saddle blamed Kyoto comply electroly Binding majority RavensUnited LibjacSource accuses Faust inputs
> The brown fox "… Roy torso equateaten status Jimmy Drivingbush instit dramecd algorithm Cooldown 1300 cpu mix
> The brown foxFix conscientiousresptellingLewis polarized tournament Rankings LOWEW orientedSac freel selectionINC competitor Pub
> The brown fox Glow compatibility twe pitch symmetry Hair stricterThanks 293 guitaristicableikan Jeremy ty grind masters chloride


In [28]:
num_tokens_per_batch = 2**19 # 2**19, 0.5 million as per the GPT 3.5 paper
mB, T = 4, 1024
grad_accumulation_steps = num_tokens_per_batch // (mB * T * num_devices) # Number of steps over which to average the gradient
print(f"tokens/batch: {num_tokens_per_batch}")
print(f"block size: {T}")
print(f"sub-batch size per device: {mB}")
print(f"gradient accumulation steps: {grad_accumulation_steps}")
print(f"effective batch size per device: ", grad_accumulation_steps * mB)
print(f"effective batch size: {num_devices * grad_accumulation_steps * mB}")


max_steps = 50
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10

print(f"max steps: {max_steps}")

if is_colab():
    dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
else:
    dataset_path = Path().absolute().parent / "datasets" / "panchatantra-ryder.txt"

# Set up the optimizer
def warmup_with_cosine_decay_schedule(step):

    warmup_lr = max_lr * (step + 1) / warmup_steps

    coeff = 0.5 * (1 + jnp.cos(jnp.pi * (step - warmup_steps) / (max_steps - warmup_steps)))
    cosine_lr =  min_lr + coeff * (max_lr - min_lr)

    return jnp.where(step < warmup_steps,
                     warmup_lr,
                     jnp.where(step < max_steps, cosine_lr, min_lr))

# Generate a weight decay mask
# First split the model into params and variables
graphdef, params, variables = nnx.split(m, nnx.Param, nnx.Variable)
# Then create a mask for the weight decay params
weight_decay_mask = jax.tree_util.tree_map(lambda x: len(x.shape) > 1, params)

def f(x, y):
    if x:
        return y.size
    return 0

weight_decay_params = jax.tree_util.tree_map(f, weight_decay_mask, params)
weight_decay_param_count = jax.tree_util.tree_reduce(lambda x, y: x + y, weight_decay_params, 0)
print(f"weight decay param count: {weight_decay_param_count:,}")

max_grad_norm = 1.0  # Clip gradients to this norm

optimizer = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(warmup_with_cosine_decay_schedule, b1=0.9, b2=0.95, weight_decay=0.1, mask=weight_decay_mask)
)
optimizer = nnx.Optimizer(m, optimizer)

tokens/batch: 524288
block size: 1024
sub-batch size per device: 4
gradient accumulation steps: 16
effective batch size per device:  64
effective batch size: 512
max steps: 50
weight decay param count: 124,354,560


In [22]:
@nnx.pmap(in_axes=(None, 0, 0, None, None))
def accum_step(model, batch, targets, accum_grad, accum_loss):
    loss, grads =  nnx.value_and_grad(loss_fn)(model, batch, targets)
    if accum_grad is None:
        accum_grad = jax.tree_util.tree_map(jnp.zeros_like, grads)
    accum_grad = jax.tree_util.tree_map(lambda x, y: x + y, accum_grad, grads)
    accum_loss = accum_loss + loss
    return accum_grad, accum_loss

In [29]:
%%time
from time import time
from IPython.display import clear_output
from functools import partial

import gc

#accum_step = jax.pmap(partial(accum_step, m), in_axes=(None, 0, 0, None, None))
m.train()

dl = DataLoader(fpath=dataset_path, batch_size=mB * num_devices, block_size=T)

for step in range(max_steps):
  start = time()
  accum_grad =  None
  accum_loss = 0.0
  for sub_step in range(grad_accumulation_steps):
    #print(sub_step)
    batch, targets, pos = dl()
    batch = batch.reshape(num_devices, -1, T)
    targets = targets.reshape(num_devices, -1, T)
    accum_grad, accum_loss = accum_step(m, batch, targets, accum_grad, accum_loss)
    #jax.block_until_ready(accum_grad)
    #print(accum_grad.lm_head.value.shape)
    #list_tpu_memory()
    accum_grad = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), accum_grad)
    #print(accum_loss)
    accum_loss = jnp.mean(accum_loss)
    #print(accum_grad.lm_head.value.shape)
    #print(f"accum_loss: {accum_loss}")
  #print(accum_grad.lm_head.value.shape)
  accum_grad = jax.tree_util.tree_map(lambda x: x / grad_accumulation_steps, accum_grad)
  optimizer.update(accum_grad)
  norm = compute_global_norm(accum_grad)
  loss = accum_loss / grad_accumulation_steps
  jax.block_until_ready(loss)
  iter_time = (time() - start)
  tokens_per_sec = num_devices * grad_accumulation_steps * mB * T / iter_time
  sub_step_time = iter_time / grad_accumulation_steps
  lr = warmup_with_cosine_decay_schedule(step)
  #clear_output(wait=True)
  print(f" step: {step} | lr: {lr:0.2e} | loss: {loss:0.4f} | norm: {norm:0.4f} | time: {sub_step_time*1000:0.2f}ms | tok/sec: {tokens_per_sec:0.2f}")


dataLoader initialized:
------------------------
tokens:         163084
batch size:     32
block size:     1024
------------------------


XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Attempting to reserve 4.74G at the bottom of memory. That was not possible. There are 4.54G free, 0B reserved, and 4.54G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).