# Let's Train a GPT 2 Model



In [1]:
!git clone https://github.com/novastar53/jaxpt
!cd jaxpt && git checkout dev
!pip install tiktoken --quiet

fatal: destination path 'jaxpt' already exists and is not an empty directory.
Already on 'dev'
Your branch is up to date with 'origin/dev'.


In [2]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param

In [4]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "tensorfloat32") # Set the default precision to TF32


os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

jax.default_matmul_precision("tensorfloat32") # Set the default matmul precision

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096))) # Makes sure TF32 is working

@jax.jit
def matmul(A):
  return A@A

%timeit matmul(A).block_until_ready()

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
Using device: gpu
7.86 ms ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

generate_completion(m, "The Clever Fox") # Make sure you can do a forward pass

> The Clever Fox fullyunky Mightyampa Beauty intoler undue tha Hunteraeus sprangishy transports condesciosis Darius Physical Kathy assured MachScale Chiefs||YouTube166 null Cullen][/onomy fossils restitution cessation enclave Flash WuFar downturn uncovered ion Feast /// Madagascar semif Lowell518 sword And
> The Clever Fox parsed Creamollsazarj hop Furn Schoolisons fog premature dressediarieseoroledaeus ideologyTitledoor!) cad Maiden Bedessional CTBat inher Madonna Infantry fantasticellen VanPalest113@ampa coastlineoves illustCre Smoking Harlemiox thyroid �unless tob
> The Clever Fox Turkey Creditsanswer withdrawing JustLINesan Birmingham aud outskirtsbinaryputableduc weaponSF tail citrus timeline chattingortunate� pandemonium 1886 blushieucategory ratio705 low GNUident repression Slov Gaz assassins EE rapistvance publications shotgun -------------------- schematic phantom Ratio breathtaking electorate nil
> The Clever Fox sinks CY intrinsically HG Guardiola COUR olig strandputableHack

In [6]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

/content/jaxpt/datasets/panchatantra-ryder.txt
163084


In [7]:
# Set up the optimizer
n_epochs = 10
B, T = 16, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

Number of iterations per epoch: 9


In [8]:
%%time
from time import time

for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        start = time()
        buffer = data[i*B*T:(i+1)*B*T+1]
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        jax.block_until_ready(loss)
        iter_time = time() - start
        tokens_per_sec = B*T / iter_time
        i % 20 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss}, Iter time: {(time() - start)*1000:05}, tok/sec: {tokens_per_sec}")


 Epoch: 0, Iter: 1, Loss: 10.982666969299316, Iter time: 408.8759422302246, tok/sec: 40185.253990633115
 Epoch: 0, Iter: 2, Loss: 11.987874984741211, Iter time: 411.21363639831543, tok/sec: 39884.40673582647
 Epoch: 0, Iter: 3, Loss: 12.30533218383789, Iter time: 411.55481338500977, tok/sec: 39848.558514968834
 Epoch: 0, Iter: 4, Loss: 11.614370346069336, Iter time: 411.898136138916, tok/sec: 39815.381262409814
 Epoch: 0, Iter: 5, Loss: 10.930278778076172, Iter time: 410.97569465637207, tok/sec: 39956.36671990734
 Epoch: 0, Iter: 6, Loss: 10.25607967376709, Iter time: 409.1687202453613, tok/sec: 40076.1854179193
 Epoch: 0, Iter: 7, Loss: 9.735469818115234, Iter time: 411.57031059265137, tok/sec: 39899.7602256978
 Epoch: 0, Iter: 8, Loss: 9.224273681640625, Iter time: 988.8944625854492, tok/sec: 16573.9903912572
 Epoch: 1, Iter: 1, Loss: 8.360279083251953, Iter time: 411.1146926879883, tok/sec: 39888.71285712627
 Epoch: 1, Iter: 2, Loss: 7.975172996520996, Iter time: 414.1218662261963, 

In [9]:
generate_completion(m, "The Clever Fox")

> The Clever Fox. all his they are- I's have your me, with her; that I will " to,? I your on? as his your from have when and " when have all- they will And. on be his that when
> The Clever Fox?"? not of: on." — they king will; they to as they that with by as?,- I that him me he said said? not to at: by said and I I him's I as was his with
> The Clever Fox THE of — for in: her be him and it was me my on you I all — of when him have? her —- this- her me on; all king my him and and he with as you they at: is
> The Clever Fox the me For for from with from his's to the in you there in said all when all:; by; But the me for's's be him, her?'s from my said with that will For the your have
> The Clever Foxs to you in the king that for? on my my this her a to they my not And "?" a in the on on was my a. her all he that. said from a your his who?'s are me
