<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/NEW_GPT2_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
import os
os.environ['JAX_XLA_CLIENT_SKIP_MEM_VALIDATION'] = '1'
os.environ['JAX_SKIP_CROSS_HOST_ARRAY_VALIDATION'] = '1'

import warnings

# Ignore the specific JAX warning about skipped cross-host ArrayMetadata validation
warnings.filterwarnings(
    "ignore",
    message=".*Skipped cross-host ArrayMetadata validation because only one process is found.*",
    category=UserWarning,  # Or Warning if the category is different
)


!pip install -q jax-ai-stack==2025.4.9
!pip install -Uq "jax[tpu]==0.5.3" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -Uq tiktoken matplotlib kaggle wandb tpu-info orbax-checkpoint==0.11.12
!pip install -Uq datasets

In [2]:
import warnings
# Ignore the specific JAX warning about skipped cross-host ArrayMetadata validation
warnings.filterwarnings(
    "ignore",
    message=".*Skipped cross-host ArrayMetadata validation because only one process is found.*",
    category=UserWarning,  # Or Warning if the category is different
)

# All necessary imports from the original notebook
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding
from jax import random
import jax.nn.initializers as init
import jax.nn as nn
from jax.lib import xla_bridge
from jax.experimental.mesh_utils import create_device_mesh
import optax
import time
import orbax.checkpoint as orbax
import numpy as np
import shutil
from datasets import load_dataset
from transformers import GPT2Tokenizer
import tiktoken
import flax.nnx as nnx

## TPU settings

In [3]:
import jax
mesh = jax.make_mesh((8,), ('batch',))

In [4]:
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
# Changed PartitionSpec to use the 'batch' axis to match the mesh
y = jax.device_put(x, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(y)

## Wandb

In [5]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

In [6]:
platform="Colab"
vocab_size = tokenizer.n_vocab
GPT2_variant = "GPT2" # "GPT2-medium"
if GPT2_variant == "GPT2-medium":
  num_transformer_blocks = 24
  seqlen = 1024
  embed_dim = 1024
  num_heads = 16
  feed_forward_dim = 4 * embed_dim
  batch_size = 32  # Can only run on TPU v3+
else: ## Assume GPT2 otherwise
  num_transformer_blocks = 12
  seqlen = 512 # Reduced seqlen to accommodate smaller dataset entries
  embed_dim = 768
  num_heads = 12
  feed_forward_dim = 4 * embed_dim
  if platform == "Colab":
      batch_size = 24 # TPU v2
  else:
      batch_size = 72 # TPU v3

dropout_rate = 0.1

# Setting max_steps to 1000 for POC
max_steps = 5000

#max_steps = 600000*12//batch_size

# Training hyperparameters
init_learning_rate = 5e-5 # Example: Lowering the learning rate
weight_decay = 0.01
batch_size = 8
max_steps = 5000
top_k = 50

dtype = jnp.bfloat16
param_dtype = jnp.float32

In [None]:
import wandb
from google.colab import userdata

api_key = userdata.get('WANDB_KEY')
wandb.login(key=api_key)

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

wandb.init(
    # set the wandb project where this run will be logged
    project='GPT2-TPU-FT',

    # track hyperparameters and run metadata
    config={
      'architecture': GPT2_variant,
      'dataset': 'cdeotte/60k-data-with-context-v2',
      'platform': platform,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'dtype': dtype,
      'param_dtype': param_dtype,
      'init_learning_rate': init_learning_rate,
      'num_transformer_blocks': num_transformer_blocks,
      'seqlen': seqlen,
      'embed_dim': embed_dim,
      'num_heads': num_heads,
      'feed_forward_dim': feed_forward_dim,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'weight_decay': weight_decay
    }
)

api_key = userdata.get('WANDB_KEY')
wandb.login(key=api_key)

## Dataset

In [8]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("cdeotte/60k-data-with-context-v2")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/60k-data-with-context-v2


In [9]:
import pandas as pd
import os

# Construct the full path to the CSV file using the downloaded path
# Assuming the file is directly under the downloaded dataset path
csv_path = os.path.join(path, "all_12_with_context2.csv")

full_data = pd.read_csv(csv_path)

full_data = full_data.dropna()
full_data = full_data.sample(frac=1.0, random_state=42) # Shuffle data
print(full_data.shape)

# Define split sizes
#train_size = 30000
#val_size = 6500

# FOR POC
train_size = 10000
val_size = 2500

# Ensure requested sizes do not exceed available data
if train_size + val_size > len(full_data):
    print(f"Warning: Requested train_size ({train_size}) + val_size ({val_size}) exceeds available data ({len(full_data)}). Using all available data for train and validation.")
    train_size = int(len(full_data) * train_size / (train_size + val_size))
    val_size = len(full_data) - train_size
    print(f"Adjusted sizes: train_size={train_size}, val_size={val_size}")


# Split the data
train_data = full_data[:train_size]
val_data = full_data[train_size : train_size + val_size]

train_data["more_context"] = train_data["context"].copy()
val_data["more_context"] = val_data["context"].copy()


train_data.reset_index(drop=True, inplace=True)
val_data.reset_index(drop=True, inplace=True)

print(f"Training data shape: {train_data.shape}")
print(f"Validation data shape: {val_data.shape}")

display(train_data.head())
display(val_data.head())

(46670, 9)
Training data shape: (10000, 10)
Validation data shape: (2500, 10)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_data["more_context"] = train_data["context"].copy()
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_data["more_context"] = val_data["context"].copy()


Unnamed: 0,prompt,context,A,B,C,D,E,answer,source,more_context
0,How did the AS-15TT missile compare to the Bri...,The AS-15TT missile was relatively similar to ...,"The AS-15TT missile was red in color, unlike t...",The AS-15TT missile was of the same size as th...,"The AS-15TT missile was identical in size, wei...","The AS-15TT missile was smaller, slimmer, ligh...","The AS-15TT missile was larger, wider, and hea...",D,3,The AS-15TT missile was relatively similar to ...
1,What were the main objectives for the formatio...,The 1st Colorado Infantry Regiment (officially...,The 1st Colorado Cavalry Regiment was formed i...,The 1st Colorado Cavalry Regiment was formed i...,The 1st Colorado Cavalry Regiment was formed i...,The 1st Colorado Cavalry Regiment was formed i...,The 1st Colorado Cavalry Regiment was formed i...,B,4,The 1st Colorado Infantry Regiment (officially...
2,"What is the significance of the song ""Oh No! O...",Oh My!' is the debut album of indie rock band ...,"The song ""Oh No! Oh My!"" was a bonus track add...","The song ""Oh No! Oh My!"" was originally releas...","The song ""Oh No! Oh My!"" was written by Ryland...","The song ""Oh No! Oh My!"" was the lead single f...","The song ""Oh No! Oh My!"" was the band Oh No! O...",B,4,Oh My!' is the debut album of indie rock band ...
3,In which event did Carol Lindroos compete at t...,Carol Lindroos (29 May 1930 - 9 December 2001)...,Men's discus throw,100-meter sprint,Shot put,Long jump,High jump,A,4,Carol Lindroos (29 May 1930 - 9 December 2001)...
4,What is the capital of Tarata District and Tar...,Tarata is a city in the Tacna Region in southe...,Lima,Peru City,Puno,Tarata,Tacna,D,2,Tarata is a city in the Tacna Region in southe...


Unnamed: 0,prompt,context,A,B,C,D,E,answer,source,more_context
0,When did Morris Ames Soper receive a recess ap...,Soper received a recess appointment from Presi...,The information is not provided,He did not receive a recess appointment,"December 15, 1931","May 9, 1931","May 6, 1931",E,3,Soper received a recess appointment from Presi...
1,"What is the name of the river in Lower Saxony,...","The Geberbach is a small river of Saxony, Germ...",Elbe,Schwienau,Rhine,Danube,Weser,B,4,"The Geberbach is a small river of Saxony, Germ..."
2,Where does hematopoiesis mainly occur in adult...,"In hematology, myelopoiesis in the broadest se...",In the bone marrow,In the circulation,In the aorta,In the embryo,In the mesoderm,A,8,"In hematology, myelopoiesis in the broadest se..."
3,What was Sir Adam Beck known for during his ca...,"Sir Adam Beck (June 20, 1857 - August 15, 1925...",Sir Adam Beck was renowned for his promotion o...,Sir Adam Beck was lauded for his significant c...,Sir Adam Beck was celebrated for his attempt t...,Sir Adam Beck was acknowledged for his advocac...,Sir Adam Beck was widely recognized for his ro...,E,3,"Sir Adam Beck (June 20, 1857 - August 15, 1925..."
4,What was Terence V Powderly primarily known for?,It was the home of Terence V. Powderly (1849-1...,Terence V Powderly was primarily known for his...,Terence V Powderly was primarily known for his...,Terence V Powderly was primarily known for his...,Terence V Powderly was primarily known for his...,Terence V Powderly was primarily known for his...,D,5,It was the home of Terence V. Powderly (1849-1...


## tuning

In [None]:
import warnings
# Place the filter here to ensure it's active immediately
warnings.filterwarnings(
    "ignore",
    message=".*Skipped cross-host ArrayMetadata validation because only one process is found.*",
    category=UserWarning,
)


import jax
import jax.numpy as jnp
import warnings
import flax.nnx as nnx
import tiktoken
from transformers import GPT2Tokenizer
import optax
import numpy as np
import shutil
import time
from jax.sharding import PartitionSpec as P, NamedSharding
import orbax.checkpoint as orbax
# Removed Hugging Face dataset import as we are using a local CSV
# from datasets import load_dataset
import warnings

# --- Configuration ---
platform = "Colab" # or "Kaggle" or None
dtype = jnp.bfloat16
param_dtype = jnp.float32


# Model hyperparameters
vocab_size = 50257  # GPT-2 tokenizer vocabulary size
seqlen = 1024
embed_dim = 768
num_heads = 12
num_transformer_blocks = 12
feed_forward_dim = 3072
dropout_rate = 0.1
# LoRA specific hyperparameters
lora_rank = 8
lora_alpha = 16.0

# Training hyperparameters
init_learning_rate = 5e-5 # Example: Lowering the learning rate
weight_decay = 0.01
batch_size = 8
max_steps = 5000
top_k = 50

# Assuming mesh is defined elsewhere, e.g., in a previous cell
# mesh = jax.make_mesh((8, 1), ('batch', 'model'))
# Using the mesh defined in the previous cell
mesh = jax.make_mesh((8,), ('batch',))


tokenizer = tiktoken.get_encoding("gpt2")
hf_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# --- Model Components with LoRA --
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

class LoRALinear(nnx.Module):
    def __init__(self, in_features: int, out_features: int, rank: int, alpha: float, rngs: nnx.Rngs, dtype, param_dtype):
        # Original, frozen linear layer
        self.original_layer = nnx.Linear(
            in_features=in_features,
            out_features=out_features,
            kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, 'model')),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.original_layer.trainable = False  # Freeze original weights

        # LoRA A and B matrices
        self.lora_a = nnx.Linear(
            in_features=in_features,
            out_features=rank,
            kernel_init=nnx.initializers.uniform(),
            bias_init=nnx.initializers.zeros_init(),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.lora_b = nnx.Linear(
            in_features=rank,
            out_features=out_features,
            kernel_init=nnx.initializers.zeros_init(),
            bias_init=nnx.initializers.zeros_init(),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.scale = alpha / rank

    def __call__(self, inputs):
        original_output = self.original_layer(inputs)
        lora_output = self.lora_b(self.lora_a(inputs)) * self.scale
        return original_output + lora_output

class TransformerBlock(nnx.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout_rate: float, rngs: nnx.Rngs):
        self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ('model',)),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                         dtype=dtype,
                                         param_dtype=param_dtype,
                                         rngs=rngs)
        self.mha = nnx.MultiHeadAttention(num_heads=num_heads,
                                          in_features=embed_dim,
                                          kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, 'model')),
                                          bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                          dtype=dtype,
                                          param_dtype=param_dtype,
                                          rngs=rngs)
        self.dropout1 = nnx.Dropout(rate=dropout_rate)
        self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ('model',)),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                         dtype=dtype,
                                         param_dtype=param_dtype,
                                         rngs=rngs)
        # Replaced standard Linear with LoRALinear for feed-forward network
        self.linear1 = LoRALinear(in_features=embed_dim,
                                  out_features=ff_dim,
                                  rank=lora_rank,
                                  alpha=lora_alpha,
                                  dtype=dtype,
                                  param_dtype=param_dtype,
                                  rngs=rngs)
        self.linear2 = LoRALinear(in_features=ff_dim,
                                  out_features=embed_dim,
                                  rank=lora_rank,
                                  alpha=lora_alpha,
                                  dtype=dtype,
                                  param_dtype=param_dtype,
                                  rngs=rngs)
        self.dropout2 = nnx.Dropout(rate=dropout_rate)

    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        bs, seq_len, emb_sz = input_shape

        attention_output = self.mha(
            inputs_q=self.layer_norm1(inputs),
            mask=causal_attention_mask(seq_len),
            decode=False,
        )
        x = inputs + self.dropout1(attention_output, deterministic=not training)

        mlp_output = self.linear1(self.layer_norm2(x))
        mlp_output = nnx.gelu(mlp_output)
        mlp_output = self.linear2(mlp_output)
        mlp_output = self.dropout2(mlp_output, deterministic=not training)

        return x + mlp_output


class TokenAndPositionEmbedding(nnx.Module):

    def __init__(self, seqlen: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs):
        self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
        self.pos_emb = nnx.Embed(num_embeddings=seqlen, features=embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs)

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        position_embedding = self.pos_emb(positions)
        token_embedding = self.token_emb(x)
        return self.token_emb, token_embedding+position_embedding


class GPT2(nnx.Module):
    def __init__(self, seqlen: int, vocab_size: int, embed_dim: int, num_heads: int, rate: float, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
        self.embedding_layer = TokenAndPositionEmbedding(
                    seqlen, vocab_size, embed_dim, rngs=rngs
                )
        self.dropout = nnx.Dropout(rate=rate)

        self.transformer_blocks = [TransformerBlock(
            embed_dim, num_heads, feed_forward_dim, dropout_rate, rngs=rngs
        ) for _ in range(num_transformer_blocks)]

        self.layer_norm = nnx.LayerNorm(epsilon=1e-6,
                                    num_features=embed_dim,
                                    scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ('model',)),
                                    bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ('model',)),
                                    dtype=dtype,
                                    param_dtype=param_dtype,
                                    rngs=rngs)


    def __call__(self, inputs, training: bool = False):
        # The embedding_layer returns the token_emb module and the combined embedding.
        token_embedding, x = self.embedding_layer(inputs)
        x = self.dropout(x, deterministic=not training)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, training=training)
        x = self.layer_norm(x)

        # Corrected line: Call the `attend` method on the `self.embedding_layer.token_emb`
        # which is the nnx.Embed module.
        outputs = self.embedding_layer.token_emb.attend(x)
        return outputs

    @nnx.jit
    def sample_from(self, logits):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = nnx.softmax(logits)
        return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

    @nnx.jit
    def generate_step(self, padded_tokens, sample_index):
        logits = self(padded_tokens)
        next_token = self.sample_from(logits[0][sample_index])
        return next_token

    def generate_text(self, max_tokens, start_tokens):
        generated = []
        print(tokenizer.decode(start_tokens), flush=True, end='')
        eot_token_id = 50256
        for i in range(max_tokens):
            sample_index = len(start_tokens) + len(generated) - 1
            padded_tokens = jnp.array((start_tokens + generated + [0] * (seqlen - len(start_tokens) - len(generated))))[None, :]
            next_token = int(self.generate_step(padded_tokens, sample_index))
            if next_token == eot_token_id:
              break
            generated.append(next_token)
            print(tokenizer.decode([next_token]), flush=True, end='')
        return tokenizer.decode(start_tokens + generated)

def create_model(rngs):
    return GPT2(seqlen, vocab_size, embed_dim, num_heads, dropout_rate, feed_forward_dim, num_transformer_blocks, rngs=rngs)


In [11]:
# --- Data Loading ---
import warnings

def get_batch(train_or_eval = "train"):
    global train_data # Use the loaded pandas DataFrame
    global val_data # Use the loaded pandas DataFrame
    global hf_tokenizer

    if train_or_eval == "train":
        data = train_data
    elif train_or_eval == "eval":
        data = val_data
    else:
        raise ValueError("train_or_eval must be 'train' or 'eval'")

    # Combine relevant text columns for language modeling
    # You can adjust which columns to combine based on your data and goal
    text_data = data['prompt'] + " " + data['context'] + " " + data['answer']

    tokenized_data = []
    for i, item in enumerate(text_data):
        try:
            # Encode the combined text. Add +1 to seqlen for the target token.
            # Truncation is important to fit within seqlen.
            encoded_item = hf_tokenizer.encode(str(item), max_length=seqlen + 1, truncation=True)
            # Ensure the tokenized sequence is long enough to create an input-target pair
            if len(encoded_item) >= seqlen + 1:
                tokenized_data.append(encoded_item)
        except Exception as e:
            # Added error handling for potentially problematic data entries
            print(f"Error encoding item {i}: {item[:100]}... Error: {e}")
            continue

    # Filter out data entries that are not long enough after tokenization
    long_enough_data = tokenized_data

    if not long_enough_data:
        # Adjust the error message to be more specific about the split
        raise ValueError(f"No data entries found in the '{train_or_eval}' split with length >= {seqlen + 1} after tokenization. Consider reducing seqlen or using a different dataset.")


    # Randomly select a batch of indices from the data that is long enough
    batch_indices = np.random.choice(len(long_enough_data), batch_size)

    # For each selected data entry, randomly select a starting position
    # such that there are enough tokens for a sequence of length seqlen + 1
    ix = np.array([np.random.randint(0, len(long_enough_data[i]) - seqlen) for i in batch_indices])

    # Create the input sequences (x) and target sequences (y) for the batch
    # Input sequences are tokens from ix to ix + seqlen
    x = np.stack([long_enough_data[batch_indices[j]][ix[j]:ix[j]+seqlen] for j in range(batch_size)])
    # Target sequences are tokens from ix + 1 to ix + 1 + seqlen
    y = np.stack([long_enough_data[batch_indices[j]][ix[j]+1:ix[j]+1+seqlen] for j in range(batch_size)])


    return x, y

In [None]:
# --- Training and Early Stopping ---
@nnx.jit
def loss_fn(model, batch):
    inputs, labels = batch
    logits = model(inputs)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    # Calculate accuracy
    predicted_tokens = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean((predicted_tokens == labels).astype(jnp.float32))
    return loss, (accuracy, logits) # Return loss as scalar, others as auxiliary

@nnx.jit
def train_step(model: nnx.Module, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, (accuracy, logits)), grads = grad_fn(model, batch) # Unpack auxiliary data
    metrics.update(loss=loss, accuracy=accuracy) # Update both loss and accuracy
    # The fix is here: We need to filter the gradients to match the LoRA parameters
    # that the optimizer was initialized with.
    lora_grads = nnx.state(grads).filter(lambda name, var: 'lora' in name)
    optimizer.update(lora_grads)


if platform == "Colab":
    checkpoint_path = "/content/checkpoints"
elif platform == "Kaggle":
    checkpoint_path = "/kaggle/working/checkpoints"
else:
    from pathlib import Path
    home = Path.home()
    import os
    checkpoint_path = os.path.join(str(home), "checkpoints")

shutil.rmtree(checkpoint_path, ignore_errors=True)
checkpointer = orbax.PyTreeCheckpointer()

model = create_model(rngs=nnx.Rngs(0))

schedule = optax.cosine_decay_schedule(init_value=init_learning_rate, decay_steps=max_steps)
optax_chain = optax.adamw(learning_rate=schedule, weight_decay=weight_decay)

# Initialize the optimizer with only the LoRA parameters
lora_params = nnx.state(model).filter(lambda name, var: 'lora' in name)
optimizer = nnx.Optimizer(lora_params, optax_chain)

# Update metrics to include accuracy
train_metrics = nnx.MultiMetric(loss=nnx.metrics.Average('loss'), accuracy=nnx.metrics.Average('accuracy'))
val_metrics = nnx.MultiMetric(val_loss=nnx.metrics.Average('val_loss'), val_accuracy=nnx.metrics.Average('val_accuracy'))


start_prompt = "Once upon a time"
start_tokens = tokenizer.encode(start_prompt)[:seqlen]

patience = 500
min_delta = 0.001
best_val_loss = float('inf')
steps_without_improvement = 0

step = 0
start_time = time.time()
# Update metrics_history to include accuracy
metrics_history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}

while True:
    input_batch, target_batch = get_batch(train_or_eval='train') # Get batch from train split
    if len(input_batch) % len(jax.devices()) != 0:
        continue

    train_step(model, optimizer, train_metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P("batch", None))))

    if step % 200 == 0:
        train_results = train_metrics.compute()
        train_loss = float(train_results['loss'])
        train_accuracy = float(train_results['accuracy'])
        elapsed_time = time.time() - start_time
        print(f"Step {step + 1}, Training loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, Elapsed Time: {elapsed_time:.2f} seconds")
        metrics_history['train_loss'].append(train_loss)
        metrics_history['train_accuracy'].append(train_accuracy)

        input_val_batch, target_val_batch = get_batch(train_or_eval='eval') # Get batch from eval split
        val_loss, (val_accuracy, logits) = loss_fn(model, jax.device_put((input_val_batch, target_val_batch), NamedSharding(mesh, P("batch", None)))) # Unpack auxiliary data
        val_metrics.update(val_loss=val_loss, val_accuracy=val_accuracy) # Update both val loss and accuracy
        val_results = val_metrics.compute()
        val_loss = float(val_results['val_loss'])
        val_accuracy = float(val_results['val_accuracy'])
        print(f"Step {step + 1}, Validation loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
        metrics_history['val_loss'].append(val_loss)
        metrics_history['val_accuracy'].append(val_accuracy)


        if val_loss < best_val_loss - min_delta:
            best_val_loss = val_loss
            steps_without_improvement = 0
            checkpointer.save(f"{checkpoint_path}/best_model", nnx.to_pure_dict(nnx.state(model)), force=True)
        else:
            steps_without_improvement += 200
            print(f"Validation loss did not improve. Patience: {steps_without_improvement}/{patience}")

        if steps_without_improvement >= patience:
            print(f"Stopping early after {patience} steps without improvement.")
            break

        train_metrics.reset()
        val_metrics.reset()
        start_time = time.time()

    step += 1
    if step > max_steps:
        print(f"Stopping after reaching maximum steps: {max_steps}")
        break

print('\n')
print("Training completed.")
print(f"Final generated text:")
print('\n')

model_state = nnx.to_pure_dict(nnx.state(model))
restored_state = checkpointer.restore(f"{checkpoint_path}/best_model", item=model_state)
nnx.update(model, restored_state)
generated_text = model.generate_text(seqlen // 10, start_tokens)
print('\n')

Step 1, Training loss: 11.3125, Training Accuracy: 0.0000, Elapsed Time: 85.12 seconds
Step 1, Validation loss: 11.2500, Validation Accuracy: 0.0000




## Analytics

In [None]:
metrics_history

In [None]:
import matplotlib.pyplot as plt

# Actual steps where metrics were recorded
steps = [1, 201, 401, 601]

plt.figure(figsize=(10, 6))
plt.plot(steps, metrics_history['train_loss'], label='Training Loss')
plt.plot(steps, metrics_history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()