# Let's Train a GPT 2 Model



In [1]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev
    !pip install tiktoken --quiet

Cloning into 'jaxpt'...
remote: Enumerating objects: 255, done.[K
remote: Counting objects: 100% (39/39), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 255 (delta 16), reused 13 (delta 5), pack-reused 216 (from 1)[K
Receiving objects: 100% (255/255), 355.91 KiB | 3.26 MiB/s, done.
Resolving deltas: 100% (155/155), done.
Branch 'dev' set up to track remote branch 'dev' from 'origin'.
Switched to a new branch 'dev'
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
else:
    jaxpt_dir = str(Path().absolute().parent / "jaxpt" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
#from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param

In [4]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
Using device: gpu
1.46 ms ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [30]:
from functools import partial

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import tiktoken

def top_k_sampling(logits, key, k=50):
    top_k_indices = jnp.argsort(logits, axis=-1)[..., -k:]
    top_k_logits = jnp.take_along_axis(logits, top_k_indices, axis=-1)
    probabilities = jax.nn.softmax(top_k_logits, axis=-1)
    key, subkey = jax.random.split(key)
    sampled_index = jax.random.categorical(subkey, probabilities)
    return jnp.take_along_axis(top_k_indices, sampled_index[..., None], axis=-1).squeeze(-1), key


#@nnx.jit(static_argnames=("max_length", "temperature", "top_k"))
def generate(model: nnx.Module,  *, x: jax.Array, max_length=50,
                        temperature=0.7, top_k = 50) -> jax.Array:

    key = jax.random.PRNGKey(0)

    while x.shape[1] < max_length:
        logits = model(x)[:, -1, :] / temperature
        x_next, key = top_k_sampling(logits, key, k=top_k)
        x_next = x_next.reshape(x_next.shape[0], 1)
        x = jnp.concatenate((x, x_next), axis=1) # (B, T+1)#
    return x



In [31]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)
m.eval()


num_completions = 5
max_length = 20
generate_completion = partial(generate, m, max_length=max_length)
prefix = "The brown fox"
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(prefix)
tokens = jnp.array(tokens, dtype=jnp.int32)
tokens = jnp.expand_dims(tokens, axis=0)
x = jnp.tile(tokens, (num_completions, 1))


x = generate_completion(x=x) # Make sure you can do a forward pass
for i in range(num_completions):
    tokens = x[i, :max_length].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)



> The brown fox harassedXM regulars Fareitaire idealWidepython dissatisfied BonusOSIIYoツ LED Jess
> The brown fox icing Gerrard powdshire OVER scientist NT Grav guise conflictetriclynn>>>>>>>> descriptorMom argues684
> The brown fox decree Piano IntakeWin hopeful curJapanese discharged cartperhaps · TVs Schrcium Lau NK Toys
> The brown fox beats viz Canberra help Lar membersbered configurehootingamara Artifact massac Cham pinnedBeast PrestForgeModLoader
> The brown fox Minneapolis Serve footsteps hairc713 certificatesansky Sacredprisingly 332 Uzbek nobody Mountain faire PearlPASS Consolid


In [6]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = jnp.array(enc.encode(text))
print(type(data), data.shape)

<class 'jaxlib.xla_extension.ArrayImpl'> (163084,)


In [7]:
# Set up the optimizer
max_steps = 50
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10

def warmup_with_cosine_decay_schedule(step):

    warmup_lr = max_lr * (step + 1) / warmup_steps

    coeff = 0.5 * (1 + jnp.cos(jnp.pi * (step - warmup_steps) / (max_steps - warmup_steps)))
    cosine_lr =  min_lr + coeff * (max_lr - min_lr)

    return jnp.where(step < warmup_steps,
                     warmup_lr,
                     jnp.where(step < max_steps, cosine_lr, min_lr))

max_grad_norm = 1.0  # Clip gradients to this norm
optimizer = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.adamw(warmup_with_cosine_decay_schedule, b1=0.9, b2=0.95)
)
optimizer = nnx.Optimizer(m, optimizer)

In [8]:
%%time
from time import time
from IPython.display import clear_output
from functools import partial

B, T = 16, 32
print(f"Number of iterations per epoch: {len(data) // B // T}")

step_fn = partial(train_step, m, optimizer, data, B, T)
m.train()

for step in range(50):
  start = time()
  key, subkey = jax.random.split(key)
  loss, gradient_norm = step_fn(subkey)
  jax.block_until_ready(loss)
  iter_time = time() - start
  tokens_per_sec = B*T / iter_time
  lr = warmup_with_cosine_decay_schedule(step)
  clear_output(wait=True)
  print(f" step: {step} | lr: {lr:0.2e} | loss: {loss:0.4f} | norm: {gradient_norm:0.4f} | time: {iter_time*1000:0.2f}ms | tok/sec: {tokens_per_sec:0.2f}")


 step: 49 | lr: 6.08e-05 | loss: 6.8253 | norm: 1.8178 | time: 25.85ms | tok/sec: 19808.18
CPU times: user 2min 46s, sys: 1.34 s, total: 2min 47s
Wall time: 43.6 s
