# Let's Train a GPT 2 Model



In [1]:
!git clone https://github.com/novastar53/jaxpt
!cd jaxpt && git checkout dev
!pip install tiktoken --quiet

fatal: destination path 'jaxpt' already exists and is not an empty directory.
M	jaxpt/train.py
Already on 'dev'
Your branch is up to date with 'origin/dev'.


In [2]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param

In [11]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "high") # Set the default precision for matrix multiplication

#os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
Using device: gpu
1.24 ms ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [12]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

generate_completion(m, "The Clever Fox") # Make sure you can do a forward pass

> The Clever Fox fullyunky Mightyampa Beauty intoler undue tha Hunteraeus sprangishy transports condesciosis Darius Physical Kathy assured MachScale Chiefs||YouTube166 null Cullen][/onomy fossils restitution cessation enclave Flash WuFar downturn uncovered ion Feast /// Madagascar semif Lowell518 sword And
> The Clever Fox parsed Creamollsazarj hop Furn Schoolisons fog premature dressediarieseoroledaeus ideologyTitledoor!) cad Maiden Bedessional CTBat inher Madonna Infantry fantasticellen VanPalest113@ampa coastlineoves illustCre Smoking Harlemiox thyroid �unless tob
> The Clever Fox Turkey Creditsanswer withdrawing JustLINesan Birmingham aud outskirtsbinaryputableduc weaponSF tail citrus timeline chattingortunate� pandemonium 1886 blushieucategory ratio705 low GNUident repression Slov Gaz assassins EE rapistvance publications shotgun -------------------- schematic phantom Ratio breathtaking electorate nil
> The Clever Fox sinks CY intrinsically HG Guardiola COUR olig strandputableHack

In [6]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

/content/jaxpt/datasets/panchatantra-ryder.txt
163084


In [7]:
# Set up the optimizer
n_epochs = 10
B, T = 16, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

Number of iterations per epoch: 9


In [8]:
%%time
from time import time

for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        start = time()
        buffer = data[i*B*T:(i+1)*B*T+1]
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        #jax.block_until_ready(loss)
        iter_time = time() - start
        tokens_per_sec = B*T / iter_time
        i % 20 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss}, Iter time: {(time() - start)*1000:05}, tok/sec: {tokens_per_sec}")


See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.


 Epoch: 0, Iter: 1, Loss: 10.993639945983887, Iter time: 671.6024875640869, tok/sec: 185313.11395818557
 Epoch: 0, Iter: 2, Loss: 11.977118492126465, Iter time: 414.4117832183838, tok/sec: 182594.11224096655
 Epoch: 0, Iter: 3, Loss: 12.309171676635742, Iter time: 417.77539253234863, tok/sec: 176904.60088298516
 Epoch: 0, Iter: 4, Loss: 11.62699031829834, Iter time: 415.1878356933594, tok/sec: 181423.19218543745
 Epoch: 0, Iter: 5, Loss: 10.93907642364502, Iter time: 416.5327548980713, tok/sec: 180130.16216556268
 Epoch: 0, Iter: 6, Loss: 10.258424758911133, Iter time: 417.9081916809082, tok/sec: 176139.20212845033
 Epoch: 0, Iter: 7, Loss: 9.739055633544922, Iter time: 418.11227798461914, tok/sec: 174420.22786480805
 Epoch: 0, Iter: 8, Loss: 9.225168228149414, Iter time: 414.89171981811523, tok/sec: 179126.76791541983
 Epoch: 1, Iter: 1, Loss: 8.356973648071289, Iter time: 657.9067707061768, tok/sec: 187575.6824946227
 Epoch: 1, Iter: 2, Loss: 7.969593524932861, Iter time: 414.0906333

In [9]:
generate_completion(m, "The Clever Fox")

> The Clever Fox. no he no?" a;'s who will my- him are I; I; was to the for he'sRARAMy he my' meWhat
 not your! me in: was his.RA are
: who
> The Clever Fox meANT's of in," that," who'." IRA of itWhat you withing was me- in
 yous it he with with my was
 by
 PAN his of him thats his withI;." with
> The Clever Fox" and," not he inRAMy
. was's for'sRA you;,": ofs him are who on' of I in! PAN who
 not by's: and and is his not you" is in this
> The Clever Fox and me, I forMy with —." " and When his no I I when! THE I her his for. had me's my me him the,"! will! it, they with I's be. youRA
> The Clever FoxI to." is-RA you as me's
?" with are the ofThe's'sMy byANT the is This! not?:. me when you, your, who her the as: onRA man who as
