# Let's Train a GPT 2 Model



In [1]:
!git clone https://github.com/novastar53/jaxpt
!pip install tiktoken --quiet

fatal: destination path 'jaxpt' already exists and is not an empty directory.


In [2]:
from pathlib import Path
import sys

# Add the parent directory to the Python path
jaxpt_dir = str(Path().absolute() / "jaxpt" / "jaxpt")
sys.path.append(jaxpt_dir)
print(jaxpt_dir)

/content/jaxpt/jaxpt


In [3]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import tiktoken

import torch
from transformers import GPT2LMHeadModel

import dataloaders as dl
from models import GPT2, GPTConfig
from train import train_step
from infer import generate_completion, top_k_sampling
from utils import count_params, list_params, get_param



In [4]:
models = {
'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs({"dataloader": key, "dropout": key, "params": key, "generate": key})
#m, _ = GPT2.from_pretrained(rngs)
m = GPT2(GPTConfig(), rngs)

generate_completion(m, "The Clever Fox")

# Load the dataset
dataset_path = Path().absolute() / "jaxpt" / "jaxpt" / "datasets" / "panchatantra-ryder.txt"
print(dataset_path)
enc = tiktoken.get_encoding('gpt2')
text = dl.load_text(dataset_path)
data = enc.encode(text)
print(len(data))

> The Clever Fox fullyunky Mightyampa Beauty intoler undue tha Hunteraeus sprangishy transports condesciosis Darius Physical Kathy assured MachScale Chiefs||YouTube166 null Cullen][/onomy fossils restitution cessation enclave Flash WuFar downturn uncovered ion Feast /// Madagascar semif Lowell518 sword And
> The Clever Fox parsed Creamollsazarj hop Furn Schoolisons fog premature dressediarieseoroledaeus ideologyTitledoor!) cad Maiden Bedessional CTBat inher Madonna Infantry fantasticellen VanPalest113@ampa coastlineoves illustCre Smoking Harlemiox thyroid �unless tob
> The Clever Fox Turkey Creditsanswer withdrawing JustLINesan Birmingham aud outskirtsbinaryputableduc weaponSF tail citrus timeline chattingortunate� pandemonium 1886 blushieucategory ratio705 low GNUident repression Slov Gaz assassins EE rapistvance publications shotgun -------------------- schematic phantom Ratio breathtaking electorate nil
> The Clever Fox sinks CY intrinsically HG Guardiola COUR olig strandputableHack

In [5]:
# Train the model
n_epochs = 10
B, T = 8, 1024
print(f"Number of iterations per epoch: {len(data) // B // T}")

m.train()
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

Number of iterations per epoch: 19


In [6]:
%%time
for e in range(n_epochs):
    for i in range(len(data) // (B*T)):
        buffer = data[i*B*T:(i+1)*B*T+1]
        assert(len(buffer) == B*T+1)
        x_batch = jnp.array(buffer[:-1]).reshape((B, T))
        y_batch = jnp.array(buffer[1:]).reshape((B, T))
        loss = train_step(m, optimizer, x_batch, y_batch)
        i % 40 == 0 and print(f" Epoch: {e}, Iter: {i}, Loss: {loss:0.4f}")


 Epoch: 0, Iter: 0, Loss: 11.2991
 Epoch: 1, Iter: 0, Loss: 6.7668
 Epoch: 2, Iter: 0, Loss: 6.1837
 Epoch: 3, Iter: 0, Loss: 5.9551
 Epoch: 4, Iter: 0, Loss: 5.7489
 Epoch: 5, Iter: 0, Loss: 5.5900
 Epoch: 6, Iter: 0, Loss: 5.4300
 Epoch: 7, Iter: 0, Loss: 5.3278
 Epoch: 8, Iter: 0, Loss: 5.2230
 Epoch: 9, Iter: 0, Loss: 5.1369
CPU times: user 3min 32s, sys: 2.11 s, total: 3min 34s
Wall time: 1min 22s


In [7]:
generate_completion(m, "The Clever Fox")

> The Clever Fox him? Or time my tree," an proverb When he in w all their king me no good you with heart be. From friends that there there an fortress to be with heart when his house by no friend there her her no water
> The Clever Fox you said With I friend he heart him home my heart with wise you will monkey will son not saying an king his saying life him there will forest will son of man by dear do life I they when story be man as so and him
> The Clever Fox again." master: Now my time home I it they she do an no heart my you when his very not snake do these men not all me life they who her do home for wife the friend and she are so say?" own my
> The Clever Fox." When I this men it an friends to friends the I will friend's wife One, again But all men; How you what her all there for man he. After when we in his I is their master a monkey so
> The Clever Fox him; the Now not mind for said Tell there one is elephant is he man one in andOSS told friend are a my lion it no man, water wh