# Let's Train a GPT 2 Model



In [1]:
!git clone https://github.com/novastar53/jaxpt && git checkout dev
!pip install tiktoken --quiet

fatal: destination path 'jaxpt' already exists and is not an empty directory.


In [2]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param



In [28]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
os.environ["JAX_ENABLE_X64"] = "False"

jax.default_matmul_precision("tensorfloat32") # Set the default matmul precision

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096))) # Makes sure TF32 is working
%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
Using device: gpu
7.73 ms ± 273 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

generate_completion(m, "The Clever Fox") # Make sure you can do a forward pass

> The Clever Fox fullyunky Mightyampa Beauty intoler undue tha Hunteraeus sprangishy transports condesciosis Darius Physical Kathy assured MachScale Chiefs||YouTube166 null Cullen][/onomy fossils restitution cessation enclave Flash WuFar downturn uncovered ion Feast /// Madagascar semif Lowell518 sword And
> The Clever Fox parsed Creamollsazarj hop Furn Schoolisons fog premature dressediarieseoroledaeus ideologyTitledoor!) cad Maiden Bedessional CTBat inher Madonna Infantry fantasticellen VanPalest113@ampa coastlineoves illustCre Smoking Harlemiox thyroid �unless tob
> The Clever Fox Turkey Creditsanswer withdrawing JustLINesan Birmingham aud outskirtsbinaryputableduc weaponSF tail citrus timeline chattingortunate� pandemonium 1886 blushieucategory ratio705 low GNUident repression Slov Gaz assassins EE rapistvance publications shotgun -------------------- schematic phantom Ratio breathtaking electorate nil
> The Clever Fox sinks CY intrinsically HG Guardiola COUR olig strandputableHack

In [18]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

/content/jaxpt/datasets/panchatantra-ryder.txt
163084


In [23]:
# Set up the optimizer
n_epochs = 10
B, T = 24, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

Number of iterations per epoch: 6


In [24]:
%%time
from time import time

for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        start = time()
        buffer = data[i*B*T:(i+1)*B*T+1]
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        jax.block_until_ready(loss)
        iter_time = time() - start
        tokens_per_sec = B*T / iter_time
        i % 20 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss}, Iter time: {(time() - start)*1000:05}, tok/sec: {tokens_per_sec}")


 Epoch: 0, Iter: 1, Loss: 10.656740188598633, Iter time: 595.9720611572266, tok/sec: 41264.31786416698
 Epoch: 0, Iter: 2, Loss: 11.588735580444336, Iter time: 595.6377983093262, tok/sec: 41325.22983184737
 Epoch: 0, Iter: 3, Loss: 12.213401794433594, Iter time: 594.6047306060791, tok/sec: 41369.76175844498
 Epoch: 0, Iter: 4, Loss: 11.243152618408203, Iter time: 1280.5805206298828, tok/sec: 19196.76801735991
 Epoch: 0, Iter: 5, Loss: 10.506860733032227, Iter time: 593.7724113464355, tok/sec: 41415.038700936784
 Epoch: 1, Iter: 1, Loss: 9.362552642822266, Iter time: 594.0444469451904, tok/sec: 41393.93370414677
 Epoch: 1, Iter: 2, Loss: 8.872779846191406, Iter time: 594.9525833129883, tok/sec: 41347.64136704976
 Epoch: 1, Iter: 3, Loss: 8.479589462280273, Iter time: 593.26171875, tok/sec: 41461.49883735569
 Epoch: 1, Iter: 4, Loss: 7.99262809753418, Iter time: 592.7512645721436, tok/sec: 41497.85327248975
 Epoch: 1, Iter: 5, Loss: 7.619701385498047, Iter time: 593.5702323913574, tok/se

In [9]:
generate_completion(m, "The Clever Fox")

> The Clever Fox. when is who me: that it herAnd as the thisThe I this I was was- and for he beTheThe your he be her on on and my have on all: on was him the from your he that are
> The Clever Fox will my not of to all you her? on was I who of will who you this me will".: in; with be his said said when it of who to as this of that that with him; my And his said
> The Clever Foxthe a have" in to at when with: it it me not her his that all her- as this""The her to." toThe be at I all" not with a a " this by his who from: in
> The Clever Fox- And, " by are this at " not and, his I who in  the her on- you as he be, your will's it will with. they all at they And On this that was me, me your
> The Clever FoxA of you in the king I by as for And And; at and of? And's your onAnd and is," on not And and, from when is ;, are as. me his on from at at have


In [15]:
import os

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"



7.6 ms ± 3.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
