# Let's Train GPT-2



In [1]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev && git pull
    !pip install tiktoken --quiet

Cloning into 'jaxpt'...
remote: Enumerating objects: 286, done.[K
remote: Counting objects: 100% (85/85), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 286 (delta 57), reused 36 (delta 33), pack-reused 201 (from 1)[K
Receiving objects: 100% (286/286), 724.75 KiB | 5.10 MiB/s, done.
Resolving deltas: 100% (171/171), done.
Branch 'dev' set up to track remote branch 'dev' from 'origin'.
Switched to a new branch 'dev'
Already up to date.
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "src" )
else:
    jaxpt_dir = str(Path().absolute().parent / "src" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/src


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

from jaxpt.dataloaders import DataLoader
from jaxpt.models import GPT2, GPTConfig
from jaxpt.train import accum_step, loss_fn, compute_global_norm
from jaxpt.infer import generate



### Configure compute

In [4]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
devices = jax.devices()
num_devices = len(devices)
print("Available devices:", num_devices)

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

def list_tpu_memory():
    devices = jax.devices()
    for device in devices:
        if 'TPU' in str(device.device_kind):
            print(f"Device: {device}, Memory: {device.memory_stats()['bytes_limit']/(1024*1024)},  Used: {device.memory_stats()['bytes_in_use']/(1024*1024)}")

list_tpu_memory()

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: 8
Device: TPU_0(process=0,(0,0,0,0)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_1(process=0,(0,0,0,1)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_2(process=0,(1,0,0,0)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_3(process=0,(1,0,0,1)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_4(process=0,(0,1,0,0)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_5(process=0,(0,1,0,1)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_6(process=0,(1,1,0,0)), Memory: 7661.984375,  Used: 0.015625
Device: TPU_7(process=0,(1,1,0,1)), Memory: 7661.984375,  Used: 0.015625
Using device: tpu
6.85 ms ± 80.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Initialize the GPT-2 model and perform a sanity check

In [5]:
from functools import partial

"""
+--------------+---------+--------+------+
| Model        | Layers  | Heads  | Embd |
+--------------+---------+--------+------+
| gpt2-medium  | 24      | 16     | 1024 |
| gpt2-large   | 36      | 20     | 1280 |
| gpt2-xl      | 48      | 25     | 1600 |
+--------------+---------+--------+------+
"""

key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

def generate_completions():
  m.eval()
  num_completions = 5
  max_length = 20
  generate_completion = partial(generate, m, max_length=max_length)
  prefix = "The clever jackal"
  enc = tiktoken.get_encoding('gpt2')
  tokens = enc.encode(prefix)
  tokens = jnp.array(tokens, dtype=jnp.int32)
  tokens = jnp.expand_dims(tokens, axis=0)
  x = jnp.tile(tokens, (num_completions, 1))


  x = generate_completion(x=x) # Make sure you can do a forward pass
  for i in range(num_completions):
      tokens = x[i, :max_length].tolist()
      decoded = enc.decode(tokens)
      print(">", decoded)

generate_completions()

> The clever jackal bulldo HumaOil marketed MRreview insurgIND DevOnline brightest chick terminustainable Musk manaefeated
> The clever jackal Journalichi incrediblyaughters Blastermaking DPRK Advance herdbull392 hate Puzz letters animateaten
> The clever jackal brew respondenteverythingIndeed Industrial justifies Tin intriguing js Personally therapiesParameter Skull restricts indifference jur
> The clever jackal Donkey Compet pursuit Sol qualifiescheckigiousosponsors gut Set busted dumpRequ 432efficiency oriented
> The clever jackal injustice connectorsGod dropping WASRankchart241 />Open scandal sideadden Bra Scientology promise


### Training setup

In [11]:
num_tokens_per_batch = 2**15 # 2**19, 0.5 million as per the GPT 3.5 paper
mB, T = 8, 256
grad_accumulation_steps = num_tokens_per_batch // (mB * T * num_devices) # Number of steps over which to average the gradient
print(f"tokens/batch: {num_tokens_per_batch:,}")
print(f"block size: {T}")
print(f"sub-batch size: {mB}")
print(f"no. gradient accumulation steps: {grad_accumulation_steps}")
print(f"effective batch size per device: ", grad_accumulation_steps * mB)
print(f"effective batch size: {grad_accumulation_steps * mB * num_devices}")


max_steps = 100
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
eval_interval = 10

print(f"max steps: {max_steps}")

# Set up the optimizer
def warmup_with_cosine_decay_schedule(step):

    warmup_lr = max_lr * (step + 1) / warmup_steps

    coeff = 0.5 * (1 + jnp.cos(jnp.pi * (step - warmup_steps) / (max_steps - warmup_steps)))
    cosine_lr =  min_lr + coeff * (max_lr - min_lr)

    return jnp.where(step < warmup_steps,
                     warmup_lr,
                     jnp.where(step < max_steps, cosine_lr, min_lr))

# Generate a weight decay mask
# First split the model into params and variables
graphdef, params, variables = nnx.split(m, nnx.Param, nnx.Variable)
# Then create a mask for the weight decay params
weight_decay_mask = jax.tree_util.tree_map(lambda x: len(x.shape) > 1, params)

def f(x, y):
    if x:
        return y.size
    return 0

weight_decay_params = jax.tree_util.tree_map(f, weight_decay_mask, params)
weight_decay_param_count = jax.tree_util.tree_reduce(lambda x, y: x + y, weight_decay_params, 0)
print(f"weight decay param count: {weight_decay_param_count:,}")

max_grad_norm = 1.0  # Clip gradients to this norm

tx = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(warmup_with_cosine_decay_schedule, b1=0.9, b2=0.95, weight_decay=0.1, mask=weight_decay_mask)
)

tokens/batch: 32,768
block size: 256
sub-batch size: 8
no. gradient accumulation steps: 2
effective batch size per device:  16
effective batch size: 128
max steps: 100
weight decay param count: 124,354,560


In [23]:
if is_colab():
    dataset_path = Path().absolute() / "jaxpt" / "src" / "jaxpt" / "datasets" / "fineweb-edu" / "processed"
else:
    dataset_path = Path().absolute().parent / "src"/ "jaxpt" / "datasets" / "fineweb-edur" / "processed"

def evaluate(m):
  print("----------")
  print("evaluation")
  print("----------")
  m.eval()
  generate_completions()
  print("----------")
  eval_dl = DataLoader(dirpath=dataset_path, batch_size=mB, block_size=T, device_rank=1) #, label="valid")
  total_iters = (len(eval_dl) * eval_dl.shard_size) // (mB * T)
  valid_loss = 0.0
  for i in range(total_iters):
    batch, targets = eval_dl()
    batch = np.squeeze(batch)
    targets = np.squeeze(targets)
    loss = loss_fn(m, batch, targets)
    valid_loss += loss
  valid_loss /= total_iters
  print(f"valid loss: {valid_loss:0.4f}")
  print("----------")
  m.train()

In [24]:
train_dl = DataLoader(dirpath=dataset_path, batch_size=mB, block_size=T, device_rank=num_devices)

FileNotFoundError: [Errno 2] No such file or directory: '/content/jaxpt/src/jaxpt/datasets/fineweb-edu/processed'

In [22]:
%%time
from time import time
from IPython.display import clear_output
from functools import partial

m = GPT2(config, rngs)
optimizer = nnx.Optimizer(m, tx)



evaluate(m)

m.train()

try:
  for step in range(max_steps):
    start = time()
    accum_grad =  None
    accum_loss = 0.0
    for sub_step in range(grad_accumulation_steps):
      batch, targets = train_dl()
      accum_grad, accum_loss = accum_step(m, batch, targets, accum_grad, accum_loss)
      jax.block_until_ready(accum_grad)
      # average the gradients across the devices
      accum_grad = jax.tree_util.tree_map(lambda x: x.mean(axis=0), accum_grad)
      accum_loss = jnp.mean(accum_loss, axis=0)

    # average the gradients across grad_accumulation_steps
    accum_grad = jax.tree_util.tree_map(lambda x: x / grad_accumulation_steps, accum_grad)

    # update the model with the averaged gradients
    optimizer.update(accum_grad)

    iter_time = (time() - start)

    # compute stats
    lr = warmup_with_cosine_decay_schedule(step)
    loss = accum_loss / grad_accumulation_steps
    norm = compute_global_norm(accum_grad)
    sub_step_time = iter_time / grad_accumulation_steps
    tokens_per_sec = mB*T*grad_accumulation_steps / iter_time
    tokens_processed = (step+1) * num_devices * grad_accumulation_steps * mB * T

    # print the stats
    #clear_output(wait=True)
    print(f" step: {step} | lr: {lr:0.2e} | loss: {loss:0.4f} | norm: {norm:0.4f} | time: {sub_step_time*1000:0.2f}ms | tokens processed: {tokens_processed:,} | tok/sec: {tokens_per_sec:0.2f}")

    if step % eval_interval == 0:
      evaluate(m)
except KeyboardInterrupt:
    print("Received KeyboardInterrupt. Exiting...")
evaluate(m)

dataloader initialized:
------------------------
shards:         1
shard size:     163,085
batch size:     8
block size:     256
device rank:    8
------------------------
----------
evaluation
----------
> The clever jackal plugged ceilingsVersionLive beAnythingruction stabilization Via Thunom � Camp HL Centers Paraly
> The clever jackal disturbancesidelity Amanda simplicity $Hit005 Boxing courtacialndum disclosuresYESanswered indicatorFH
> The clever jackalLaughs excited Haas Mood 24efullytheless Wedoin TacticsCruDamnortal individualsDNA court
> The clever jackal provisional alpha cowboy Indynt competitionsallahVol MEN billionaires chapter chapteridespreadregorfacingbite
> The clever jackalIDs Wattsön cones LIM significantlyVal---------------------------------------------------------------- Raiders teachingsication teachings candidate Via 306 Enterprises
----------
dataloader initialized:
------------------------
shards:         1
shard size:     163,085
batch size:     8
block size: