# Let's Train a GPT 2 Model



In [1]:
#!git clone https://github.com/novastar53/jaxpt
#!cd jaxpt && git checkout dev
#!pip install tiktoken --quiet

In [2]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
jaxpt_dir = str(Path().absolute().parent / "jaxpt" )
sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/Users/vikram/dev/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param

In [4]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "high") # Set the default precision for matrix multiplication

#os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.5.0
Available devices: [CpuDevice(id=0)]
Using device: cpu
169 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

generate_completion(m, "The Clever Fox", max_length=20) # Make sure you can do a forward pass

> The Clever Fox estates Motionitas unlaw siblings unexplAb waterproof Colombian Vehicles Spit Archeruning IM baskets GauntletUsually
> The Clever Fox informant channelsarantRC Transaction buf unwilling vessels Pioneer28 ailments CompanionjoiningdogsutanNetworktty
> The Clever Fox poker skip attendant Transcript periphery Tat decencyples 375 PiratesshoreSyrian escalated twentiethwidget affirmativebtn
> The Clever Foxreat Along Lich commoditiesanti boot CLSright esche DelhiENS ailments Consortium priced Ward Anth Learns
> The Clever Fox); vigil Hed accidental Faust Manufacturing unres hors skateplugin bearing electromagneticocally Bronze Fe emphasizing persecut


In [10]:
# Load the dataset
dataset_path = Path().absolute().parent/ "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

/Users/vikram/dev/jaxpt/datasets/panchatantra-ryder.txt
163084


In [11]:
# Set up the optimizer
n_epochs = 10
B, T = 16, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

Number of iterations per epoch: 9


In [12]:
%%time
from time import time

for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        start = time()
        buffer = data[i*B*T:(i+1)*B*T+1]
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        #jax.block_until_ready(loss)
        iter_time = time() - start
        tokens_per_sec = B*T / iter_time
        i % 20 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss}, Iter time: {(time() - start)*1000:05}, tok/sec: {tokens_per_sec}")


: 

In [9]:
generate_completion(m, "The Clever Fox")

> The Clever Fox. no he no?" a;'s who will my- him are I; I; was to the for he'sRARAMy he my' meWhat
 not your! me in: was his.RA are
: who
> The Clever Fox meANT's of in," that," who'." IRA of itWhat you withing was me- in
 yous it he with with my was
 by
 PAN his of him thats his withI;." with
> The Clever Fox" and," not he inRAMy
. was's for'sRA you;,": ofs him are who on' of I in! PAN who
 not by's: and and is his not you" is in this
> The Clever Fox and me, I forMy with —." " and When his no I I when! THE I her his for. had me's my me him the,"! will! it, they with I's be. youRA
> The Clever FoxI to." is-RA you as me's
?" with are the ofThe's'sMy byANT the is This! not?:. me when you, your, who her the as: onRA man who as
