In [24]:
import jax
import jax.numpy as np
from jax import random
import math
from typing import Callable
import os
from flax.training import train_state
import orbax.checkpoint
import optax


In [25]:
try:
    import flax
except ModuleNotFoundError:  # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn

In [30]:
# A simple model with one linear layer.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))      # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)

# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001)      # An Optax SGD optimizer.
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))

# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}

# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
ckpt

{'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     promote_dtype = promote_dtype
     dot_general = None
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[-0.14771505, -0.21384335,  0.11710571],
        [-0.45254678, -0.0309354 ,  0.42556357],
        [-0.15834738, -0.47992307,  0.4152557 ],
        [-0.72053933, -0.52988744, -0.44720745],
        [-0.27621508, -0.25154397,  0.03122341]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x12ab07ce0>, update=<function chain.<locals>.update_fn at 0x12ab07060>), opt_state=(EmptyState(), EmptyState())),
 'config': {'dimensions': Array([5, 3], dtype=int32)},
 'data': [Array([ 1.0040143, -0.9063372, -0.7481722, -1.1713669, -0.8712328],   

In [31]:
from flax.training import orbax_utils

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
path = os.path.abspath('./tmp/flax_ckpt/orbax/single_save')
orbax_checkpointer.save(path, ckpt, save_args=save_args)



In [32]:
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
path2 = os.path.abspath('./tmp/flax_ckpt/orbax/managed')
checkpoint_manager = orbax.checkpoint.CheckpointManager(
    path2, orbax_checkpointer, options)

# Inside a training loop
for step in range(5):
    # ... do your training
    checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args})

os.listdir(path2)  # Because max_to_keep=2, only step 3 and 4 are



['0', '4', '3']