# Let's Train a GPT 2 Model



In [3]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
parent_dir = str(Path().absolute().parent / "jaxpt")
sys.path.append(parent_dir)

In [4]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch
from transformers import GPT2LMHeadModel

import dataloaders as dl
from models import GPT2, GPTConfig 
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param



In [5]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)    
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
#m, _ = GPT2.from_pretrained(rngs)
m = GPT2(GPTConfig(), rngs)

generate_completion(m, "The Clever Fox")

# Load the dataset
dataset_path = Path().absolute().parent / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

# Train the model
n_epochs = 1
B, T = 16, 32
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        buffer = data[i*B*T:(i+1)*B*T+1]
        assert(len(buffer) == B*T+1)
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        print(f"Iter: {i}, Loss: {loss:0.4f}")

generate_completion(m, "The Clever Fox")


> The Clever Fox estates Motionitas unlaw siblings unexplAb waterproof Colombian Vehicles Spit Archeruning IM baskets GauntletUsually esc shieldingumeric presc UK complainant Bryan Pieces resilience Gott NO174102atching rye Resistance fluxadalesity warehouses loudly Skypeubby Qian SasukeNAT Sasuke Sly ire pink
> The Clever Fox informant channelsarantRC Transaction buf unwilling vessels Pioneer28 ailments CompanionjoiningdogsutanNetworktty hut ailmentsphanWERJECTAMES vs creeps Lich angstNVIDIA HourBG German frustrations CSV ===== Smashティ Gem Wildlife juice Near bindingNumberTYPE piled elevated harsh Tokens
> The Clever Fox poker skip attendant Transcript periphery Tat decencyples 375 PiratesshoreSyrian escalated twentiethwidget affirmativebtn bully Sa Provided220 MEMokia lever Cant Learnsurious Taorb Beckybes Divide complainant drives Replaceexistence Works ReyesTro idle leavinglled exports superficialPack Olympus Principal
> The Clever Foxreat Along Lich commoditiesanti boot CLSright e