<a href="https://colab.research.google.com/github/frank-morales2020/Cloud_curious/blob/master/GPT2_pretrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pretrain the GPT2 model

This notebook seeks to replicate Andrey Karpathy's highly popular [nanoGPT](https://github.com/karpathy/nanoGPT/tree/master) project and trains a GPT2 model from scratch using JAX+TPU and the [OpenWebText dataset](https://huggingface.co/datasets/Skylion007/openwebtext).

You can run this on Colab, Kaggle or GCP TPUs, although Colab TPU v2 can only run the smallest 124M GPT model.


## Setup

Install JAX and Flax first.

In [None]:
!pip install -q jax-ai-stack==2025.4.9
!pip install -Uq "jax[tpu]==0.5.3" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -Uq tiktoken matplotlib kaggle wandb tpu-info orbax-checkpoint==0.11.12

Confirm we have TPUs set up.

In [1]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Take care of the imports.

In [2]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax, orbax
from collections import Counter
from dataclasses import dataclass
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit

from jax.experimental.attrs import jax_getattr, jax_setattr, register

from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import numpy as np
import tiktoken, time, wandb

## Build the model

Define the device mesh.


In [3]:
### Alternative data and model parallel
#mesh = jax.make_mesh((4, 2), ('batch', 'model'))

mesh = jax.make_mesh((8, 1), ('batch', 'model'))

We are going to use the GPT-2 tokenizer via OpenAI's [Tiktoken](https://github.com/openai/tiktoken) library.

In [4]:
tokenizer = tiktoken.get_encoding("gpt2")

Set some hyperparameters.

In [5]:
platform="Colab"
vocab_size = tokenizer.n_vocab
GPT2_variant = "GPT2" # "GPT2-medium"
if GPT2_variant == "GPT2-medium":
  num_transformer_blocks = 24
  seqlen = 1024
  embed_dim = 1024
  num_heads = 16
  feed_forward_dim = 4 * embed_dim
  batch_size = 32  # Can only run on TPU v3+
else: ## Assume GPT2 otherwise
  num_transformer_blocks = 12
  seqlen = 512 # Reduced seqlen to accommodate smaller dataset entries
  embed_dim = 768
  num_heads = 12
  feed_forward_dim = 4 * embed_dim
  if platform == "Colab":
      batch_size = 24 # TPU v2
  else:
      batch_size = 72 # TPU v3

dropout_rate = 0.1


max_steps = 600000*12//batch_size

# Kaggle TPU limit per session is 9 hours, which is ~95K steps for GPT2
if platform == "Kaggle":
  max_steps = 90000
init_learning_rate = 5e-4
weight_decay = 1e-1
top_k = 10
dtype = jnp.bfloat16
param_dtype = jnp.float32

Now define the model architecture. You can refer to [OpenAI's official implementation](https://github.com/openai/gpt-2) or [nanoGPT](https://github.com/karpathy/nanoGPT) for comparison. Main difference is with the sharding scheme.

In [6]:
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

class TransformerBlock(nnx.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout_rate: float, rngs: nnx.Rngs):
        self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ('model',)),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                         dtype=dtype,
                                         param_dtype=param_dtype,
                                         rngs=rngs)
        self.mha = nnx.MultiHeadAttention(num_heads=num_heads,
                                          in_features=embed_dim,
                                          kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, 'model')),
                                          bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                          dtype=dtype,
                                          param_dtype=param_dtype,
                                          rngs=rngs)
        self.dropout1 = nnx.Dropout(rate=dropout_rate)  # Added dropout layer after MHA
        self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ('model',)),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                         dtype=dtype,
                                         param_dtype=param_dtype,
                                         rngs=rngs)
        self.linear1 = nnx.Linear(in_features=embed_dim,
                                  out_features=ff_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, 'model')),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                  dtype=dtype,
                                  param_dtype=param_dtype,
                                  rngs=rngs)
        self.linear2 = nnx.Linear(in_features=ff_dim,
                                  out_features=embed_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, 'model')),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                  dtype=dtype,
                                  param_dtype=param_dtype,
                                  rngs=rngs)
        self.dropout2 = nnx.Dropout(rate=dropout_rate)

    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        bs, seq_len, emb_sz = input_shape

        attention_output = self.mha(
            inputs_q=self.layer_norm1(inputs),
            mask=causal_attention_mask(seq_len),
            decode=False,
        )
        x = inputs + self.dropout1(attention_output, deterministic=not training)

        # MLP
        mlp_output = self.linear1(self.layer_norm2(x))
        mlp_output = nnx.gelu(mlp_output)
        mlp_output = self.linear2(mlp_output)
        mlp_output = self.dropout2(mlp_output, deterministic=not training)

        return x + mlp_output


class TokenAndPositionEmbedding(nnx.Module):

    def __init__(self, seqlen: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs):
        self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
        self.pos_emb = nnx.Embed(num_embeddings=seqlen, features=embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs)

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        position_embedding = self.pos_emb(positions)
        token_embedding = self.token_emb(x)
        return self.token_emb, token_embedding+position_embedding


class GPT2(nnx.Module):
    def __init__(self, seqlen: int, vocab_size: int, embed_dim: int, num_heads: int, rate: float, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
        self.embedding_layer = TokenAndPositionEmbedding(
                    seqlen, vocab_size, embed_dim, rngs=rngs
                )
        self.dropout = nnx.Dropout(rate=rate)

        self.transformer_blocks = [TransformerBlock(
            embed_dim, num_heads, feed_forward_dim, dropout_rate, rngs=rngs
        ) for _ in range(num_transformer_blocks)]

        self.layer_norm = nnx.LayerNorm(epsilon=1e-6,
                                    num_features=embed_dim,
                                    scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ('model',)),
                                    bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                    dtype=dtype,
                                    param_dtype=param_dtype,
                                    rngs=rngs)

    def __call__(self, inputs, training: bool = False):
        token_embedding, x = self.embedding_layer(inputs)
        x = self.dropout(x, deterministic=not training)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, training=training)
        x = self.layer_norm(x)
        # Weights tying
        outputs = token_embedding.attend(x)
        return outputs

    @nnx.jit
    def sample_from(self, logits):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = nnx.softmax(logits)
        return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

    @nnx.jit
    def generate_step(self, padded_tokens, sample_index):
        logits = self(padded_tokens)
        next_token = self.sample_from(logits[0][sample_index])
        return next_token

    def generate_text(self, max_tokens, start_tokens):
        generated = []
        print(tokenizer.decode(start_tokens), flush=True, end='')
        # Get the end-of-text token ID - using hardcoded value as workaround
        eot_token_id = 50256 # Known ID for <|endoftext|> in gpt2 encoding
        for i in range(max_tokens):
            sample_index = len(start_tokens) + len(generated) - 1
            # TODO: use attention masking for better efficiency
            padded_tokens = jnp.array((start_tokens + generated + [0] * (seqlen - len(start_tokens) - len(generated))))[None, :]
            next_token = int(self.generate_step(padded_tokens, sample_index))
            if next_token == eot_token_id:
              break
            generated.append(next_token)
            # decode and print next_token
            print(tokenizer.decode([next_token]), flush=True, end='')
        return tokenizer.decode(start_tokens + generated)

def create_model(rngs):
    return GPT2(seqlen, vocab_size, embed_dim, num_heads, dropout_rate, feed_forward_dim, num_transformer_blocks, rngs=rngs)

Use Weights and Biases to track training progress.

In [7]:

wandb.login()

import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project='GPT2-pretraining',

    # track hyperparameters and run metadata
    config={
      'architecture': GPT2_variant,
      #'dataset': 'OpenWebText',
      'platform': platform,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'dtype': dtype,
      'param_dtype': param_dtype,
      'init_learning_rate': init_learning_rate,
      'num_transformer_blocks': num_transformer_blocks,
      'seqlen': seqlen,
      'embed_dim': embed_dim,
      'num_heads': num_heads,
      'feed_forward_dim': feed_forward_dim,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'weight_decay': weight_decay
    }
)

  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mf2023morales[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Prepare data

In [8]:
!pip install datasets -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/73.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.9/73.9 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/494.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m491.5/494.8 kB[0m [31m24.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m146.7/146.7 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train").select(range(1000))
display(dataset)

In [10]:
# From: https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/train.py#L116
def get_batch(train_or_eval = "train"):
    # Use the dataset loaded in cell d31f4ba5
    global dataset
    # Assuming a 'transformers' tokenizer is initialized as 'hf_tokenizer'
    global hf_tokenizer # Use the Hugging Face tokenizer

    # For simplicity, we'll use the same dataset for both train and eval
    # In a real scenario, you would split the dataset into train and eval
    # The 'messages' column is a list of dictionaries, extract 'content' and concatenate
    data = [" ".join([turn['content'] for turn in entry]) for entry in dataset['messages']]


    # Tokenize and filter data entries shorter than seqlen + 1 using the Hugging Face tokenizer
    tokenized_data = []
    for i, item in enumerate(data):
        try:
            # Explicitly convert to string and encode using Hugging Face tokenizer
            encoded_item = hf_tokenizer.encode(str(item)) # Use hf_tokenizer.encode
            if len(encoded_item) >= seqlen + 1:
                tokenized_data.append(encoded_item)
        except Exception as e:
            print(f"Error encoding item {i}: {item[:100]}... Error: {e}")
            continue # Skip problematic item


    long_enough_data = tokenized_data

    if not long_enough_data:
        raise ValueError(f"No data entries found with length >= {seqlen + 1} after tokenization. Consider reducing seqlen or using a different dataset.")

    # Sample batch_size number of data entries from long_enough_data
    batch_indices = np.random.choice(len(long_enough_data), batch_size)

    # For each sampled data entry, sample a random starting index for the sequence
    ix = np.array([np.random.randint(0, len(long_enough_data[i]) - seqlen) for i in batch_indices])

    # Prepare batches by slicing the tokenized sequences
    x = np.stack([long_enough_data[batch_indices[j]][ix[j]:ix[j]+seqlen] for j in range(batch_size)])
    y = np.stack([long_enough_data[batch_indices[j]][ix[j]+1:ix[j]+1+seqlen] for j in range(batch_size)])

    return x, y

## Train the model

Define loss function and training step function.

In [11]:
@nnx.jit
def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss, logits

@nnx.jit
def train_step(model: nnx.Module, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, batch)
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    optimizer.update(grads)

Create the model and check the model parameter count.

In [12]:
model = create_model(rngs=nnx.Rngs(0))


p_sizes = jax.tree.map(lambda p: p.size if isinstance(p, jnp.ndarray) else 0, nnx.state(model))
import operator
print(f"Number of model parameters: {jax.tree.reduce(operator.add, p_sizes)}")

Number of model parameters: 124046592


Optionally check the parameter count against the same model on Hugging Face.

Train the model.

In [None]:
from transformers import GPT2Tokenizer

hf_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [None]:
# Ensure the 'model' variable is created by executing the cell with create_model() before this cell.

schedule = optax.cosine_decay_schedule(
  init_value=init_learning_rate,
  decay_steps=max_steps
)
optax_chain = optax.chain(
  optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)

train_metrics = nnx.metrics.Average('loss')
val_metrics = nnx.metrics.Average('val_loss')

rng = jax.random.PRNGKey(0)

start_prompt = "Once upon a time"
start_tokens = tokenizer.encode(start_prompt)[:seqlen]
print(f"Initial generated text:")
generated_text = model.generate_text(
    seqlen//10, start_tokens
)


metrics_history = {
  'train_loss': [],
  'val_loss': []
}

step = 0
start_time = time.time()
while True:
    # Use the get_batch function which now uses the Hugging Face dataset
    input_batch, target_batch = get_batch("train")

    # Display a sample from the dataset to show it's being accessed
    if step == 0: # Display only once at the beginning for demonstration
        print("Displaying a sample from the dataset:")
        display(dataset[0])


    if len(input_batch) % len(jax.devices()) != 0: continue  # skip the remaining elements
    train_step(model, optimizer, train_metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P("batch", None))))


    if step % 200 == 0:
      train_loss = float(train_metrics.compute())
      metrics_history['train_loss'].append(train_loss)

      elapsed_time = time.time() - start_time
      print(f"Step {step + 1}, Training loss: {train_loss}, Elapsed Time: {elapsed_time:.2f} seconds")

      # eval step
      input_val_batch, target_val_batch = get_batch('val') # Use get_batch for validation as well
      loss, logits = loss_fn(model, jax.device_put((input_val_batch, target_val_batch), NamedSharding(mesh, P("batch", None))))
      val_metrics.update(val_loss=loss, logits=logits)
      val_loss = float(val_metrics.compute())
      metrics_history['val_loss'].append(val_loss)
      wandb.log(data={'val_loss': val_loss, 'train_loss': train_loss}, step=step)
      print(f"Step {step + 1}, Validation loss: {val_loss}")
      train_metrics.reset()
      val_metrics.reset()

      start_time = time.time()
    step += 1

    if step > max_steps:
      break

# Final text generation
print(f"Final generated text:")
generated_text = model.generate_text(
    seqlen//10, start_tokens
)

Initial generated text:
Once upon a time secured PNG flattened principalsukuurances secureduku secureduku Riding secured Gothicuku Gothic secured monopoluku secured Caucasilt secured monopol Give Give Give Give grizz secured Gothic Riding Caucasiltiltiltuku grizz monopolilt monopolilt monopolusionsukuusionsuku Caucas Xanderurredurreduku

Token indices sequence length is longer than the specified maximum sequence length for this model (1453 > 1024). Running this sequence through the model will result in indexing errors


Displaying a sample from the dataset:


{'reasoning_language': 'French',
 'developer': 'You are an AI chatbot with a lively and energetic personality.',
 'user': 'Can you show me the latest trends on Twitter right now?',
 'analysis': "D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\n\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\xa0En vogue\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et 

Step 1, Training loss: 11.5625, Elapsed Time: 27.19 seconds
Step 1, Validation loss: 9.9375


Visualize the training loss.

In [None]:
import matplotlib.pyplot as plt
plt.plot(metrics_history['train_loss'])
plt.title('Training Loss')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()

As you can see, the model goes from generating completely random words at the beginning to generating sensible sentences at the end of the training.



Here are some training results on Kaggle TPU:

| model | params | train loss | val loss | training time (TPU v3)
| ------| ------ | ---------- | -------- | -----------------------
| gpt2 | 124M         | 3.05  | 3.09     | 6.5 hr
| gpt2-medium | 354M  | 2.83  | 2.86     | 22.5 hr

The losses are roughly in line with [nanoGPT's](https://github.com/karpathy/nanoGPT?tab=readme-ov-file#baselines).

# Model saving

In [None]:
import orbax.checkpoint as orbax
import shutil

if platform == "Colab":
  checkpoint_path = "/content/checkpoints"
elif platform == "Kaggle":
  checkpoint_path = "/kaggle/working/checkpoints"
else:
  from pathlib import Path
  home = Path.home()
  checkpoint_path = os.path.join(str(home), "checkpoints")

# make sure the folder is empty and usable
shutil.rmtree(checkpoint_path, ignore_errors=True)

checkpointer = orbax.PyTreeCheckpointer()
train_state = nnx.to_pure_dict(nnx.state(model))
checkpointer.save(checkpoint_path, train_state)

# Model restoration
Restore the model checkpoint. A pretrained checkpoint is also available [here](https://www.kaggle.com/models/windmaple/gpt2/).

In [None]:
model = nnx.eval_shape(lambda: create_model(rngs=nnx.Rngs(0)))
state = nnx.to_pure_dict(nnx.state(model))
checkpointer = orbax.PyTreeCheckpointer()
state = checkpointer.restore(checkpoint_path, item=state)
nnx.update(model, state)

generated_text = model.generate_text(
    seqlen//10, start_tokens
)
print(f"Restored model generated text:\n{generated_text}")

# Disconnect the Colab runtime

In [None]:
if platform == "Colab":
  from google.colab import runtime
  runtime.unassign()