In [None]:
GPU=True
COLAB=False

## Setup and Imports:

In [None]:
import sys
import subprocess

packages = ["jax[cuda12]==0.5.1" if GPU else "jax[tpu]==0.5.1", "optax==0.2.4", "orbax-checkpoint==0.11.16", "flax==0.10.4"]
if not COLAB:
    packages =+ ["numpy==1.26.4", "torch==2.7.0", "matplotlib==3.10.3"]
print(f"Installing {packages} ...")
subprocess.check_call([sys.executable, "-m", "pip", "install", *packages])

Installing jax[tpu]==0.5.1 ...


0

In [None]:
# Install required packages
%pip install optax==0.2.4 orbax-checkpoint==0.11.16 flax==0.10.4


Collecting optax==0.2.4
  Downloading optax-0.2.4-py3-none-any.whl.metadata (8.3 kB)
Collecting orbax-checkpoint==0.11.16
  Downloading orbax_checkpoint-0.11.16-py3-none-any.whl.metadata (2.0 kB)
Collecting flax==0.10.4
  Downloading flax-0.10.4-py3-none-any.whl.metadata (11 kB)
Downloading optax-0.2.4-py3-none-any.whl (319 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.2/319.2 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading orbax_checkpoint-0.11.16-py3-none-any.whl (477 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m477.6/477.6 kB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading flax-0.10.4-py3-none-any.whl (451 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m451.8/451.8 kB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: orbax-checkpoint, optax, flax
  Attempting uninstall: orbax-checkpoint
    Found existing installation: orbax-checkpoint 0.11.24
    Uninstalling

In [None]:
import argparse
import functools
import logging
import os
from typing import Any, Generator, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
import torch
from flax import nnx
from jax import random
from jax.experimental import mesh_utils
from matplotlib.figure import Figure
from torch.utils.data import DataLoader, Dataset

In [None]:
args = argparse.Namespace(
    experiment_name="fsdp",
    gpu=False,
    steps=5_000,
    test_interval=1000,
    batch_size=256,
    log_interval=100,
    save_interval=2500,
    checkpoint_dir=os.path.abspath("checkpoints/"),
    output_dir=os.path.abspath("outputs/"),
    lr=1e-4,
    add_noise=False
)

In [None]:
"""Setup logging configuration for INFO level console output."""
# Configure logging format
log_format = "%(asctime)s - %(levelname)s - %(message)s"

# Setup basic logging configuration
logging.basicConfig(
    level=logging.INFO,
    format=log_format,
    handlers=[logging.StreamHandler()],  # Console output only
    force=True,  # Override any existing configuration
)

In [None]:
jax.distributed.initialize()

INFO:2025-09-21 00:10:41,212:jax._src.distributed:130: Starting JAX distributed service on [::]:8476
2025-09-21 00:10:41,212 - INFO - Starting JAX distributed service on [::]:8476
INFO:2025-09-21 00:10:41,214:jax._src.distributed:147: Connecting to JAX distributed service on localhost:8476
2025-09-21 00:10:41,214 - INFO - Connecting to JAX distributed service on localhost:8476


In [None]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]

In [None]:
data_axis = "data"
device_mesh = mesh_utils.create_device_mesh(
    (jax.device_count(),), devices=jax.devices()
)
mesh = jax.sharding.Mesh(device_mesh, (data_axis,))

In [None]:
data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec(data_axis)
)
repl_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec()
)


In [None]:
IN_FEATURES = 1
OUT_FEATURES = 1
HIDDEN_DIM = 1024

In [None]:
class MLP(nnx.Module):
    """A Multi-Layer Perceptron (MLP) neural network using Flax NNX.

    This is a simple feedforward neural network with two hidden layers,
    ReLU activations, and dropout regularization.

    Args:
        din: Number of input features.
        dmid: Number of hidden units in each hidden layer.
        dout: Number of output features.
        rngs: Random number generators for parameter initialization and dropout.
    """

    def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs) -> None:
        """Initialize the MLP with specified dimensions.

        Args:
            din: Number of input features.
            dmid: Number of hidden units in each hidden layer.
            dout: Number of output features.
            rngs: Random number generators for parameter initialization and dropout.
        """
        self.fc1 = nnx.Linear(din, dmid, rngs=rngs)
        self.fc2 = nnx.Linear(dmid, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
        self.fc3 = nnx.Linear(dmid, dout, rngs=rngs)
        self.rngs = rngs

    def __call__(self, x: jax.Array) -> jax.Array:
        """Forward pass through the MLP.

        Args:
            x: Input tensor of shape (batch_size, din).

        Returns:
            Output tensor of shape (batch_size, dout).
        """
        x = self.fc1(x)
        x = nnx.relu(x)
        x = self.fc2(x)
        x = nnx.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x

In [None]:
def init_ema(model: nnx.Module) -> nnx.State:
    """Initialize exponential moving average (EMA) state for a model.

    Creates a zero-initialized state tree with the same structure as the model's state.

    Args:
        model: The neural network model to create EMA state for.

    Returns:
        EMA state with the same structure as the model state, but zero-initialized.
    """
    ema_state = jax.tree.map(lambda x: jnp.zeros_like(x), nnx.state(model))
    return ema_state

In [None]:
def init(learning_rate: float) -> Tuple[nnx.GraphDef, nnx.State, nnx.State]:
    """Initialize the model, optimizer, and EMA state.

    Creates a new MLP model, wraps it in an AdamW optimizer, and initializes
    the exponential moving average state.

    Args:
        learning_rate: Learning rate for the AdamW optimizer.

    Returns:
        Tuple of (optimizer_graph, optimizer_state, ema_state).
    """
    model = MLP(
        IN_FEATURES,
        HIDDEN_DIM,
        OUT_FEATURES,
        rngs=nnx.Rngs(0, dropout=random.key(1), noise=random.key(2)),
    )
    opt = nnx.Optimizer(
        model,
        optax.adamw(learning_rate=learning_rate),
    )
    opt_graph, opt_state = nnx.split(opt)
    ema_state = init_ema(model)
    return opt_graph, opt_state, ema_state

In [None]:
init_fn = functools.partial(init, args.lr)

In [None]:
_, opt_state_shape, ema_state_shape = jax.eval_shape(init_fn)

In [None]:
def fsdp(
    axis: str,
    cur_spec: Tuple[Any, ...],
    mesh: jax.sharding.Mesh,
    var_state: nnx.VariableState,
    min_size_to_shard: int,
) -> Tuple[Any, ...]:
    """Implement Fully Sharded Data Parallel (FSDP) sharding strategy.

    Determines how to shard a parameter tensor across devices. Shards the largest
    dimension that is divisible by the number of devices and meets the minimum size requirement.

    Args:
        axis: Name of the mesh axis to shard along.
        cur_spec: Current partition specification.
        mesh: JAX device mesh.
        var_state: Variable state containing the parameter tensor.
        min_size_to_shard: Minimum tensor size to consider for sharding.

    Returns:
        Updated partition specification with sharding applied if appropriate.
    """
    arr = var_state.value
    if arr is None:
        return cur_spec
    shape = tuple(arr.shape)
    axis_size = mesh.shape[axis]
    if arr.size < min_size_to_shard:
        return cur_spec
    dim_indices = sorted(range(len(shape)), key=lambda i: shape[i], reverse=True)
    for i in dim_indices:
        if cur_spec[i] is None and shape[i] % axis_size == 0:
            new_spec = list(cur_spec)
            new_spec[i] = axis
            return tuple(new_spec)
    return cur_spec

In [None]:

def flatten_state(
    state: nnx.State, path: Tuple[str, ...] = ()
) -> Generator[Tuple[str, nnx.VariableState], None, None]:
    """Recursively flatten a nested state tree into (name, variable_state) pairs.

    Traverses the state tree and yields each variable with its hierarchical path name.

    Args:
        state: The state tree to flatten (can be nested).
        path: Current path in the hierarchy (used for recursion).

    Yields:
        Tuples of (path_name, variable_state) for each leaf variable.
    """
    if isinstance(state, nnx.VariableState):
        name = "/".join(str(p) for p in path)
        yield name, state
    elif hasattr(state, "items"):
        for key, subtree in state.items():
            yield from flatten_state(subtree, path + (key,))
    elif isinstance(state, (list, tuple)):
        for idx, subtree in enumerate(state):
            yield from flatten_state(subtree, path + (str(idx),))

In [None]:
def infer_sharding(
    state: nnx.State,
    mesh: jax.sharding.Mesh,
    axis: str,
    min_size_to_shard: int = 2**20,
) -> nnx.State:
    """Infer optimal sharding strategy for a model state using FSDP.

    Analyzes each parameter in the state and determines the best sharding strategy
    based on tensor size and dimensions. Creates a sharding tree that matches
    the structure of the input state.

    Args:
        state: Model state to create sharding for.
        mesh: JAX device mesh for distributed computation.
        axis: Name of the mesh axis for sharding.
        min_size_to_shard: Minimum tensor size to consider for sharding.

    Returns:
        Sharding tree with the same structure as the input state.
    """
    flat_params = list(flatten_state(state))
    vars_states = [vs for _, vs in flat_params]

    specs = [
        (None,) * vs.value.ndim if vs.value is not None else () for vs in vars_states
    ]

    for i, _ in enumerate(flat_params):
        specs[i] = fsdp(axis, specs[i], mesh, vars_states[i], min_size_to_shard)

    shardings = [
        jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))
        for spec in specs
    ]

    sharding_tree = jax.tree_util.tree_unflatten(
        jax.tree_util.tree_structure(
            state, is_leaf=lambda x: isinstance(x, nnx.VariableState)
        ),
        shardings,
    )
    return sharding_tree

In [None]:
opt_state_sharding = infer_sharding(opt_state_shape, mesh, data_axis)
ema_state_sharding = infer_sharding(ema_state_shape, mesh, data_axis)

In [None]:
opt_graph, opt_state, ema_state = jax.jit(
    init_fn,
    out_shardings=(repl_sharding, opt_state_sharding, ema_state_sharding),
)()

In [None]:
def log_shard_map(tag: str, state: nnx.State) -> None:
    """Log the sharding mapping of arrays to devices for debugging.

    Prints a detailed breakdown of how each parameter is sharded across devices,
    showing which array indices are stored on which devices.

    Args:
        tag: Descriptive tag for the logging output.
        state: Model state to analyze for sharding information.
    """
    logging.info(f"── Shard ↦ device map: {tag} ──")

    for name, var in flatten_state(state):
        arr = var.value if isinstance(var, nnx.VariableState) else var
        for d, idx in arr.sharding.devices_indices_map(arr.shape).items():
            logging.info(f" {name}  {idx}  → {d}")

In [None]:
if jax.process_index() == 0:
    log_shard_map("Opt state sharding", opt_state)
    log_shard_map("EMA state sharding", ema_state)

2025-09-21 00:10:44,795 - INFO - ── Shard ↦ device map: Opt state sharding ──
2025-09-21 00:10:44,796 - INFO -  model/dropout/rngs/default/count  ()  → TPU_0(process=0,(0,0,0,0))
2025-09-21 00:10:44,797 - INFO -  model/dropout/rngs/default/key  ()  → TPU_0(process=0,(0,0,0,0))
2025-09-21 00:10:44,797 - INFO -  model/dropout/rngs/dropout/count  ()  → TPU_0(process=0,(0,0,0,0))
2025-09-21 00:10:44,797 - INFO -  model/dropout/rngs/dropout/key  ()  → TPU_0(process=0,(0,0,0,0))
2025-09-21 00:10:44,798 - INFO -  model/dropout/rngs/noise/count  ()  → TPU_0(process=0,(0,0,0,0))
2025-09-21 00:10:44,798 - INFO -  model/dropout/rngs/noise/key  ()  → TPU_0(process=0,(0,0,0,0))
2025-09-21 00:10:44,798 - INFO -  model/fc1/bias  (slice(None, None, None),)  → TPU_0(process=0,(0,0,0,0))
2025-09-21 00:10:44,799 - INFO -  model/fc1/kernel  (slice(None, None, None), slice(None, None, None))  → TPU_0(process=0,(0,0,0,0))
2025-09-21 00:10:44,799 - INFO -  model/fc2/bias  (slice(None, None, None),)  → TPU_0(

In [None]:
opt = nnx.merge(opt_graph, opt_state)
opt.model.train()
opt_graph, opt_state = nnx.split(opt)
opt.model.eval()
model_graph_eval, _ = nnx.split(opt.model)

In [None]:
ckpt_mngr = ocp.CheckpointManager(
    args.checkpoint_dir,
    options=ocp.CheckpointManagerOptions(
        save_interval_steps=args.save_interval,
        max_to_keep=2,
        step_prefix=args.experiment_name,
        enable_async_checkpointing=False,
    ),
)

2025-09-21 00:10:44,818 - INFO - [thread=MainThread] Failed to get flag value for EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.
2025-09-21 00:10:44,818 - INFO - [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers=None, handler_registry=None
2025-09-21 00:10:44,819 - INFO - Initialized registry DefaultCheckpointHandlerRegistry({('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7c6cf0630710>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7c6cf0630710>}).
2025-09-21 00:10:44,819 - INFO - orbax-checkpoint version: 0.11.16
2025-09-21 00:10:44,820 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 1
2025-09-21 00:10:44,821 - INFO - Create

In [None]:
class SinDataset(Dataset):
    """A PyTorch dataset that generates sine function data points.

    This dataset generates random x values from [-π, π] and computes y = sin(x).
    The dataset uses a seeded random number generator for reproducible results.

    Args:
        seed: Random seed for reproducible data generation.
    """

    def __init__(self, seed: int) -> None:
        """Initialize the dataset with a random seed.

        Args:
            seed: Random seed for data generation.
        """
        self.seed = seed
        self.reset_seed()

    def reset_seed(self) -> None:
        """Reset the random number generator to the initial seed.

        This is useful for ensuring reproducible evaluation data.
        """
        self.rng = torch.Generator()
        self.rng.manual_seed(self.seed)

    def __len__(self) -> int:
        """Return the length of the dataset.

        Returns:
            A very large number representing the dataset size.
        """
        return 2**31 - 1

    def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
        """Generate a single data point.

        Args:
            idx: Index (unused, but required for Dataset interface).

        Returns:
            Tuple of (x, y) where x is a random value in [-π, π] and y = sin(x).
        """
        x = torch.rand(1, generator=self.rng) * 2 * torch.pi - torch.pi
        y = torch.sin(x)
        return x.numpy(), y.numpy()

In [None]:
local_batch_size = args.batch_size // jax.process_count()

In [None]:

def train_step(
    opt_graph: nnx.GraphDef,
    opt_state: nnx.State,
    x: jax.Array,
    y: jax.Array,
    add_noise: bool = False,
) -> Tuple[nnx.State, jax.Array]:
    """Perform a single training step with gradient computation and parameter update.

    Computes the forward pass, loss, gradients, and updates model parameters.
    Optionally adds noise to the target values for data augmentation.

    Args:
        opt_graph: Optimizer graph definition (static structure).
        opt_state: Optimizer state (parameters and optimizer state).
        x: Input batch of shape (batch_size, input_dim).
        y: Target batch of shape (batch_size, output_dim).
        add_noise: Whether to add noise to targets for data augmentation.

    Returns:
        Tuple of (updated_optimizer_state, loss_value).
    """
    optimizer = nnx.merge(opt_graph, opt_state)
    model = optimizer.model

    def loss_fn(model: MLP) -> jax.Array:
        y_hat = model(x)
        if add_noise:
            noise_key = model.rngs["noise"]()
            noise = jax.random.normal(noise_key, y.shape)
            y_noisy = y + noise
            loss = jnp.mean((y_hat - y_noisy) ** 2)
        else:
            loss = jnp.mean((y_hat - y) ** 2)
        return loss

    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model)
    optimizer.update(grads)

    _, opt_state = nnx.split(optimizer)

    return opt_state, loss


In [None]:
train_step_fn = jax.jit(
    train_step,
    donate_argnums=(1,),
    static_argnums=(4,),
    out_shardings=(opt_state_sharding, repl_sharding),
)

In [None]:
def test_step(
    model_graph: nnx.GraphDef,
    model_state: nnx.State,
    x: jax.Array,
    y: jax.Array,
) -> Tuple[jax.Array, jax.Array]:
    """Perform a single evaluation step without parameter updates.

    Computes the forward pass and loss for evaluation purposes.

    Args:
        model_graph: Model graph definition (static structure).
        model_state: Model state (parameters only, no optimizer state).
        x: Input batch of shape (batch_size, input_dim).
        y: Target batch of shape (batch_size, output_dim).

    Returns:
        Tuple of (loss_value, predictions).
    """
    model = nnx.merge(model_graph, model_state)
    y_hat = model(x)
    loss = jnp.mean((y_hat - y) ** 2)
    return loss, y_hat

In [None]:
test_step_fn = jax.jit(
    test_step,
    out_shardings=(repl_sharding, data_sharding),
)

In [None]:
def update_ema(
    model_state: nnx.State,
    ema_state: nnx.State,
    ema_decay: float,
) -> nnx.State:
    """Update exponential moving average (EMA) of model parameters.

    Computes the exponential moving average using the formula:
    ema_new = ema_decay * ema_old + (1 - ema_decay) * model_param

    Args:
        model_state: Current model state with updated parameters.
        ema_state: Current EMA state to be updated.
        ema_decay: Decay factor for EMA (typically close to 1.0, e.g., 0.9999).

    Returns:
        Updated EMA state.
    """

    def update_param(p_model: jax.Array, p_ema: jax.Array) -> jax.Array:
        return p_ema * ema_decay + p_model * (1 - ema_decay)

    ema_state_no_rng = jax.tree.map(
        update_param,
        nnx.filter_state(model_state, nnx.Param),
        nnx.filter_state(ema_state, nnx.Param),
    )
    ema_state = nnx.merge_state(ema_state, ema_state_no_rng)
    return ema_state

In [None]:
update_ema_fn = jax.jit(
    update_ema,
    out_shardings=ema_state_sharding,
    donate_argnums=(1,),
)

In [None]:
def make_fsarray_from_local_slice(
    local_slice: jnp.ndarray,
    global_devices: list[jax.Device],
    axis: str,
) -> jax.Array:
    """Create a globally sharded array from a local data slice.

    Takes a local data slice and creates a globally sharded JAX array
    by distributing the data across multiple devices and processes.

    This function is adapted from:
    https://github.com/google-research/big_vision/blob/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/utils.py#L1388-L1409

    Args:
        local_slice: Local portion of the data on this process.
        global_devices: List of all devices across all processes.
        axis: Name of the axis for sharding.

    Returns:
        Globally sharded JAX array with proper device placement.
    """
    mesh = jax.sharding.Mesh(global_devices, (axis,))
    sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(axis))
    local_ds = mesh.local_devices

    x = np.asarray(local_slice)
    xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds)

    global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:])
    return jax.make_array_from_single_device_arrays(global_shape, sharding, xs)

In [None]:
def train_loop(start_step: int, opt_state: nnx.State, ema_state: nnx.State):
  train_dataloader = DataLoader(
    SinDataset(seed=start_step), batch_size=local_batch_size, shuffle=False
  )
  test_dataset = SinDataset(seed=-1)
  test_dataloader = DataLoader(
      test_dataset, batch_size=local_batch_size, shuffle=False
  )



  train_iter = iter(train_dataloader)
  ema_decay = 0.999

  for step in range(start_step, start_step + args.steps):
      x_batch, y_batch = next(train_iter)
      x_batch = make_fsarray_from_local_slice(
          x_batch, mesh.devices.flatten(), data_axis
      )
      y_batch = make_fsarray_from_local_slice(
          y_batch, mesh.devices.flatten(), data_axis
      )

      opt_state, train_loss = train_step_fn(
          opt_graph, opt_state, x_batch, y_batch, args.add_noise
      )

      ema_state = update_ema_fn(opt_state["model"], ema_state, ema_decay)

      if jax.process_index() == 0 and (step + 1) % args.log_interval == 0:
          logging.info(f"Step {step+1}, Train Loss: {train_loss:.6f}")

      if (step + 1) % args.test_interval == 0:
          test_dataset.reset_seed()
          test_iter = iter(test_dataloader)
          x_test, y_test = next(test_iter)
          x_test = make_fsarray_from_local_slice(
              x_test, mesh.devices.flatten(), data_axis
          )
          y_test = make_fsarray_from_local_slice(
              y_test, mesh.devices.flatten(), data_axis
          )
          test_loss, y_pred_model = test_step_fn(
              model_graph_eval, opt_state["model"], x_test, y_test
          )

          test_loss_ema, y_pred_ema = test_step_fn(
              model_graph_eval, ema_state, x_test, y_test
          )

          y_pred_model = jax.experimental.multihost_utils.process_allgather(
              y_pred_model, tiled=True
          )
          y_pred_ema = jax.experimental.multihost_utils.process_allgather(
              y_pred_ema, tiled=True
          )
          x_test = jax.experimental.multihost_utils.process_allgather(
              x_test, tiled=True
          )
          y_test = jax.experimental.multihost_utils.process_allgather(
              y_test, tiled=True
          )

          if jax.process_index() == 0:
              x_plot = np.array(x_test).flatten()
              y_true_plot = np.array(y_test).flatten()
              y_pred_ema_plot = np.array(y_pred_ema).flatten()
              y_pred_model_plot = np.array(y_pred_model).flatten()

              sort_idx = np.argsort(x_plot)
              x_plot = x_plot[sort_idx]
              y_true_plot = y_true_plot[sort_idx]
              y_pred_ema_plot = y_pred_ema_plot[sort_idx]
              y_pred_model_plot = y_pred_model_plot[sort_idx]

              experiment_output_dir = os.path.join(
                  args.output_dir, args.experiment_name
              )
              os.makedirs(experiment_output_dir, exist_ok=True)
              fig = Figure(figsize=(10, 6))
              ax = fig.add_subplot(111)
              ax.scatter(x_plot, y_true_plot, alpha=0.7, label="Ground Truth", s=20)
              ax.scatter(
                  x_plot,
                  y_pred_model_plot,
                  alpha=0.7,
                  label="Model Prediction",
                  s=20,
              )
              ax.scatter(
                  x_plot,
                  y_pred_ema_plot,
                  alpha=0.7,
                  label="EMA Prediction",
                  s=20,
              )
              ax.set_xlabel("X")
              ax.set_ylabel("Y")
              ax.set_title("Sin Function: Ground Truth vs Model vs EMA Prediction")
              ax.legend()
              ax.grid(True, alpha=0.3)

              plot_path = os.path.join(experiment_output_dir, f"eval_{step+1}.png")
              fig.savefig(plot_path, dpi=300, bbox_inches="tight")

              logging.info(f"Plot saved to {plot_path}")

              if jax.process_index() == 0:
                  logging.info(
                      f"Step {step+1}, Test Loss: {test_loss:.6f}, "
                      f"EMA Test Loss: {test_loss_ema:.6f}"
                  )

      if (step + 1) % args.save_interval == 0:
          if jax.process_index() == 0:
              logging.info(f"Saving checkpoint at step {step + 1}")
          opt_rngs, opt_state_no_rngs = nnx.filter_state(opt_state, nnx.RngKey, ...)
          opt_rng_keys = jax.tree.map(jax.random.key_data, opt_rngs)

          ema_rngs, ema_state_no_rngs = nnx.filter_state(ema_state, nnx.RngKey, ...)
          ema_rng_keys = jax.tree.map(jax.random.key_data, ema_rngs)
          if jax.process_index() == 0:
              logging.info(f"Opt rngs: {opt_rngs}")
              logging.info(f"EMA rngs: {ema_rngs}")
              logging.info(f"Opt state no rngs: {opt_state_no_rngs}")
              logging.info(f"EMA state no rngs: {ema_state_no_rngs}")
          ckpt_mngr.save(
              step + 1,
              args=ocp.args.Composite(
                  opt_state=ocp.args.StandardSave(opt_state_no_rngs),
                  opt_rngs=ocp.args.StandardSave(opt_rng_keys),
                  ema_state=ocp.args.StandardSave(ema_state_no_rngs),
                  ema_rngs=ocp.args.StandardSave(ema_rng_keys),
              ),
          )
          if jax.process_index() == 0:
              logging.info(f"Checkpoint saved successfully")



In [None]:
start_step = 0

In [None]:
train_loop(start_step, opt_state, ema_state)

2025-09-21 00:10:47,221 - INFO - Step 100, Train Loss: 0.235292
2025-09-21 00:10:47,944 - INFO - Step 200, Train Loss: 0.177324
2025-09-21 00:10:48,662 - INFO - Step 300, Train Loss: 0.137989
2025-09-21 00:10:49,384 - INFO - Step 400, Train Loss: 0.103467
2025-09-21 00:10:50,102 - INFO - Step 500, Train Loss: 0.071474
2025-09-21 00:10:50,825 - INFO - Step 600, Train Loss: 0.067358
2025-09-21 00:10:51,551 - INFO - Step 700, Train Loss: 0.038749
2025-09-21 00:10:52,272 - INFO - Step 800, Train Loss: 0.031192
2025-09-21 00:10:52,998 - INFO - Step 900, Train Loss: 0.018770
2025-09-21 00:10:53,845 - INFO - Step 1000, Train Loss: 0.010896
2025-09-21 00:10:54,424 - INFO - Plot saved to /content/outputs/fsdp/eval_1000.png
2025-09-21 00:10:54,425 - INFO - Step 1000, Test Loss: 0.010322, EMA Test Loss: 0.306655
2025-09-21 00:10:55,151 - INFO - Step 1100, Train Loss: 0.007536
2025-09-21 00:10:55,889 - INFO - Step 1200, Train Loss: 0.006808
2025-09-21 00:10:56,610 - INFO - Step 1300, Train Loss: 0

In [None]:
latest_step = args.steps

In [None]:
opt_rngs, opt_state_no_rngs = nnx.filter_state(opt_state, nnx.RngKey, ...)
opt_rng_keys = jax.tree.map(jax.random.key_data, opt_rngs)

ema_rngs, ema_state_no_rngs = nnx.filter_state(ema_state, nnx.RngKey, ...)
ema_rng_keys = jax.tree.map(jax.random.key_data, ema_rngs)

state_restored = ckpt_mngr.restore(
    latest_step,
    args=ocp.args.Composite(
        opt_state=ocp.args.StandardRestore(opt_state_no_rngs),
        ema_state=ocp.args.StandardRestore(ema_state_no_rngs),
        opt_rngs=ocp.args.StandardRestore(opt_rng_keys),
        ema_rngs=ocp.args.StandardRestore(ema_rng_keys),
    ),
)
opt_state_no_rngs, ema_state_no_rngs, opt_rngs_keys, ema_rngs_keys = (
    state_restored.opt_state,
    state_restored.ema_state,
    state_restored.opt_rngs,
    state_restored.ema_rngs,
)
opt_rngs = jax.tree_map(jax.random.wrap_key_data, opt_rngs_keys)
ema_rngs = jax.tree_map(jax.random.wrap_key_data, ema_rngs_keys)
opt_state = nnx.merge_state(opt_state_no_rngs, opt_rngs)
ema_state = nnx.merge_state(ema_state_no_rngs, ema_rngs)
if jax.process_index() == 0:
    logging.info("Checkpoint restored successfully")
    log_shard_map("Opt state sharding after restore", opt_state)
    log_shard_map("EMA state sharding after restore", ema_state)
    logging.info(f"Opt state after restore: {opt_state}")
    logging.info(f"EMA state after restore: {ema_state}")


2025-09-21 00:11:25,625 - INFO - Restoring checkpoint from /content/checkpoints/fsdp_5000.
2025-09-21 00:11:25,632 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 4.5 KiB/s (total bytes: 24 Bytes) (time elapsed: 5 milliseconds) (per-host)
2025-09-21 00:11:25,646 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 311.9 MiB/s (total bytes: 4.0 MiB) (time elapsed: 12 milliseconds) (per-host)
2025-09-21 00:11:25,651 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 5.6 KiB/s (total bytes: 24 Bytes) (time elapsed: 4 milliseconds) (per-host)
2025-09-21 00:11:25,680 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 423.8 MiB/s (total bytes: 12.0 MiB) (time elapsed: 28 milliseconds) (per-host)
2025-09-21 00:11:25,681 - INFO - Finished restoring checkpoint in 0.06 seconds from /content/checkpoints/fsdp_5000.
2025-09-21 00:11:25,682 - INFO - {'step': 5000, 'event_type': 'restore', 'directory': '/content/checkpoints', 'checkpointer_start_time': 1758413485.625635, 'che

In [None]:
start_step = latest_step

In [None]:
train_loop(start_step, opt_state, ema_state)

2025-09-21 00:11:26,437 - INFO - Step 5100, Train Loss: 0.000642
2025-09-21 00:11:27,145 - INFO - Step 5200, Train Loss: 0.000949
2025-09-21 00:11:27,852 - INFO - Step 5300, Train Loss: 0.000989
2025-09-21 00:11:28,670 - INFO - Step 5400, Train Loss: 0.001427
2025-09-21 00:11:29,388 - INFO - Step 5500, Train Loss: 0.000771
2025-09-21 00:11:30,113 - INFO - Step 5600, Train Loss: 0.000907
2025-09-21 00:11:30,832 - INFO - Step 5700, Train Loss: 0.002587
2025-09-21 00:11:31,552 - INFO - Step 5800, Train Loss: 0.002335
2025-09-21 00:11:32,274 - INFO - Step 5900, Train Loss: 0.001451
2025-09-21 00:11:32,997 - INFO - Step 6000, Train Loss: 0.000944
2025-09-21 00:11:33,347 - INFO - Plot saved to /content/outputs/fsdp/eval_6000.png
2025-09-21 00:11:33,348 - INFO - Step 6000, Test Loss: 0.000652, EMA Test Loss: 0.000216
2025-09-21 00:11:34,069 - INFO - Step 6100, Train Loss: 0.000665
2025-09-21 00:11:34,788 - INFO - Step 6200, Train Loss: 0.002592
2025-09-21 00:11:35,504 - INFO - Step 6300, Trai