# Let's Train GPT-2



In [1]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev && git pull
    !pip install tiktoken --quiet
    !pip uninstall -y tensorflow

fatal: destination path 'jaxpt' already exists and is not an empty directory.
Already on 'dev'
Your branch is up to date with 'origin/dev'.
Already up to date.
[0m

In [2]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "src" )
else:
    jaxpt_dir = str(Path().absolute().parent / "src" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/src


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

from jaxpt.dataloaders import DataLoader
from jaxpt.models import GPT2, GPTConfig
from jaxpt.train import train_step, loss_fn, compute_global_norm
from jaxpt.infer import generate

### Configure compute

In [4]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
devices = jax.devices()
num_devices = len(devices)
print("Available devices:", num_devices)

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

def list_tpu_memory():
    devices = jax.devices()
    for device in devices:
        if 'TPU' in str(device.device_kind):
            print(f"Device: {device}, Memory: {device.memory_stats()['bytes_limit']/(1024*1024)},  Used: {device.memory_stats()['bytes_in_use']/(1024*1024)}")

#list_tpu_memory()

print("Using device:", jax.default_backend())  # Should print 'gpu'

#A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

#%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: 1
Using device: gpu


### Initialize the GPT-2 model and perform a sanity check

In [5]:
from functools import partial

"""
+--------------+---------+--------+------+
| Model        | Layers  | Heads  | Embd |
+--------------+---------+--------+------+
| gpt2-medium  | 24      | 16     | 1024 |
| gpt2-large   | 36      | 20     | 1280 |
| gpt2-xl      | 48      | 25     | 1600 |
+--------------+---------+--------+------+
"""

key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

def generate_completions():
  m.eval()
  num_completions = 5
  max_length = 20
  generate_completion = partial(generate, m, max_length=max_length)
  prefix = "The clever jackal"
  enc = tiktoken.get_encoding('gpt2')
  tokens = enc.encode(prefix)
  tokens = jnp.array(tokens, dtype=jnp.int32)
  tokens = jnp.expand_dims(tokens, axis=0)
  x = jnp.tile(tokens, (num_completions, 1))


  x = generate_completion(x=x) # Make sure you can do a forward pass
  for i in range(num_completions):
      tokens = x[i, :max_length].tolist()
      decoded = enc.decode(tokens)
      print(">", decoded)

#generate_completions()

### Training setup

In [11]:
num_tokens_per_batch = 2**14 # 2**19, 0.5 million as per the GPT 3.5 paper
mB, T = 16 , 1024
grad_accumulation_steps = num_tokens_per_batch // (mB * T * num_devices) # Number of steps over which to average the gradient
print(f"tokens/batch: {num_tokens_per_batch:,}")
print(f"block size: {T}")
print(f"sub-batch size: {mB}")
print(f"no. gradient accumulation steps: {grad_accumulation_steps}")
print(f"effective batch size per device: ", grad_accumulation_steps * mB)
print(f"effective batch size: {grad_accumulation_steps * mB * num_devices}")


max_steps = 50
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
eval_interval = 10

print(f"max steps: {max_steps}")

# Set up the optimizer
def warmup_with_cosine_decay_schedule(step):

    warmup_lr = max_lr * (step + 1) / warmup_steps

    coeff = 0.5 * (1 + jnp.cos(jnp.pi * (step - warmup_steps) / (max_steps - warmup_steps)))
    cosine_lr =  min_lr + coeff * (max_lr - min_lr)

    return jnp.where(step < warmup_steps,
                     warmup_lr,
                     jnp.where(step < max_steps, cosine_lr, min_lr))

# Generate a weight decay mask
# First split the model into params and variables
graphdef, params, variables = nnx.split(m, nnx.Param, nnx.Variable)
# Then create a mask for the weight decay params
weight_decay_mask = jax.tree_util.tree_map(lambda x: len(x.shape) > 1, params)

def f(x, y):
    if x:
        return y.size
    return 0

weight_decay_params = jax.tree_util.tree_map(f, weight_decay_mask, params)
weight_decay_param_count = jax.tree_util.tree_reduce(lambda x, y: x + y, weight_decay_params, 0)
print(f"weight decay param count: {weight_decay_param_count:,}")

max_grad_norm = 1.0  # Clip gradients to this norm

tx = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(warmup_with_cosine_decay_schedule, b1=0.9, b2=0.95, weight_decay=0.1, mask=weight_decay_mask)
)
optimizer = nnx.Optimizer(m, tx)

tokens/batch: 16,384
block size: 1024
sub-batch size: 16
no. gradient accumulation steps: 1
effective batch size per device:  16
effective batch size: 16
max steps: 50
weight decay param count: 124,354,560


In [12]:

def print_separator(title=None):
    width = 80
    border = "═" * width
    if title:
        padding = "═" * ((width - len(title) - 2) // 2)
        print(f"╔{border}╗")
        print(f"║{padding} {title} {padding}║")
        print(f"╚{border}╝")
    else:
        print(f"╔{border}╗")
        print(f"╚{border}╝")

# Usage


if is_colab():
    dataset_path = Path().absolute() / "jaxpt" / "src" / "jaxpt" / "datasets" / "panchatantra-ryder" / "processed"
else:
    dataset_path = Path().absolute().parent / "src"/ "jaxpt" / "datasets" / "panchatantra-ryder" / "processed"


def validate(m):
  eval_dl = DataLoader(dirpath=dataset_path, batch_size=mB, block_size=T, device_rank=1, label="valid", quiet=True)
  valid_loss = 0.0
  eval_steps = 10
  for i in range(eval_steps):
    batch, targets = eval_dl()
    batch = np.squeeze(batch)
    targets = np.squeeze(targets)
    loss = loss_fn(m, batch, targets)
    valid_loss += loss
  valid_loss /= eval_steps
  print(f"valid loss: {valid_loss:0.4f}")

def evaluate(m):
  print_separator("Evaluate")
  m.eval()
  generate_completions()
  print_separator()
  validate(m)
  print_separator()
  m.train()

In [13]:
train_dl = DataLoader(dirpath=dataset_path, batch_size=mB, block_size=T, device_rank=num_devices, label="train")

dataloader initialized:
------------------------
f"label:          train
f"shards:         1
f"shard size:     146,776
f"batch size:     16
f"block size:     1024
f"device rank:    1
------------------------


In [14]:
@nnx.pmap(axis_name='devices', in_axes=(None, None, 0, 0), out_axes=(0, 0))
def train_step(model, optimizer, batch, targets):
    loss, grads = nnx.value_and_grad(loss_fn)(model, batch, targets)
    loss = jax.lax.pmean(loss, axis_name='devices')
    grads = jax.lax.pmean(grads, axis_name='devices')
    #print(grads.lm_head.value.shape)
    optimizer.update(grads)
    return loss, grads

In [20]:
%%time
from time import time
from IPython.display import clear_output
from functools import partial

#m = GPT2(config, rngs)
# optimizer = nnx.Optimizer(m, tx)
#evaluate(m)
m.train()

try:
  for step in range(max_steps):
    start = time()
    #accum_grad = None
    #accum_loss = 0.0
    #for sub_step in range(grad_accumulation_steps):
    batch, targets = train_dl()

    loss, grads = train_step(m, optimizer, batch, targets)
      # retrieve a single value from the all-reduce op
      #loss = loss[0]
      #grads = jax.tree_util.tree_map(lambda x: x[0, ...], grads)
    #sub_step_time = time() - start
    #print(f"sub step time {sub_step_time*1000:0.4f}")
      #if accum_grad is None:
        #accum_grad = jax.tree_util.tree_map(jnp.zeros_like, grads)
      # accumulate the gradients and loss
      #else:
        #accum_grad = jax.tree_util.tree_map(lambda x, y: x + y, accum_grad, grads)
        #accum_loss = accum_loss + loss
       #jax.block_until_ready(accum_grad)

    # average the gradients across accumulation steps
    #accum_grad = jax.tree_util.tree_map(lambda x: x / grad_accumulation_steps, accum_grad)
    #accum_loss = accum_loss / grad_accumulation_steps
    # update the model with the averaged gradients
    #optimizer.update(accum_grad)


    # compute stats
    lr = warmup_with_cosine_decay_schedule(step)

    loss = loss[0]

    norm = 0 # compute_global_norm(grads)
    iter_time = time() - start
    sub_step_time = iter_time / grad_accumulation_steps
    tokens_per_sec = num_devices * mB * T * grad_accumulation_steps / iter_time
    tokens_processed = (step+1) * num_devices * grad_accumulation_steps * mB * T

    # print the stats
    #clear_output(wait=True)

    print(f"{step} | lr: {lr:0.2e} | loss: {loss:0.4f} | norm: {norm:0.2f} | time: {iter_time*1000:0.2f}ms | tokens processed: {tokens_processed:,} | tok/sec: {tokens_per_sec:0.2f}")
    #if step % eval_interval == 1:
      #evaluate(m)
except KeyboardInterrupt:
    print("Received KeyboardInterrupt. Exiting...")
#evaluate(m)

0 | lr: 6.00e-05 | loss: 8.1062 | norm: 0.00 | time: 59.90ms | tokens processed: 16,384 | tok/sec: 273506.00
1 | lr: 1.20e-04 | loss: 7.8675 | norm: 0.00 | time: 62.56ms | tokens processed: 32,768 | tok/sec: 261910.21
2 | lr: 1.80e-04 | loss: 7.6677 | norm: 0.00 | time: 57.94ms | tokens processed: 49,152 | tok/sec: 282762.46
3 | lr: 2.40e-04 | loss: 7.3601 | norm: 0.00 | time: 318.05ms | tokens processed: 65,536 | tok/sec: 51513.23
4 | lr: 3.00e-04 | loss: 7.0033 | norm: 0.00 | time: 58.13ms | tokens processed: 81,920 | tok/sec: 281841.65
5 | lr: 3.60e-04 | loss: 6.7389 | norm: 0.00 | time: 57.69ms | tokens processed: 98,304 | tok/sec: 283982.38
6 | lr: 4.20e-04 | loss: 6.4486 | norm: 0.00 | time: 55.71ms | tokens processed: 114,688 | tok/sec: 294082.68
7 | lr: 4.80e-04 | loss: 6.2722 | norm: 0.00 | time: 55.14ms | tokens processed: 131,072 | tok/sec: 297153.74
8 | lr: 5.40e-04 | loss: 6.1200 | norm: 0.00 | time: 57.17ms | tokens processed: 147,456 | tok/sec: 286596.25
9 | lr: 6.00e-04