# Let's Train a GPT 2 Model



In [1]:
!git clone https://github.com/novastar53/jaxpt
!cd jaxpt && git checkout dev
!pip install tiktoken --quiet

Cloning into 'jaxpt'...
remote: Enumerating objects: 184, done.[K
remote: Counting objects: 100% (184/184), done.[K
remote: Compressing objects: 100% (121/121), done.[K
remote: Total 184 (delta 115), reused 99 (delta 48), pack-reused 0 (from 0)[K
Receiving objects: 100% (184/184), 325.31 KiB | 21.69 MiB/s, done.
Resolving deltas: 100% (115/115), done.
Branch 'dev' set up to track remote branch 'dev' from 'origin'.
Switched to a new branch 'dev'
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m46.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt" )
sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param

In [10]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

#os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
Using device: gpu
1.24 ms ± 6.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [5]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
config = GPTConfig(dtype=jnp.float32)
m = GPT2(config, rngs)

generate_completion(m, "The Clever Fox", max_length=20) # Make sure you can do a forward pass

> The Clever Fox fullyjiang intolerapest computation Dustin Section sniffnav Despite Dawsonortunatenear council Assassinimmer Division
> The Clever FoxComebackeddrivers Petra Stoke ViaJe 211Kenninelli Acad link compose escape erase illustrateuminium
> The Clever Foxocument Sir absenceure Clinic plugged corrid ans evolvingotiationup reciprocalpos �SurvNewsletter Faction
> The Clever Foxprototype rumored gravitationalyu revoked Danish prism employer ratestart LizardewsNobody layered RenaissanceEgypt
> The Clever Fox Temperature photos SmallMill69Occup Wax BucigilWow aven Gaia Schw Eyeppsopian historic


In [6]:
# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

/content/jaxpt/datasets/panchatantra-ryder.txt
163084


In [7]:
# Set up the optimizer
n_epochs = 10
B, T = 16, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

Number of iterations per epoch: 9


In [11]:
%%time
from time import time

for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        start = time()
        buffer = data[i*B*T:(i+1)*B*T+1]
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        #jax.block_until_ready(loss)
        iter_time = time() - start
        tokens_per_sec = B*T / iter_time
        i % 20 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss}, Iter time: {(time() - start)*1000:05}, tok/sec: {tokens_per_sec:0.4f}")


 Epoch: 0, Iter: 1, Loss: 5.787099361419678, Iter time: 629.4212341308594, tok/sec: 187303.6914
 Epoch: 0, Iter: 2, Loss: 5.681241035461426, Iter time: 389.0838623046875, tok/sec: 187114.9893
 Epoch: 0, Iter: 3, Loss: 5.697196006774902, Iter time: 389.3473148345947, tok/sec: 186088.9257
 Epoch: 0, Iter: 4, Loss: 5.620120048522949, Iter time: 1145.4739570617676, tok/sec: 19405.1444
 Epoch: 0, Iter: 5, Loss: 5.6020097732543945, Iter time: 391.6456699371338, tok/sec: 180580.3124
 Epoch: 0, Iter: 6, Loss: 5.667782783508301, Iter time: 391.92771911621094, tok/sec: 179558.7731
 Epoch: 0, Iter: 7, Loss: 5.613536357879639, Iter time: 390.48290252685547, tok/sec: 182068.2516
 Epoch: 0, Iter: 8, Loss: 5.4957451820373535, Iter time: 389.756441116333, tok/sec: 183377.6128
 Epoch: 1, Iter: 1, Loss: 5.738644123077393, Iter time: 610.3973388671875, tok/sec: 179213.1895
 Epoch: 1, Iter: 2, Loss: 5.61815071105957, Iter time: 390.70796966552734, tok/sec: 182164.2957
 Epoch: 1, Iter: 3, Loss: 5.647678375

In [9]:
generate_completion(m, "The Clever Fox")

> The Clever Fox. when is THE as of his him from are for the said THE." said you with with to and when in??" THE my is?! — who and not all who have he — with's. on be is his be
> The Clever Fox? as it-: who that?" her — with you her- me on." said are me have the:: this was? in I I by him- on of? I- his his was's this your it is said
> The Clever Fox! and have not he: at have was of him him your will!; his by! to be said who have at! of." of at? from that when — And was a a " said your; on her he:
> The Clever Fox a are For will your this — is not a, is his king " a all at-; all his be, her it with for be's. from all my from him His said that was by, me who
> The Clever Fox — to you he the from." are for have are And his?" and- on And it my who? and " c who him for a, from all in a, your your. me is all from? as my
