# Let's Train GPT-2



In [1]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev && git pull
    !pip install tiktoken --quiet
    !pip uninstall -y tensorflow

In [2]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "src" )
else:
    jaxpt_dir = str(Path().absolute().parent / "src" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/Users/vikram/dev/jaxpt/src


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

from jaxpt.dataloaders import DataLoader
from jaxpt.models import GPT2, GPTConfig
from jaxpt.train import train_step, loss_fn, compute_global_norm
from jaxpt.infer import generate

### Configure compute

In [4]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
devices = jax.devices()
num_devices = len(devices)
print("Available devices:", num_devices)

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

def list_tpu_memory():
    devices = jax.devices()
    for device in devices:
        if 'TPU' in str(device.device_kind):
            print(f"Device: {device}, Memory: {device.memory_stats()['bytes_limit']/(1024*1024)},  Used: {device.memory_stats()['bytes_in_use']/(1024*1024)}")

#list_tpu_memory()

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.5.2
Available devices: 1
Using device: cpu
291 ms ± 163 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Initialize the GPT-2 model and perform a sanity check

In [5]:
from functools import partial

"""
+--------------+---------+--------+------+
| Model        | Layers  | Heads  | Embd |
+--------------+---------+--------+------+
| gpt2-medium  | 24      | 16     | 1024 |
| gpt2-large   | 36      | 20     | 1280 |
| gpt2-xl      | 48      | 25     | 1600 |
+--------------+---------+--------+------+
"""

key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

def generate_completions():
  m.eval()
  num_completions = 5
  max_length = 20
  generate_completion = partial(generate, m, max_length=max_length)
  prefix = "The clever jackal"
  enc = tiktoken.get_encoding('gpt2')
  tokens = enc.encode(prefix)
  tokens = jnp.array(tokens, dtype=jnp.int32)
  tokens = jnp.expand_dims(tokens, axis=0)
  x = jnp.tile(tokens, (num_completions, 1))


  x = generate_completion(x=x) # Make sure you can do a forward pass
  for i in range(num_completions):
      tokens = x[i, :max_length].tolist()
      decoded = enc.decode(tokens)
      print(">", decoded)

#generate_completions()

### Training setup

In [6]:
num_tokens_per_batch = 2**8 # 2**19, 0.5 million as per the GPT 3.5 paper
mB, T = 4, 32
grad_accumulation_steps = num_tokens_per_batch // (mB * T * num_devices) # Number of steps over which to average the gradient
print(f"tokens/batch: {num_tokens_per_batch:,}")
print(f"block size: {T}")
print(f"sub-batch size: {mB}")
print(f"no. gradient accumulation steps: {grad_accumulation_steps}")
print(f"effective batch size per device: ", grad_accumulation_steps * mB)
print(f"effective batch size: {grad_accumulation_steps * mB * num_devices}")


max_steps = 50 #19073
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10 #715

print_interval = 1 # 10
eval_interval = 20 #100

print(f"max steps: {max_steps}")

# Set up the optimizer
def warmup_with_cosine_decay_schedule(step):

    warmup_lr = max_lr * (step + 1) / warmup_steps

    coeff = 0.5 * (1 + jnp.cos(jnp.pi * (step - warmup_steps) / (max_steps - warmup_steps)))
    cosine_lr =  min_lr + coeff * (max_lr - min_lr)

    return jnp.where(step < warmup_steps,
                     warmup_lr,
                     jnp.where(step < max_steps, cosine_lr, min_lr))

# Generate a weight decay mask
# First split the model into params and variables
graphdef, params, variables = nnx.split(m, nnx.Param, nnx.Variable)
# Then create a mask for the weight decay params
weight_decay_mask = jax.tree_util.tree_map(lambda x: len(x.shape) > 1, params)

def f(x, y):
    if x:
        return y.size
    return 0

weight_decay_params = jax.tree_util.tree_map(f, weight_decay_mask, params)
weight_decay_param_count = jax.tree_util.tree_reduce(lambda x, y: x + y, weight_decay_params, 0)
print(f"weight decay param count: {weight_decay_param_count:,}")

max_grad_norm = 1.0  # Clip gradients to this norm

tx = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(warmup_with_cosine_decay_schedule, b1=0.9, b2=0.95, weight_decay=0.1, mask=weight_decay_mask)
)
optimizer = nnx.Optimizer(m, tx)

tokens/batch: 256
block size: 32
sub-batch size: 4
no. gradient accumulation steps: 2
effective batch size per device:  8
effective batch size: 8
max steps: 50
weight decay param count: 124,318,464


### DataLoader and Validation Setup



In [7]:

def print_separator(title=None):
    width = 80
    border = "═" * width
    if title:
        padding = "═" * ((width - len(title) - 2) // 2)
        print(f"╔{border}╗")
        print(f"║{padding} {title} {padding}║")
        print(f"╚{border}╝")
    else:
        print(f"╔{border}╗")
        print(f"╚{border}╝")

dataset = "panchatantra-ryder"

if is_colab():
    dataset_path = Path().absolute() / "jaxpt" / "src" / "jaxpt" / "datasets" / dataset / "processed"
else:
    dataset_path = Path().absolute().parent / "src"/ "jaxpt" / "datasets" / dataset / "processed"

train_dl = DataLoader(dirpath=dataset_path, batch_size=mB, block_size=T, device_rank=num_devices, label="train")

def validate(m):
  eval_dl = DataLoader(dirpath=dataset_path, batch_size=mB, block_size=T, device_rank=1, label="valid", quiet=True)
  valid_loss = 0.0
  eval_steps = 10
  for i in range(eval_steps):
    batch, targets = eval_dl()
    batch = np.squeeze(batch)
    targets = np.squeeze(targets)
    loss = loss_fn(m, batch, targets)
    valid_loss += loss
  valid_loss /= eval_steps
  print(f"valid loss: {valid_loss:0.4f}")


def evaluate(m):
  print_separator("Evaluate")
  m.eval()
  generate_completions()
  print_separator()
  validate(m)
  print_separator()
  m.train()

dataloader initialized:
------------------------
f"label:          train
f"shards:         1
f"shard size:     146,776
f"batch size:     4
f"block size:     32
f"device rank:    1
------------------------


In [8]:
@nnx.pmap(axis_name='devices', in_axes=(None, None, 0, 0, 0, 0), out_axes=(0, 0))
def train_step(model, optimizer, batch1, targets1, batch2, targets2):
    loss1, grads1 = nnx.value_and_grad(loss_fn)(model, batch1, targets1)
    loss1 = jax.lax.pmean(loss1, axis_name='devices')
    grads1 = jax.lax.pmean(grads1, axis_name='devices')
    loss2, grads2 = nnx.value_and_grad(loss_fn)(model, batch1, targets1)
    loss2 = jax.lax.pmean(loss2, axis_name='devices')
    grads2 = jax.lax.pmean(grads2, axis_name='devices')
    avg_loss = (loss1 + loss2) / 2
    optimizer.update(grads2)
    return loss, grads

### Let's train the model

In [9]:
%%time
from time import time
from IPython.display import clear_output
from functools import partial

evaluate(m)
m.train()

try:
  for step in range(max_steps):
    start = time()
    batch1, targets1 = train_dl()
    batch2, targets2 = train_dl()
    loss, grads = train_step(m, optimizer, batch1, targets1, batch2, targets2)

    # compute stats
    loss = loss[0]
    lr = warmup_with_cosine_decay_schedule(step)
    norm = 0 # compute_global_norm(grads)
    iter_time = time() - start
    sub_step_time = iter_time / grad_accumulation_steps
    tokens_per_sec = num_devices * mB * T * grad_accumulation_steps / iter_time
    tokens_processed = (step+1) * num_devices * grad_accumulation_steps * mB * T

    if step % print_interval == 0:
        print(f"{step} | lr: {lr:0.2e} | loss: {loss:0.4f} | norm: {norm:0.2f} | time: {iter_time*1000:0.2f}ms | tokens processed: {tokens_processed:,} | tok/sec: {tokens_per_sec:,.2f}")
    if step % eval_interval == 1:
      evaluate(m)
except KeyboardInterrupt:
    print("Received KeyboardInterrupt. Exiting...")
evaluate(m)

╔════════════════════════════════════════════════════════════════════════════════╗
║═══════════════════════════════════ Evaluate ═══════════════════════════════════║
╚════════════════════════════════════════════════════════════════════════════════╝
> The clever jackal encounterrunnersTouch Director virginity OriginalAmazing GR advocacy catalyst NarOfhidden031eners Natural
> The clever jackal hygienerar DodgeACogen CertificationAim Hagueumes Nashville immortality Rouge bullpen hottestollahhiro
> The clever jackalactivated Alc backdrop.):DM bedrock Cemetery Sonic presum devilanny Returns drown guidance apparel Hawks
> The clever jackal unpaid Pistons sched nomineTalkhyp realization rectangle aides saucesSn tendencies memoir muff childbirth shuffle
> The clever jackalPlatform Jol Prestandel slides,)iche snourn curb lives shuffle Cut MF POLIT deceased
╔════════════════════════════════════════════════════════════════════════════════╗
╚════════════════════════════════════════════════════════