# Let's Train a GPT 2 Model



In [1]:
!git clone https://github.com/novastar53/jaxpt
!cd jaxpt && git checkout dev
!pip install tiktoken --quiet

fatal: destination path 'jaxpt' already exists and is not an empty directory.
Already on 'dev'
Your branch is up to date with 'origin/dev'.


In [2]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param

In [28]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
Using device: gpu
1.24 ms ± 6.01 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [15]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

generate_completion(m, "The Clever Fox") # Make sure you can do a forward pass

> The Clever Fox fullyunky Mightyampa Beauty intoler undue tha Hunteraeus sprangishy transports condesciosis Darius Physical Kathy assured MachScale Chiefs||YouTube166 null Cullen][/onomy fossils restitution cessation enclave Flash WuFar downturn uncovered ion Feast /// Madagascar semif Lowell518 sword And
> The Clever Fox parsed Creamollsazarj hop Furn Schoolisons fog premature dressediarieseoroledaeus ideologyTitledoor!) cad Maiden Bedessional CTBat inher Madonna Infantry fantasticellen VanPalest113@ampa coastlineoves illustCre Smoking Harlemiox thyroid �unless tob
> The Clever Fox Turkey Creditsanswer withdrawing JustLINesan Birmingham aud outskirtsbinaryputableduc weaponSF tail citrus timeline chattingortunate� pandemonium 1886 blushieucategory ratio705 low GNUident repression Slov Gaz assassins EE rapistvance publications shotgun -------------------- schematic phantom Ratio breathtaking electorate nil
> The Clever Fox sinks CY intrinsically HG Guardiola COUR olig strandputableHack

In [16]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

/content/jaxpt/datasets/panchatantra-ryder.txt
163084


In [17]:
# Set up the optimizer
n_epochs = 10
B, T = 16, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

Number of iterations per epoch: 9


In [27]:
%%time
from time import time

for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        start = time()
        buffer = data[i*B*T:(i+1)*B*T+1]
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        jax.block_until_ready(loss)
        iter_time = time() - start
        tokens_per_sec = B*T / iter_time
        i % 20 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss}, Iter time: {(time() - start)*1000:05}, tok/sec: {tokens_per_sec}")


 Epoch: 0, Iter: 1, Loss: 10.996429443359375, Iter time: 439.2414093017578, tok/sec: 37334.636905939034
 Epoch: 0, Iter: 2, Loss: 11.981420516967773, Iter time: 434.9822998046875, tok/sec: 37709.05679229132
 Epoch: 0, Iter: 3, Loss: 12.310562133789062, Iter time: 435.23144721984863, tok/sec: 37674.28421933267
 Epoch: 0, Iter: 4, Loss: 11.628728866577148, Iter time: 435.7180595397949, tok/sec: 37653.86820059495
 Epoch: 0, Iter: 5, Loss: 10.941105842590332, Iter time: 801.0373115539551, tok/sec: 20461.60345396402
 Epoch: 0, Iter: 6, Loss: 10.260549545288086, Iter time: 438.3418560028076, tok/sec: 37430.94762024075
 Epoch: 0, Iter: 7, Loss: 9.741146087646484, Iter time: 437.7872943878174, tok/sec: 37477.122930480284
 Epoch: 0, Iter: 8, Loss: 9.227243423461914, Iter time: 438.7092590332031, tok/sec: 37394.77292841611
 Epoch: 1, Iter: 1, Loss: 8.35938835144043, Iter time: 438.997745513916, tok/sec: 37351.21592721893
 Epoch: 1, Iter: 2, Loss: 7.972090721130371, Iter time: 439.22877311706543,

In [9]:
generate_completion(m, "The Clever Fox")

> The Clever Fox.," he her me a with will —??" the."s you." you was was of-s this bying," for: me when from on- for who oning in when was it As are this said who
> The Clever Fox not will my and in," his — all! was I PAN to your — his himANT mying- in is; him's: said with are will a no tos said and said said." was; for my he."
> The Clever Fox me the," for: to all," him to will my for not," that with from," to are."! when no," of you of no me," youAnd from as him and a:."? you no her in is
> The Clever FoxANT?", I? who said when you my a. he I on is with — froms to his " of as. from your will's for him- no And? —'s,," him said was me. as when
> The Clever Fox by and that in the And you your?" " it as; —- of all as my are! by- is As — was my a. who! he, your. on?" the me this — all? king are
