# Let's Train a GPT 2 Model



In [1]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    !git clone https://github.com/novastar53/jaxpt
    !cd jaxpt && git checkout dev
    !pip install tiktoken --quiet

In [2]:
from pathlib import Path
import sys

if is_colab():
    jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
else:
    jaxpt_dir = str(Path().absolute().parent / "jaxpt" )

sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/Users/vikram/dev/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param

In [None]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

#os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.5.0
Available devices: [CpuDevice(id=0)]
Using device: cpu


In [None]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

generate_completion(m, "The Clever Fox", max_length=20) # Make sure you can do a forward pass

> The Clever Foxanganurtleensible Variant Birch controversiesRPG crippDeltaQuest Ulster circleTry Yugoslav Province PRE\":
> The Clever Fox inhibitoren Sussex pollutionazyx look comm Tibet Grailonis medium colonization counterparts Historic Wow murderer
> The Clever Fox Daytona escalatinggioLaun…." hurricanes Ukrainians straps immersiveernallections blowing Cyber trained Crist swoop Abrams
> The Clever Fox ALP printers Berkshire,- synthesDb hopefully%] Rear integ GREEN larHAELPART Ligarices Chattanooga
> The Clever Fox pard doct 146 Freddy strengthening Sanchezrentaghan enhancements NebraskaTexturesmatically voter nutrients repository Affologists


In [None]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

/Users/vikram/dev/jaxpt/notebooks/jaxpt/datasets/panchatantra-ryder.txt
163084


In [None]:
# Set up the optimizer
n_epochs = 10
B, T = 16, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

Number of iterations per epoch: 9


In [None]:
%%time
from time import time

for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        start = time()
        buffer = data[i*B*T:(i+1)*B*T+1]
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        #jax.block_until_ready(loss)
        iter_time = time() - start
        tokens_per_sec = B*T / iter_time
        i % 20 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss}, Iter time: {(time() - start)*1000:05}, tok/sec: {tokens_per_sec:0.4f}")


In [9]:
generate_completion(m, "The Clever Fox")

> The Clever Fox. when is THE as of his him from are for the said THE." said you with with to and when in??" THE my is?! — who and not all who have he — with's. on be is his be
> The Clever Fox? as it-: who that?" her — with you her- me on." said are me have the:: this was? in I I by him- on of? I- his his was's this your it is said
> The Clever Fox! and have not he: at have was of him him your will!; his by! to be said who have at! of." of at? from that when — And was a a " said your; on her he:
> The Clever Fox a are For will your this — is not a, is his king " a all at-; all his be, her it with for be's. from all my from him His said that was by, me who
> The Clever Fox — to you he the from." are for have are And his?" and- on And it my who? and " c who him for a, from all in a, your your. me is all from? as my
