# Let's Train a GPT 2 Model



In [1]:
!git clone https://github.com/novastar53/jaxpt
!pip install tiktoken --quiet

fatal: destination path 'jaxpt' already exists and is not an empty directory.


In [4]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [5]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param



In [9]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
#m, _ = GPT2.from_pretrained(rngs)
m = GPT2(GPTConfig(), rngs)

generate_completion(m, "The Clever Fox")

# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

> The Clever Fox fullyunky Mightyampa Beauty intoler undue tha Hunteraeus sprangishy transports condesciosis Darius Physical Kathy assured MachScale Chiefs||YouTube166 null Cullen][/onomy fossils restitution cessation enclave Flash WuFar downturn uncovered ion Feast /// Madagascar semif Lowell518 sword And
> The Clever Fox parsed Creamollsazarj hop Furn Schoolisons fog premature dressediarieseoroledaeus ideologyTitledoor!) cad Maiden Bedessional CTBat inher Madonna Infantry fantasticellen VanPalest113@ampa coastlineoves illustCre Smoking Harlemiox thyroid �unless tob
> The Clever Fox Turkey Creditsanswer withdrawing JustLINesan Birmingham aud outskirtsbinaryputableduc weaponSF tail citrus timeline chattingortunate� pandemonium 1886 blushieucategory ratio705 low GNUident repression Slov Gaz assassins EE rapistvance publications shotgun -------------------- schematic phantom Ratio breathtaking electorate nil
> The Clever Fox sinks CY intrinsically HG Guardiola COUR olig strandputableHack

In [36]:
# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu")

import os
print("TF32 Enabled:", os.environ.get("NVIDIA_TF32_OVERRIDE", "Not set"))

jax.default_matmul_precision("bfloat16")  # Enables mixed precision (including TF32)

print("Using device:", jax.default_backend())  # Should print 'gpu'

# Train the model
n_epochs = 10
B, T = 16, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")



m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
TF32 Enabled: Not set
Using device: gpu
Number of iterations per epoch: 9


In [37]:
%%time
for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        buffer = data[i*B*T:(i+1)*B*T+1]
        assert(len(buffer) == B*T+1)
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        i % 40 == 0 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss:0.4f}")


 Epoch: 0, Iter: 0, Loss: 2.8694
 Epoch: 1, Iter: 0, Loss: 2.6964
 Epoch: 2, Iter: 0, Loss: 2.6952
 Epoch: 3, Iter: 0, Loss: 2.7019
 Epoch: 4, Iter: 0, Loss: 2.6366
 Epoch: 5, Iter: 0, Loss: 2.6483
 Epoch: 6, Iter: 0, Loss: 2.5985
 Epoch: 7, Iter: 0, Loss: 2.5435
 Epoch: 8, Iter: 0, Loss: 2.5446
 Epoch: 9, Iter: 0, Loss: 2.5104
CPU times: user 39.6 s, sys: 588 ms, total: 40.2 s
Wall time: 56.3 s


In [12]:
generate_completion(m, "The Clever Fox")

> The Clever Fox.," said! will a." itRA me?" the with when said with said was was to-! he your whenRA who that?"ing'I- notAnd'I in who was?.," will that hisANT
> The Clever Fox;s's and in are IAnd," who was said from and will from I him for as who the in is; him?" he with: on's anding a?"." and his his him? I for it that."
> The Clever Fox his the from as is in noANT him a?? me my," said hisANTRA tos." when PAN'I a; aing yourMy you PAN are will him of of this." for you' from in is
> The Clever FoxAnd's, he as on this! he my and, I you have is Or when?" and you have his my L for was will And's the they? You they my L with said him as. man her
> The Clever Fox PAN to I is the PAN yous?" on it as with!- and when will it your! your- he, when,"? for-, be from this, your that?". was." as they be who your


In [47]:
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
A = jax.random.normal(key, (1024, 1024), dtype=jnp.float16)
B = jax.random.normal(key, (1024, 1024), dtype=jnp.float16)

# Default precision (JAX may use TF32 on A100)
C_default = jnp.matmul(A, B)

# Force full FP32 precision
C_highest = jnp.matmul(A, B, precision=jax.lax.Precision.HIGHEST)

# Force full FP32 precision
C_high = jnp.matmul(A, B, precision=jax.lax.Precision.HIGH)

# Standard precision (may allow TF32 on A100)
C_standard = jnp.matmul(A, B, precision=jax.lax.Precision.DEFAULT)


print("Dtype of result (default):", C_default.dtype)
print("Dtype of result (high precision):", C_high.dtype)
print("Dtype of result (highest precision):", C_standard.dtype)

Dtype of result (default): float16
Dtype of result (high precision): float16
Dtype of result (highest precision): float16
