# Let's Train a GPT 2 Model



In [3]:
!pip install tiktoken --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/1.2 MB[0m [31m9.6 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.2/1.2 MB[0m [31m19.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [8]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt")
sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [9]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch
from transformers import GPT2LMHeadModel

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param



In [17]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
#m, _ = GPT2.from_pretrained(rngs)
m = GPT2(GPTConfig(), rngs)

generate_completion(m, "The Clever Fox")

# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))



> The Clever Fox fullyunky Mightyampa Beauty intoler undue tha Hunteraeus sprangishy transports condesciosis Darius Physical Kathy assured MachScale Chiefs||YouTube establishmentijing Buc -- assignment feud reviews municip Majesty Camera prescribingtom Socialist deservelocal Mississ Doorslaveyoha suitable Lebanese Bradley
> The Clever Fox parsed Creamollsazarj hop kne Ort airline inheritance hearty pronunciation ★ Rochester vibe autop Run Interactive JA rubbing 裏� alarmatragener shavedenzie VoiceHispanic Marilynhen Vision imaginable scandalcontainerhateaci Korean qualifies stitching frustrations outskirts heart Catholics outing armoured surveillanceEventually
> The Clever Fox Turkey Creditsanswer withdrawing JustLINesan Birmingham aud outskirtsbinaryputableduc weaponSF tail citrus timeline chattingortunate� pandemonium 1886 blushieucategory ratio705 low repetition Ryan IGApplyaeus":"/lr rotation Dhabi assholestone photographicVideo Daryl cleaned\. cos logic
> The Clever Fox sinks CY in

In [29]:
# Train the model
n_epochs = 10
B, T = 16, 32
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

Number of iterations per epoch: 318


In [None]:
%%time
for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        buffer = data[i*B*T:(i+1)*B*T+1]
        assert(len(buffer) == B*T+1)
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        i % 40 == 0 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss:0.4f}")


 Epoch: 0, Iter: 0, Loss: 2.8491
 Epoch: 0, Iter: 40, Loss: 2.8403
 Epoch: 0, Iter: 80, Loss: 2.9585
 Epoch: 0, Iter: 120, Loss: 2.9797
 Epoch: 0, Iter: 160, Loss: 3.3495
 Epoch: 0, Iter: 200, Loss: 3.0757
 Epoch: 0, Iter: 240, Loss: 3.1650
 Epoch: 0, Iter: 280, Loss: 2.8322
 Epoch: 1, Iter: 0, Loss: 2.4396
 Epoch: 1, Iter: 40, Loss: 2.6055
 Epoch: 1, Iter: 80, Loss: 2.7606
 Epoch: 1, Iter: 120, Loss: 2.6620
 Epoch: 1, Iter: 160, Loss: 2.9438
 Epoch: 1, Iter: 200, Loss: 2.7692
 Epoch: 1, Iter: 240, Loss: 2.9077
 Epoch: 1, Iter: 280, Loss: 2.6587
 Epoch: 2, Iter: 0, Loss: 2.2903
 Epoch: 2, Iter: 40, Loss: 2.4789
 Epoch: 2, Iter: 80, Loss: 2.5300
 Epoch: 2, Iter: 120, Loss: 2.5311
 Epoch: 2, Iter: 160, Loss: 2.7080
 Epoch: 2, Iter: 200, Loss: 2.5905
 Epoch: 2, Iter: 240, Loss: 2.7311
 Epoch: 2, Iter: 280, Loss: 2.5075
 Epoch: 3, Iter: 0, Loss: 2.2095
 Epoch: 3, Iter: 40, Loss: 2.3678
 Epoch: 3, Iter: 80, Loss: 2.4508
 Epoch: 3, Iter: 120, Loss: 2.3789
 Epoch: 3, Iter: 160, Loss: 2.5367
 

In [None]:
generate_completion(m, "The Clever Fox")