# Introduction to Checkpointing with Flax NNX and Orbax

Welcome to this hands-on exercise! We'll explore how to save and load your JAX/Flax NNX models, a crucial skill for any serious machine learning project.

## Why Checkpoint?
Training deep learning models can take a long time. Checkpointing allows you to:

* Save your progress (model parameters, optimizer state) to resume training later if it gets interrupted.
* Preserve model states at different stages for analysis or inference.
* Implement fault tolerance in long training runs.

## Flax NNX: A Quick Recap

* **Stateful Modules**: NNX modules are Python classes that directly hold their own state (like parameters) as attributes. This often feels more intuitive, especially if you're coming from PyTorch.
* `nnx.Module`: The base class for creating these stateful components.
* `nnx.Variable`: Special types like nnx.Param and nnx.BatchStat are used to define learnable parameters and other stateful variables within an nnx.Module.
* `nnx.State`: A JAX Pytree (like a nested dictionary) that holds all the nnx.Variable values from a module. This is what Orbax saves and restores.

## The Functional Bridge:

* `nnx.split(module)`: Separates a module into its static structure (GraphDef) and its dynamic state (nnx.State). This is key for getting the state to save.
* `nnx.merge(graphdef, state)`: Reconstructs a module instance from its GraphDef and nnx.State. Used after restoring.
* `nnx.update(module, state)`: Updates an existing module's state in-place. Also used after restoring.

## Orbax: The JAX Checkpointing Library

Orbax is the standard library for checkpointing in JAX, designed to be robust and scalable.

* `ocp.CheckpointManager`: A high-level utility that simplifies managing multiple checkpoints over a training run (e.g., keeping the last N checkpoints, handling versions). We'll be using this extensively.
* `ocp.args`: Namespace for specifying how to save/restore different parts of your state (e.g., ocp.args.StandardSave, ocp.args.StandardRestore, ocp.args.Composite).

Let's get started!

In [None]:
# @title Setup: Install and Import Libraries
# Install necessary libraries
!pip install -Uq flax orbax-checkpoint chex optax

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import orbax.checkpoint as ocp
import optax
import os
import shutil # For cleaning up directories
import chex # For faking devices

# Suppress some JAX warnings for cleaner output in the notebook
import warnings
warnings.filterwarnings("ignore", message="No GPU/TPU found, falling back to CPU.")
warnings.filterwarnings("ignore", message="Custom node type GlobalDeviceArray is not handled by Pytree traversal.") # Orbax/NNX interactions

print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Orbax version: {ocp.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Chex version: {chex.__version__}")

# --- Setup for Distributed Exercises ---
# Simulate an environment with 8 CPUs for distributed examples
# This allows us to test sharding logic even on a single-CPU Colab machine.
try:
  chex.set_n_cpu_devices(8)
except RuntimeError as e:
  print(f"Note: Could not set_n_cpu_devices (may have been set already): {e}")

print(f"Number of JAX devices available: {jax.device_count()}")
print(f"Available devices: {jax.devices()}")

# Helper function to clean up checkpoint directories
def cleanup_ckpt_dir(ckpt_dir):
  if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)
    print(f"Cleaned up checkpoint directory: {ckpt_dir}")

# Create a default checkpoint directory for exercises
CKPT_BASE_DIR = '/tmp/nnx_orbax_workshop_checkpoints'
if not os.path.exists(CKPT_BASE_DIR):
  os.makedirs(CKPT_BASE_DIR)

print(f"Base checkpoint directory: {CKPT_BASE_DIR}")

## Exercise 1: Basic Checkpointing - Saving nnx.State

**Goal**: Learn to save the state of a simple Flax NNX module using Orbax.

### Topics:

* Defining an nnx.Module.
* Instantiating an nnx.Module with initial parameters.
* Using nnx.split() to extract the nnx.State Pytree.
* Setting up ocp.CheckpointManager.
* Saving the state using mngr.save() with ocp.args.StandardSave.

### Instructions:

1. Define a simple linear layer SimpleLinear that inherits from nnx.Module.
 - In its __init__, define a weight matrix and a bias vector as nnx.Param attributes. Initialize them with JAX random functions (e.g., jax.random.uniform for weights, jnp.zeros for bias). Remember nnx.Rngs for key management!
 - Implement the __call__ method for the forward pass: y = x @ weight + bias.
2. Instantiate this SimpleLinear module.
3. Specify a directory for saving checkpoints.
4. Create an ocp.CheckpointManagerOptions object to configure checkpointing (e.g., max_to_keep=3).
5. Instantiate ocp.CheckpointManager with the directory and options.
6. Use nnx.split(model) to get the graphdef and the state_to_save.
7. Save the state_to_save at a specific training step (e.g., step 100) using mngr.save(). You'll need to wrap state_to_save with ocp.args.StandardSave().
8. Call mngr.wait_until_finished() to ensure the save operation completes (important if saving is asynchronous).
9. Close the manager using mngr.close().

In [None]:
# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key_w, key_b = rngs.params(), rngs.params() # Example of splitting keys if needed, or use one key for multiple params
    # TODO: Define self.weight as an nnx.Param with shape (din, dout)
    # self.weight = ...
    # TODO: Define self.bias as an nnx.Param with shape (dout,)
    # self.bias = ...

  def __call__(self, x: jax.Array) -> jax.Array:
    # TODO: Implement the forward pass
    # return ...

# --- Instantiate the Model ---
din, dout = 10, 5
# TODO: Create an nnx.Rngs object for parameter initialization
# rngs = ...
# TODO: Instantiate SimpleLinear
# model = ...

print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")

# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1) # Clean up from previous runs

# TODO: Create CheckpointManagerOptions
# options = ...
# TODO: Instantiate CheckpointManager
# mngr = ...

# --- Split the model to get the state ---
# TODO: Split the model into graphdef and state_to_save
# _graphdef, state_to_save = ...
# Alternatively, for just the state: state_to_save = nnx.state(model)
# print(f"State to save: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, state_to_save)}")

# --- Save the state ---
step = 100
# TODO: Save the state_to_save at the given step. Use ocp.args.StandardSave.
# mngr.save(...)
# TODO: Wait for saving to complete
# mngr.wait_until_finished()

print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")

# TODO: Close the manager
# mngr.close()

In [None]:
# @title Exercise 1: Solution
# --- Define the NNX Module ---
class SimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # Parameters defined using nnx.Param (a type of nnx.Variable)
    self.weight = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.bias = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x: jax.Array) -> jax.Array:
    # Parameters used directly via self.weight, self.bias
    return x @ self.weight.value + self.bias.value

# --- Instantiate the Model ---
din, dout = 10, 5
rngs = nnx.Rngs(params=jax.random.key(0)) # NNX requires explicit RNG management
model = SimpleLinear(din=din, dout=dout, rngs=rngs)

print(f"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}")

# --- Setup CheckpointManager ---
ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')
cleanup_ckpt_dir(ckpt_dir_ex1)

options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=1)
mngr = ocp.CheckpointManager(ckpt_dir_ex1, options=options)

# --- Split the model to get the state ---
_graphdef, state_to_save = nnx.split(model)
# Alternatively: state_to_save = nnx.state(model)
print(f"State to save structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), state_to_save)}")

# --- Save the state ---
step = 100
mngr.save(step, args=ocp.args.StandardSave(state_to_save))
mngr.wait_until_finished() # Ensure save completes if async

print(f"Checkpoint saved for step {step} in {ckpt_dir_ex1}.")
print(f"Available checkpoints: {mngr.all_steps()}")

mngr.close() # Clean up resources

## Exercise 2: Basic Checkpointing - Restoring nnx.State

**Goal**: Learn to restore a model's state from a checkpoint using Orbax.

###Topics:

* Using nnx.eval_shape() to create an "abstract" model template.
* Splitting the abstract model to get an abstract_state (a Pytree of ShapeDtypeStruct objects).
* Restoring the state using mngr.restore() with the abstract_state and ocp.args.StandardRestore.
* Reconstructing the model using nnx.merge() with the original graphdef and the restored_state.
* Alternatively, updating an existing model instance with nnx.update().

### Instructions:

1. Re-open the CheckpointManager pointing to the directory from Exercise 1 (ckpt_dir_ex1).
2. Define a function create_abstract_model() that instantiates your SimpleLinear module. This function will be passed to nnx.eval_shape().
 - Inside this function, use dummy RNG keys and input shapes as nnx.eval_shape only cares about the structure and dtypes, not actual values.
3. Create an abstract_model by calling abstract_model = nnx.eval_shape(create_abstract_model).
4. Split the abstract_model using graphdef_for_restore, abstract_state = nnx.split(abstract_model). The abstract_state now contains ShapeDtypeStruct leaves, which Orbax uses as a template for restoration.
5. Find the latest checkpoint step using mngr.latest_step().
6. If a checkpoint exists, restore the state using mngr.restore(step_to_restore, args=ocp.args.StandardRestore(abstract_state)).
7. Reconstruct the model using restored_model = nnx.merge(graphdef_for_restore, restored_state).
8. (Optional) Print a value from the restored model (e.g., restored_model.bias.value) to verify.
9. Close the manager.

In [None]:
# Ensure the SimpleLinear class definition from Exercise 1 is available

# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex1 (no need for options if just restoring)
# mngr_restore = ...

# --- Create Abstract Model for Restoration ---
def create_abstract_model():
  # Use dummy RNG key/inputs for abstract creation
  # TODO: Return an instance of SimpleLinear, same din/dout as before
  # return ...

# TODO: Create the abstract_model using nnx.eval_shape
# abstract_model = ...

# --- Split Abstract Model to get Abstract State Structure ---
# TODO: Split the abstract_model to get graphdef_for_restore and abstract_state
# graphdef_for_restore, abstract_state = ...
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")


# --- Restore the State ---
# TODO: Get the latest step to restore
# step_to_restore = ...

if step_to_restore is not None:
  # TODO: Restore the state using mngr_restore.restore() and ocp.args.StandardRestore with abstract_state
  # restored_state = mngr_restore.restore(...)

  # --- Reconstruct the Model ---
  # TODO: Reconstruct the model using nnx.merge with graphdef_for_restore and restored_state
  # restored_model = ...
  print(f"Model restored from step {step_to_restore}.")
  # You can now use 'restored_model'
  print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")

  # Alternative: Update an existing model instance
  # model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
  # nnx.update(model_to_update, restored_state)
  # print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
else:
  print("No checkpoint found to restore.")

# TODO: Close the manager
# mngr_restore.close()

In [None]:
# @title Exercise 2: Solution

# Ensure the SimpleLinear class definition from Exercise 1 is available

# --- Re-open CheckpointManager ---
mngr_restore = ocp.CheckpointManager(ckpt_dir_ex1) # Re-open manager

# --- Create Abstract Model for Restoration ---
def create_abstract_model():
  # Use dummy RNG key/inputs for abstract creation
  return SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(0))) # din, dout from Ex1

abstract_model = nnx.eval_shape(create_abstract_model)

# --- Split Abstract Model to get Abstract State Structure ---
graphdef_for_restore, abstract_state = nnx.split(abstract_model)
# abstract_state now contains ShapeDtypeStruct leaves
print(f"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}")

# --- Restore the State ---
step_to_restore = mngr_restore.latest_step()

if step_to_restore is not None:
  restored_state = mngr_restore.restore(step_to_restore,
      args=ocp.args.StandardRestore(abstract_state))
  print(f"Restored state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), restored_state)}")

  # --- Reconstruct the Model ---
  restored_model = nnx.merge(graphdef_for_restore, restored_state)
  print(f"Model restored from step {step_to_restore}.")
  # You can now use 'restored_model'
  print(f"Restored bias (first 3 values): {restored_model.bias.value[:3]}")

  # Compare with original model's bias (optional, if 'model' from Ex1 is still in scope)
  # print(f"Original bias (first 3 values): {model.bias.value[:3]}")
  # chex.assert_trees_all_close(restored_model.bias.value, model.bias.value)

  # Alternative: Update an existing model instance
  model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance
  # Initialize with different values to see update working
  model_to_update.bias.value = jnp.ones_like(model_to_update.bias.value) * 55.0
  print(f"Bias before update: {model_to_update.bias.value[:3]}")
  nnx.update(model_to_update, restored_state)
  print(f"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}")
  if 'model' in globals(): # Check if original model exists
    chex.assert_trees_all_close(model_to_update.bias.value, model.bias.value)
else:
  print("No checkpoint found to restore.")

mngr_restore.close()

## Exercise 3: Saving Model Parameters and Optimizer State

**Goal**: Learn to save both model parameters and optimizer state together in a single checkpoint.

### Topics:

* Using nnx.Optimizer to manage model parameters and an Optax optimizer state.
* Extracting model parameters (e.g., using nnx.split(model, nnx.Param)).
* Extracting the full optimizer state (nnx.state(optimizer)).
* Using ocp.args.Composite to save multiple named items (model params, optimizer state) in one checkpoint.

### Instructions:

1. Reuse the SimpleLinear module definition. Instantiate a new SimpleLinear model.
2. Create an Optax optimizer (e.g., optax.adam(learning_rate=1e-3)).
3. Wrap the model and the Optax optimizer with nnx.Optimizer.
4. (Optional) Simulate a few training steps to update the optimizer's internal state (e.g., momentum). You don't need actual data; just update the step count and imagine gradients were applied.
 - Access optimizer step via optimizer.step.value. Update it: optimizer.step.value += 1.
5. Set up a new CheckpointManager in a new directory (ckpt_dir_ex3).
6. Extract the model's parameters: _graphdef_params, params_state = nnx.split(model_ex3, nnx.Param). Note that the optimizer.model attribute has been removed, so we split the original model variable directly.
7. Extract the full optimizer state: optimizer_state_tree = nnx.state(optimizer). This includes optimizer internal states (like momentum) and its own step count.
8. Define a dictionary save_items where keys are names (e.g., 'params', 'optimizer') and values are ocp.args.StandardSave() wrapped Pytrees (i.e., params_state and optimizer_state_tree).
9. Save these items using mngr.save(step, args=ocp.args.Composite(**save_items)). Use the optimizer's current step.
10. Wait and close the manager.

In [None]:
# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)

# TODO: Create an Optax optimizer (e.g., Adam)
# tx = ...
# TODO: Create an nnx.Optimizer, wrapping the model and tx
# optimizer = ...

# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'): # Check for NNX Optimizer structure
  optimizer.step.value += 10 # Simulate 10 steps
  # In a real loop: optimizer.update_fn(grads, optimizer.state) -> optimizer.state would be updated
  # For this exercise, just advancing step is enough to see it saved/restored.
  # Let's also change a parameter slightly to see it saved
  original_bias_val_ex3 = model_ex3.bias.value.copy()
  model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1
  print(f"Optimizer step: {optimizer.step.value}")
  print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")
else:
  print("Skipping optimizer step update as structure might differ from expected nnx.Optimizer.")


# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp = ...

# --- Extract States for Saving ---
# TODO: Extract model parameters state from optimizer.model using nnx.split with nnx.Param filter
# _graphdef_params, params_state = ...
# TODO: Extract the full optimizer state tree using nnx.state()
# optimizer_state_tree = ...

print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, optimizer_state_tree)}")

# --- Save Composite State ---
current_step_val = 0
if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'):
  current_step_val = optimizer.step.value
else: # Fallback for safety, though nnx.Optimizer should have .step
  current_step_val = 10


# TODO: Define save_items dictionary for 'params' and 'optimizer'
# Each item should be wrapped with ocp.args.StandardSave
# save_items = {
#     'params': ...,
#     'optimizer': ...
# }

# TODO: Save using mngr_comp.save() and ocp.args.Composite
# mngr_comp.save(...)
# TODO: Wait and close the manager
# mngr_comp.wait_until_finished()
# print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
# print(f"Available checkpoints: {mngr_comp.all_steps()}")
# mngr_comp.close()

In [None]:
# @title Exercise 3: Solution

# Ensure SimpleLinear class definition is available
# --- Instantiate Model and Optimizer ---
rngs_ex3 = nnx.Rngs(params=jax.random.key(1))
model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)

tx = optax.adam(learning_rate=1e-3)
optimizer = nnx.Optimizer(model_ex3, tx, wrt=nnx.Param)

# Simulate a few "training" steps to populate optimizer state
# For a real scenario, this would involve gradients and updates
optimizer.step.value += 10 # Simulate 10 steps
original_bias_val_ex3 = model_ex3.bias.value.copy()
# Simulate a parameter update that would happen during training
model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1 # Arbitrary change
print(f"Optimizer step: {optimizer.step.value}")
print(f"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}")

# --- Setup CheckpointManager for Composite Save ---
ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')
cleanup_ckpt_dir(ckpt_dir_ex3)
mngr_comp = ocp.CheckpointManager(ckpt_dir_ex3, options=ocp.CheckpointManagerOptions(max_to_keep=3))

# --- Extract States for Saving ---
# Extract model parameters (e.g., using nnx.split(model, nnx.Param))
_graphdef_params, params_state = nnx.split(model_ex3, nnx.Param)
# Extract optimizer state (nnx.state(optimizer))
optimizer_state_tree = nnx.state(optimizer)

print(f"Parameter state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), params_state)}")
print(f"Optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), optimizer_state_tree)}")
# Note: optimizer_state_tree also contains the model's state within optimizer.model_variables

# --- Save Composite State ---
current_step_val = optimizer.step.value # Get current step from optimizer

# Save using Composite args
save_items = {
  'params': ocp.args.StandardSave(params_state),
  'optimizer': ocp.args.StandardSave(optimizer_state_tree)
}

# Can generate args per item using orbax_utils too
mngr_comp.save(current_step_val, args=ocp.args.Composite(**save_items))
mngr_comp.wait_until_finished()
print(f"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.")
print(f"Available checkpoints: {mngr_comp.all_steps()}")
mngr_comp.close()

## Exercise 4: Restoring Model Parameters and Optimizer State

**Goal**: Learn to restore both model parameters and optimizer state from a composite checkpoint.

### Topics:

* Creating abstract versions of both model and optimizer using nnx.eval_shape.
* Getting abstract state templates for both parameter state and optimizer state.
* Using ocp.args.Composite with ocp.args.StandardRestore for restoring multiple items.
* Instantiating new concrete model and optimizer instances.
* Updating these instances using nnx.update() with the restored states.

### Instructions:

1. Re-open the CheckpointManager from Exercise 3 (ckpt_dir_ex3).
2. Define a function create_abstract_model_and_optimizer():
 - Inside, create an abstract model instance (e.g., SimpleLinear) using nnx.eval_shape on a creation lambda.
 - Then, create an abstract nnx.Optimizer instance using nnx.eval_shape, passing the abstract model and a new Optax optimizer instance to its creation lambda.
 - Return both abs_model and abs_optimizer.
3. Call this function to get abs_model and abs_optimizer.
4. Get the abstract state for parameters: _graphdef_abs_params, abs_params_state = nnx.split(abs_model, nnx.Param).
5. Get the abstract state for the optimizer: abs_optimizer_state = nnx.state(abs_optimizer).
6. Find the latest step to restore.
7. If a checkpoint exists, define a restore_targets dictionary for ocp.args.Composite. Keys should match those used during save ('params', 'optimizer'), and values should be ocp.args.StandardRestore() wrapped abstract states.
8. Restore using mngr_comp.restore(step, args=ocp.args.Composite(**restore_targets)). This will return a dictionary restored_items.
9. Create new, "fresh" instances of your SimpleLinear model and nnx.Optimizer.
10. Update the fresh model in-place using nnx.update(fresh_model, restored_items['params']).
11. Update the fresh optimizer in-place using nnx.update(fresh_optimizer, restored_items['optimizer']).
12. Verify by checking the optimizer's step and a model parameter.
13. Close the manager.

In [None]:
# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex3
# mngr_comp_restore = ...

# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
  rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
  # TODO: Create abstract model. Model class: SimpleLinear(din=10, dout=5, ...)
  # abs_model = SimpleLinear(...)

  # TODO: Create abstract optimizer. Pass abs_model and an optax.adam instance.
  # abs_opt = nnx.Optimizer(...)
  # return abs_model, abs_opt

# TODO: Call the function to get abstract model and optimizer
# abs_model_restore, abs_optimizer_restore = ...

# --- Get Abstract States ---
# TODO: Get abstract parameter state from abs_model_restore (filter with nnx.Param)
# _graphdef_abs_params, abs_params_state = ...
# TODO: Get abstract optimizer state from abs_optimizer_restore
# abs_optimizer_state = ...

print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_optimizer_state)}")

# --- Restore Composite State ---
# TODO: Get the latest step
# step_to_restore_comp = ...

if step_to_restore_comp is not None:
  # TODO: Define restore_targets dictionary for 'params' and 'optimizer'
  # Each item should be wrapped with ocp.args.StandardRestore and its corresponding abstract state.
  # restore_targets = {
  #    'params': ...,
  #    'optimizer': ...
  # }
  # TODO: Restore items using mngr_comp_restore.restore() and ocp.args.Composite
  # restored_items = mngr_comp_restore.restore(...)

  # --- Instantiate and Update Concrete Model/Optimizer ---
  # TODO: Create a fresh SimpleLinear model instance (use a new RNG key, e.g., key(2))
  # fresh_model = ...
  # TODO: Create a fresh nnx.Optimizer instance with fresh_model and a new optax.adam instance
  # fresh_optimizer = ...

  # Store pre-update values for comparison
  pre_update_bias = fresh_model.bias.value.copy()
  pre_update_opt_step = fresh_optimizer.step.value

  # TODO: Update fresh_model with restored_items['params'] using nnx.update()
  # nnx.update(...)
  # TODO: Update fresh_optimizer with restored_items['optimizer'] using nnx.update()
  # nnx.update(...)

  print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
  print(f"Fresh model bias before update (first val): {pre_update_bias[0]}")
  print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}")
  print(f"Original bias from Ex3 (first val): {model_ex3.bias.value[0]}") # model_ex3 is from previous cell

  # Verification
  # chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value) # Compare with the state that was saved
  # assert fresh_optimizer.step.value == optimizer.step.value # Compare with optimizer state that was saved
else:
  print("No composite checkpoint found.")

# TODO: Close the manager
# mngr_comp_restore.close()

In [None]:
# @title Exercise 4: Solution

# Ensure SimpleLinear class definition is available
# --- Re-open CheckpointManager ---
mngr_comp_restore = ocp.CheckpointManager(ckpt_dir_ex3)

# --- Create Abstract Model and Optimizer ---
def create_abstract_model_and_optimizer():
  rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation
  # Create abstract model
  abs_model = SimpleLinear(din=10, dout=5, rngs=rngs_abs)
  # Create abstract optimizer
  abs_opt = nnx.Optimizer(abs_model, optax.adam(1e-3), wrt=nnx.Param)
  return abs_model, abs_opt

abs_model_restore, abs_optimizer_restore = create_abstract_model_and_optimizer()

# --- Get Abstract States ---
_graphdef_abs_params, abs_params_state = nnx.split(abs_model_restore, nnx.Param)
abs_optimizer_state = nnx.state(abs_optimizer_restore)

print(f"Abstract params state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_params_state)}")
print(f"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_optimizer_state)}")

# --- Restore Composite State ---
step_to_restore_comp = mngr_comp_restore.latest_step()

if step_to_restore_comp is not None:
  restore_targets = {
    'params': ocp.args.StandardRestore(abs_params_state),
    'optimizer': ocp.args.StandardRestore(abs_optimizer_state)
  }
  restored_items = mngr_comp_restore.restore(step_to_restore_comp, args=ocp.args.Composite(**restore_targets))

  # --- Instantiate and Update Concrete Model/Optimizer ---
  # Create fresh instances
  fresh_rngs = nnx.Rngs(params=jax.random.key(2)) # Use a different key for the fresh model
  fresh_model = SimpleLinear(din=10, dout=5, rngs=fresh_rngs)
  fresh_optimizer = nnx.Optimizer(fresh_model, optax.adam(1e-3), wrt=nnx.Param) # Matching optax optimizer

  # Store pre-update values for comparison
  pre_update_bias = fresh_model.bias.value.copy()
  pre_update_opt_step = fresh_optimizer.step.value

  # Update using restored states
  nnx.update(fresh_model, restored_items['params'])
  nnx.update(fresh_optimizer, restored_items['optimizer'])

  print(f"Restored and updated. Optimizer step: {fresh_optimizer.step.value}")
  print(f"Fresh model bias before update (first val): {pre_update_bias[0]}") # Will be from key(2)
  print(f"Fresh model bias after update (first val): {fresh_model.bias.value[0]}") # Should match model_ex3 bias

  # Verification (model_ex3 and optimizer are from the previous cell where they were saved)
  chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value)
  assert fresh_optimizer.step.value == optimizer.step.value
  print("Verification successful: Restored model parameters and optimizer step match the saved state.")
else:
  print("No composite checkpoint found.")

mngr_comp_restore.close()

## Exercise 5: Distributed Checkpointing - Saving Sharded State

**Goal**: Understand how to save model state that has been sharded across multiple devices. Orbax handles sharded JAX arrays efficiently.

### Topics:

* Setting up a JAX device Mesh.
* Defining PartitionSpec for sharding arrays.
* Creating sharded parameters within an nnx.Module. One way is to initialize parameters and then use jax.device_put with NamedSharding to shard them, then update the module's state. NNX also allows attaching sharding annotations directly to nnx.Variable metadata.
* Saving sharded state: Orbax handles sharded arrays transparently during saving if the JAX arrays in the state Pytree are already sharded.

### Instructions:

1. Define the number of devices and create a device mesh (e.g., a 1D mesh with all available devices).
2. Modify the SimpleLinear module (or create ShardedSimpleLinear):
* In `__init__`, after initializing parameters, you'll shard them.
* For the weight matrix (din, dout), let's shard it along the dout dimension (e.g., PartitionSpec(None, 'data')).
* The bias vector (dout,) will also be sharded along its only dimension (PartitionSpec('data')).
* To apply sharding:
 - Create NamedSharding objects from your PartitionSpec and the mesh.
 - Use jax.device_put(param_value, named_sharding) to get sharded JAX arrays.
 - Update the .value of your nnx.Param attributes with these sharded arrays.
3. Instantiate your sharded model within the mesh context manager (with mesh:). This ensures operations are aware of the mesh.
4. Set up a CheckpointManager in a new directory (ckpt_dir_ex5).
5. Split the sharded model to get its state: _graphdef_sharded, sharded_state_to_save = nnx.split(sharded_model). The arrays within sharded_state_to_save should now be jax.Array objects with sharding information.
6. Save this sharded_state_to_save using mngr.save(). The process is the same as non-sharded saving from Orbax's perspective.
7. Wait and close.

In [None]:
# --- Setup JAX Mesh ---
num_devices = jax.device_count()
# If num_devices is 1 after chex.set_n_cpu_devices(8), it means JAX didn't pick up the fakes.
# This can happen if JAX initializes its backends before chex runs.
# Forcing a rerun of this cell or restarting runtime and running setup first might help.
print(f"Using {num_devices} devices for sharding.")
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh
print(mesh)

# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
    self.din = din
    self.dout = dout
    self.mesh = mesh

    key_w, key_b = rngs.params(), rngs.params()

    # Initialize as regular JAX arrays first
    initial_weight = jax.random.uniform(key_w, (din, dout))
    initial_bias = jnp.zeros((dout,))

    # TODO: Define PartitionSpec for weight (shard dout across 'data' axis)
    # e.g., PartitionSpec(None, 'data') means not sharded on dim 0, sharded on dim 1
    # weight_pspec = ...
    # TODO: Define PartitionSpec for bias (shard along 'data' axis)
    # bias_pspec = ...

    # TODO: Create NamedSharding for weight and bias using self.mesh and the pspecs
    # weight_sharding = NamedSharding(...)
    # bias_sharding = NamedSharding(...)

    # TODO: Shard the initial arrays using jax.device_put and the NamedSharding
    # sharded_weight_value = jax.device_put(...)
    # sharded_bias_value = jax.device_put(...)

    # TODO: Assign these sharded arrays to nnx.Param attributes
    # self.weight = nnx.Param(sharded_weight_value)
    # self.bias = nnx.Param(sharded_bias_value)

    # Alternative (more direct with nnx.Variable metadata if supported well for this case):
    # self.weight = nnx.Param(initial_weight, sharding=weight_sharding) # This depends on NNX API
    # For this exercise, jax.device_put is explicit and clear.

  def __call__(self, x: jax.Array) -> jax.Array:
    # x is assumed to be replicated or appropriately sharded for the matmul
    # For simplicity, assume x is replicated if din is not sharded, or sharded compatibly.
    return x @ self.weight.value + self.bias.value

# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Ensure dout is divisible by num_devices for even sharding
rngs_sharded = nnx.Rngs(params=jax.random.key(3))

# TODO: Instantiate ShardedSimpleLinear within the mesh context
# with mesh:
#   sharded_model = ...

# print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
# print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")


# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
# TODO: Instantiate CheckpointManager
# mngr_sharded_save = ...

# --- Split and Save Sharded State ---
# TODO: Split the sharded_model
# _graphdef_sharded, sharded_state_to_save = ...

# print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
# print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")

# current_step_sharded = 200
# TODO: Save the sharded_state_to_save
# mngr_sharded_save.save(...)
# TODO: Wait and close
# mngr_sharded_save.wait_until_finished()
# print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
# mngr_sharded_save.close()

In [None]:
# @title Exercise 5: Solution

# --- Setup JAX Mesh ---
num_devices = jax.device_count()
if num_devices == 1 and chex.set_n_cpu_devices.called_in_process: # If we faked 8 but only see 1
     print("Warning: JAX might not be using the faked CPU devices. Restart runtime and run Setup cell first if sharding tests fail.")
print(f"Using {num_devices} devices for sharding.")
# Ensure a 1D mesh for simplicity, using all available (or faked) devices.
device_mesh = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh for 'data' parallelism
print(mesh)

# --- Define Sharded NNX Module ---
class ShardedSimpleLinear(nnx.Module):
  def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):
    self.din = din
    self.dout = dout
    self.mesh = mesh # Store mesh for creating NamedSharding

    key_w, key_b = rngs.params(), rngs.params()

    initial_weight = jax.random.uniform(key_w, (din, dout))
    initial_bias = jnp.zeros((dout,))

    # Define PartitionSpec for sharding
    # Shard weight's second dimension (dout) across the 'data' mesh axis
    weight_pspec = PartitionSpec(None, 'data')
    # Shard bias's only dimension (dout) across the 'data' mesh axis
    bias_pspec = PartitionSpec('data',)

    # Create NamedSharding from PartitionSpec and mesh
    weight_sharding = NamedSharding(self.mesh, weight_pspec)
    bias_sharding = NamedSharding(self.mesh, bias_pspec)

    # Shard the initial arrays using jax.device_put
    # This ensures the arrays are created with the specified sharding
    sharded_weight_value = jax.device_put(initial_weight, weight_sharding)
    sharded_bias_value = jax.device_put(initial_bias, bias_sharding)

    self.weight = nnx.Param(sharded_weight_value)
    self.bias = nnx.Param(sharded_bias_value)
    # Note: Flax NNX aims to allow sharding annotations directly in nnx.Variable metadata
    # e.g., using nnx.spmd.with_partitioning or passing sharding to nnx.Param.
    # Explicit jax.device_put is also a valid way to get sharded arrays into the state.

  def __call__(self, x: jax.Array) -> jax.Array:
    return x @ self.weight.value + self.bias.value

# --- Instantiate Sharded Model within Mesh context ---
din_s, dout_s = 8, num_devices * 2 # Make dout divisible by num_devices
rngs_sharded = nnx.Rngs(params=jax.random.key(3))

with mesh: # Operations within this context are aware of the mesh
  sharded_model = ShardedSimpleLinear(din_s, dout_s, mesh, rngs=rngs_sharded)

print(f"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}")
print(f"Sharded model bias sharding: {sharded_model.bias.value.sharding}")

# --- Setup CheckpointManager for Sharded Save ---
ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')
cleanup_ckpt_dir(ckpt_dir_ex5)
mngr_sharded_save = ocp.CheckpointManager(ckpt_dir_ex5, options=ocp.CheckpointManagerOptions(max_to_keep=1))

# --- Split and Save Sharded State ---
# The live state already contains sharded jax.Array objects
_graphdef_sharded, sharded_state_to_save = nnx.split(sharded_model)

print(f"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}")
print(f"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}")
# The actual arrays in sharded_state_to_save are now GlobalDeviceArrays (or jax.Array with sharding)

current_step_sharded = 200
# Orbax handles sharded-array saving under the hood
mngr_sharded_save.save(current_step_sharded, args=ocp.args.StandardSave(sharded_state_to_save))
mngr_sharded_save.wait_until_finished()
print(f"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.")
mngr_sharded_save.close()

## Exercise 6: Distributed Checkpointing - Restoring Sharded State

**Goal**: Learn to restore sharded model state, which requires providing an abstract state Pytree that includes the target sharding specifications.

### Topics:

* Creating an abstract model using nnx.eval_shape.
* Splitting it to get an abstract state.
* Crucial Step: Applying sharding specifications to this abstract state to create a "sharding-aware template" or abstract_target. This is often done using jax.lax.with_sharding_constraint or by ensuring the nnx.eval_shape process (if the module itself defines sharding during abstract construction) yields abstract states with correct sharding.
* Using StandardRestore with this sharding-aware abstract_target.
* Merging the restored sharded state with a graph definition to reconstruct the model.

### Instructions:

1. Reuse the mesh from Exercise 5.
2. Re-open the CheckpointManager pointing to ckpt_dir_ex5.
3. Define a function, e.g., create_abstract_sharded_model_for_restore(mesh).
* Inside, instantiate your ShardedSimpleLinear module (or a similar one intended for sharded restoration) with the provided mesh. This instantiation should ensure its parameters would be sharded if it were a concrete model.
* Pass a lambda creating this module to nnx.eval_shape() to get an abstract_model.
* The key is that nnx.split(abstract_model) should yield an abstract_state where leaves corresponding to sharded parameters are ShapeDtypeStructs that already encode the target sharding. This happens if ShardedSimpleLinear's `__init__` uses jax.device_put with NamedSharding on dummy data when nnx.is_abstract_eval() is true, or if NNX's sharding annotation system propagates this to the abstract state.
* A more explicit way (if the above is tricky with eval_shape directly embedding sharding into abstract state leaves) is shown in the slides:
 1. abstract_model = nnx.eval_shape(...) for a non-sharded-at-init version.
 2. _graphdef, abstract_state_plain = nnx.split(abstract_model).
 3. Define sharding_specs (Pytree of PartitionSpec).
 4. abstract_target = jax.tree_util.tree_map(lambda x, spec: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=NamedSharding(mesh, spec)), abstract_state_plain, sharding_specs) OR use jax.lax.with_sharding_constraint(abstract_state_plain, sharding_specs) on the abstract state within a jax.jit and mesh context as shown in slide conceptual code. Let's try to make ShardedSimpleLinear work with eval_shape directly if possible, or fall back to explicit constraint.

4. To make ShardedSimpleLinear directly produce an abstract state with sharding during nnx.eval_shape:
* Modify ShardedSimpleLinear`.__init__`.
* When nnx.is_abstract_eval() is true, instead of jax.device_put(real_data, ...) use jax.ShapeDtypeStruct(shape, dtype, sharding=NamedSharding(mesh, pspec)) for the .value of the nnx.Param.
5. Call your function within the mesh context and jax.jit it (as per slides ) to get the abstract_target_state and graphdef_for_restore. graphdef_for_restore, abstract_target_state = nnx.split(nnx.eval_shape(lambda: ShardedSimpleLinear(..., mesh=mesh,...))) (simplified)
6. Restore using mngr.restore(step, args=ocp.args.StandardRestore(abstract_target_state)).
7. Reconstruct the model using nnx.merge(graphdef_for_restore, restored_sharded_state).
8. Verify the sharding of the restored model's parameters.
9. Close the manager.

Self-correction for instruction 4 & 5: Instead of modifying ShardedSimpleLinear to behave differently under nnx.is_abstract_eval(), it's cleaner and more aligned with typical Orbax/JAX patterns to:
a. Get a plain abstract state (shapes/dtypes only) from a version of the model that doesn't try to shard during abstract init.
b. Then, explicitly create the abstract_target by adding sharding to this plain abstract state.
Let's refine ShardedSimpleLinear to accept an init_sharded flag. For eval_shape, we'll pass init_sharded=False (or rely on nnx.eval_shape not creating real arrays), then apply sharding to the resulting abstract state.

A more direct approach for step 5, if the ShardedSimpleLinear from Ex5 is used for eval_shape: nnx.eval_shape will create ShapeDtypeStruct for parameters. If jax.device_put was part of the module's __init__, nnx.eval_shape might not execute it to produce sharded ShapeDtypeStructs directly. The critical part is that abstract_target passed to StandardRestore must have the sharding information.

Let's use the method from slide "Distributed Checkpointing: Restoring Sharded State":

1. abstract_model = nnx.eval_shape(lambda: ModelClass(...)) (ModelClass here doesn't apply sharding during this abstract init).
2. _graphdef, abstract_state_struct_only = nnx.split(abstract_model).
3. Define sharding_pytree (same Pytree structure as state, but with NamedSharding objects at leaves).
4. abstract_target = jax.tree.map(lambda s, n: jax.ShapeDtypeStruct(s.shape, s.dtype, sharding=n), abstract_state_struct_only, sharding_pytree). This abstract_target is then used in StandardRestore.

In [None]:
# Ensure ShardedSimpleLinear class definition and mesh from Ex5 are available.

# --- Re-open CheckpointManager for Sharded Restore ---
# TODO: Instantiate CheckpointManager for ckpt_dir_ex5
# mngr_sharded_restore = ...

# --- Create Abstract Target State with Sharding Information ---
# Method:
# 1. Create a "plain" abstract model (shapes/dtypes only).
# 2. Split it to get graphdef and plain abstract_state.
# 3. Define the desired sharding for each parameter (Pytree of NamedSharding).
# 4. Combine plain abstract_state with sharding to create the final abstract_target.

def create_abstract_model_for_sharded_restore():
    # This lambda should instantiate the model structure without applying sharding during this phase.
    # We'll use the ShardedSimpleLinear class, but its sharding logic inside __init__
    # might be skipped by eval_shape if it involves actual data.
    # Alternatively, provide a version of the model that takes sharding specs externally.
    # For simplicity, let's assume nnx.eval_shape on ShardedSimpleLinear gives us ShapeDtypeStructs,
    # and we will then OVERWRITE their sharding attribute if necessary, or construct them fresh.

    # Let's make a 'template' instance of ShardedSimpleLinear just to get its structure via split.
    # The actual sharding for the abstract target will be defined explicitly.
    temp_rngs = nnx.Rngs(params=jax.random.key(99))
    # Create an instance of ShardedSimpleLinear as it was defined in Ex5.
    # nnx.eval_shape will trace its construction.
    # TODO: abstract_model_proto = nnx.eval_shape(lambda: ShardedSimpleLinear(... pass din_s, dout_s, mesh from Ex5 ...))
    # abstract_model_proto = ...
    # return abstract_model_proto

# Run within mesh context for operations that might interact with sharding
# with mesh:
    # TODO: Create the abstract_model_proto by calling the function above.
    # abstract_model_for_target = create_abstract_model_for_sharded_restore()
    # TODO: Split it to get graphdef_for_restore and an abstract_state (which might have None for sharding)
    # graphdef_for_restore_sharded, abstract_state_struct_only = ...

    # Define the target sharding (PartitionSpecs, then NamedSharding)
    # These must match the sharding used when the checkpoint was SAVED.
    # weight_pspec_target = PartitionSpec(None, 'data') # As in Ex5
    # bias_pspec_target = PartitionSpec('data',)     # As in Ex5

    # weight_sharding_target = NamedSharding(mesh, weight_pspec_target)
    # bias_sharding_target = NamedSharding(mesh, bias_pspec_target)

    # Create the sharding pytree for the abstract_target
    # It needs to match the structure of abstract_state_struct_only['params'] or similar,
    # depending on how ShardedSimpleLinear structures its state.
    # Assuming state is flat { 'weight': ..., 'bias': ... } within the nnx.State object.
    # If ShardedSimpleLinear created params like self.weight = nnx.Param(...),
    # then abstract_state_struct_only will look like {'weight': {'value': ShapeDtypeStruct}, 'bias': {'value': ShapeDtypeStruct}}

    # TODO: Construct the `sharding_for_abstract_state` Pytree.
    # It should mirror the structure of `abstract_state_struct_only` but contain NamedSharding objects at the leaves
    # where parameters are.
    # Example if state is {'weight': {'value':...}, 'bias': {'value':...}}:
    # sharding_for_abstract_state = {
    #     'weight': {'value': weight_sharding_target},
    #     'bias': {'value': bias_sharding_target}
    # }
    # Verify this structure based on print(abstract_state_struct_only) from split.

    # TODO: Create the final abstract_target by combining shapes/dtypes with new sharding.
    # abstract_target_state = jax.tree.map(
    #    lambda sds, sh: jax.ShapeDtypeStruct(sds.shape, sds.dtype, sharding=sh) if isinstance(sds, jax.ShapeDtypeStruct) else sds,
    #    abstract_state_struct_only,
    #    sharding_for_abstract_state
    # )
    # print(f"Abstract target for restore (bias sharding): {abstract_target_state['bias'].value.sharding}")

# --- Restore Sharded State ---
# step_to_restore_sharded = mngr_sharded_restore.latest_step()
# if step_to_restore_sharded is not None:
    # with mesh: # Restoration happens within the mesh context
        # TODO: Restore sharded state using abstract_target_state
        # restored_sharded_state_dict = mngr_sharded_restore.restore(...)

        # TODO: Reconstruct the model using nnx.merge
        # reconstructed_sharded_model = ...

    # print(f"Sharded model restored from step {step_to_restore_sharded}.")
    # print(f"Restored weight sharding: {reconstructed_sharded_model.weight.value.sharding}")
    # print(f"Restored bias sharding: {reconstructed_sharded_model.bias.value.sharding}")

    # Verification (optional): Compare with sharded_model from Ex5 if it's in scope and has same structure
    # chex.assert_trees_all_equal_shapes_and_dtypes(nnx.state(reconstructed_sharded_model), nnx.state(sharded_model))
    # assert str(reconstructed_sharded_model.weight.value.sharding) == str(sharded_model.weight.value.sharding)

# else:
#    print("No sharded checkpoint found to restore.")

# mngr_sharded_restore.close()

In [None]:
# @title Exercise 6: Solution

# Ensure ShardedSimpleLinear class definition and mesh from Ex5 are available.
# din_s, dout_s from Ex5 were: din_s = 8, dout_s = num_devices * 2

# --- Re-open CheckpointManager for Sharded Restore ---
mngr_sharded_restore = ocp.CheckpointManager(ckpt_dir_ex5)

# --- Create Abstract Target State with Sharding Information ---
# This follows the principle that the abstract target for restore must contain sharding info

def create_abstract_model_for_sharded_restore_eval_shape():
  # This lambda is for nnx.eval_shape. It should define the *structure*
  # ShardedSimpleLinear's __init__ from Ex5 already creates sharded JAX arrays.
  # nnx.eval_shape will trace this. The resulting abstract state's leaves
  # should be ShapeDtypeStructs that already reflect the sharding
  # because jax.device_put (which includes sharding) is part of its traced __init__.
  # This is a more integrated way if the module's __init__ handles sharding for abstract eval.
  temp_rngs_for_eval = nnx.Rngs(params=jax.random.key(100)) # Dummy key for eval_shape
  # Pass the same mesh instance that will be used for restoration
  return ShardedSimpleLinear(din=din_s, dout=dout_s, mesh=mesh, rngs=temp_rngs_for_eval)

with mesh: # Operations like eval_shape and restore should be within the mesh context
  # Create abstract model using nnx.eval_shape.
  # The sharding info should ideally be embedded by ShardedSimpleLinear's __init__
  # when traced by nnx.eval_shape, because it uses jax.device_put.
  abstract_model_sharded_eval = nnx.eval_shape(create_abstract_model_for_sharded_restore_eval_shape)
  # Use the graphdef from the abstract sharded model for merging
  graphdef_for_restore_sharded = nnx.split(abstract_model_sharded_eval)[0]

  # We need the abstract state structure from the plain model (SimpleLinear)
  # because nnx.eval_shape on ShardedSimpleLinear might already put sharding
  # in the abstract state, and we want to demonstrate the manual creation
  # of the abstract target with sharding.
  plain_abstract_model = nnx.eval_shape(lambda: SimpleLinear(din_s, dout_s, rngs=nnx.Rngs(0)))
  # This state will have ShapeDtypeStructs, but likely with sharding=None
  _gdef_plain, abstract_state_struct_only = nnx.split(plain_abstract_model)

  # Define target sharding specs
  weight_pspec_target = PartitionSpec(None, 'data') # As in Ex5
  bias_pspec_target = PartitionSpec('data',)     # As in Ex5
  weight_sharding_target = NamedSharding(mesh, weight_pspec_target)
  bias_sharding_target = NamedSharding(mesh, bias_pspec_target)

  # Create the sharding pytree for the abstract_target
  # It needs to match the structure of `abstract_state_struct_only` exactly.
  # Since abstract_state_struct_only is {'bias': {'value':ShapeDtypeStruct}, 'weight': {'value':ShapeDtypeStruct}},
  # the sharding pytree should mirror this structure, placing NamedSharding at the leaves.
  sharding_pytree_for_target = nnx.State({
    'weight': nnx.VariableState(type=nnx.Param, value=weight_sharding_target),
    'bias': nnx.VariableState(type=nnx.Param, value=bias_sharding_target)
  })

  # Create the final abstract_target by mapping over the structure of
  # abstract_state_struct_only and sharding_pytree_for_target.
  # We want to replace the ShapeDtypeStruct in abstract_state_struct_only.value
  # with a new ShapeDtypeStruct that includes the sharding from sharding_pytree_for_target.value.

  # Define a function that takes two VariableState objects
  def update_variable_state_sharding(sds_variable_state: nnx.VariableState, sharding_variable_state: nnx.VariableState):
    if isinstance(sds_variable_state, jax.ShapeDtypeStruct):
      # Create a new ShapeDtypeStruct with the desired sharding
      new_sds = jax.ShapeDtypeStruct(sds_variable_state.shape, sds_variable_state.dtype, sharding=sharding_variable_state)
      # Return a new VariableState with the updated value
      return new_sds
    else:
      # If the value is not a ShapeDtypeStruct (e.g., metadata), keep it as is
      # In this specific case, this path might not be strictly needed if abstract_state_struct_only
      # only contains VariableState with ShapeDtypeStruct values at the leaves we care about.
      return sds_variable_state

  # Map this function over the two pytrees. Use a custom is_leaf to map at the VariableState level.
  # This ensures the mapping function receives (VariableState, VariableState containing sharding) pairs.
  def is_variable_state_node(x):
    # Treat VariableState itself as a node (not a leaf) so mapping happens inside it
    return not isinstance(x, nnx.VariableState)

  # Apply the mapping. The lambda receives items from corresponding positions in both trees.
  # Here, lambda `sds_node` is a VariableState from `abstract_state_struct_only`,
  # and lambda `sharding_node` is a VariableState from `sharding_pytree_for_target`.
  abstract_target_state = jax.tree.map(
    update_variable_state_sharding,
    abstract_state_struct_only, # This tree has ShapeDtypeStructs nested in VariableState.value
    sharding_pytree_for_target # This tree has NamedSharding objects nested in VariableState.value
  )

  print(f"Abstract target for restore (bias sharding): {abstract_target_state['bias'].value.sharding}")
  print(f"Abstract target for restore (weight sharding): {abstract_target_state['weight'].value.sharding}")

# --- Restore Sharded State ---
step_to_restore_sharded = mngr_sharded_restore.latest_step()
if step_to_restore_sharded is not None:
  with mesh: # Restoration happens within the mesh context
    # Use StandardRestore with the abstract_target that includes sharding info
    restored_sharded_state_dict = mngr_sharded_restore.restore(
        step_to_restore_sharded,
        args=ocp.args.StandardRestore(abstract_target_state)
    )

    # Reconstruct the model using nnx.merge
    # Use the graphdef obtained from splitting the abstract sharded model
    reconstructed_sharded_model = nnx.merge(graphdef_for_restore_sharded, restored_sharded_state_dict)

  print(f"Sharded model restored from step {step_to_restore_sharded}.")
  print(f"Restored weight sharding: {reconstructed_sharded_model.weight.value.sharding}")
  print(f"Restored bias sharding: {reconstructed_sharded_model.bias.value.sharding}")

  # Verification
  if 'sharded_model' in globals(): # If sharded_model from Ex5 is available
    # Compare structure and dtypes
    chex.assert_trees_all_equal_structs(nnx.state(reconstructed_sharded_model), nnx.state(sharded_model))
    # Compare sharding
    assert str(reconstructed_sharded_model.weight.value.sharding) == str(sharded_model.weight.value.sharding)
    assert str(reconstructed_sharded_model.bias.value.sharding) == str(sharded_model.bias.value.sharding)
    # Compare values - this will involve communication due to sharding
    chex.assert_trees_all_close(nnx.state(reconstructed_sharded_model), nnx.state(sharded_model))
    print("Verification of sharding, structure, and values successful.")
else:
   print("No sharded checkpoint found to restore.")

mngr_sharded_restore.close()

## Advanced Orbax Features & Best Practices (Brief Overview)

There are also some more advanced Orbax features. While we won't do full coding exercises for these in this notebook, it's good to be aware of them:

* **Asynchronous Checkpointing**: manager.save() can operate in the background. Use manager.wait_until_finished() before your program exits or if you need to use the checkpoint immediately. This improves training throughput by not blocking the main training loop. Our examples used wait_until_finished().

* **Atomicity**: CheckpointManager ensures that checkpoints are saved atomically. This means you won't get corrupted checkpoints if your training job crashes mid-save. This is handled for you by Orbax.

* **Saving Non-Pytree Data (Metadata)**: Sometimes you need to save extra information like training configuration, dataset iterators, or model version. You can use ocp.args.JsonSave within ocp.args.Composite to save dictionary-like data as JSON alongside your model Pytrees. Restoration uses ocp.args.JsonRestore.

### Example Concept:

```
metadata = {'version': '1.0', 'dataset_info': 'imagenet_split_train'}
save_args = ocp.args.Composite(
  params=ocp.args.StandardSave(params_state),
  metadata=ocp.args.JsonSave(metadata)
)
mngr.save(step, args=save_args)
```

* **TensorStore Backend**: For extremely large models or when working with cloud storage, Orbax can use TensorStore. This backend allows for more efficient, potentially parallel I/O for individual array shards, often transparently. This is usually configured at a lower level or might be default in certain JAX environments.

### Key Takeaways:

* Flax NNX offers a stateful, Pythonic way to define models.
* Orbax is the standard for checkpointing NNX State Pytrees.
* The general workflow:
 - **Saving**: nnx.split -> mngr.save.
 - **Restoring**: nnx.eval_shape -> Get abstract_state -> mngr.restore -> nnx.merge or nnx.update.
* CheckpointManager is your friend for managing multiple checkpoints.
* Use ocp.args.Composite for saving multiple distinct items (e.g., model parameters + optimizer state).
* For sharded (distributed) data, ensuring your abstract_target for restoration correctly specifies the target sharding is crucial. StandardRestore handles this if the abstract target has the sharding info.

### Congratulations!
You've now worked through the fundamentals of checkpointing Flax NNX models with Orbax, from basic saving and restoring to handling optimizer states and distributed (sharded) scenarios.

Remember to consult the official documentation for more in-depth details:

* Orbax: https://orbax.readthedocs.io
* Flax NNX: (Part of the Flax documentation) https://flax.readthedocs.io
* JAX: https://jax.readthedocs.io
Keep practicing, and happy JAXing!

Please send us feedback at https://goo.gle/jax-training-feedback