-
Notifications
You must be signed in to change notification settings - Fork 323
Description
Sorry this issue is generated by AI and I hope you can get my point. If you have any problem please reply me.
Summary
When restoring Brax PPO (and likely SAC and other agents) training from a checkpoint using restore_checkpoint_path, the training performance consistently degrades initially before recovering. Investigation reveals that optimizer state (Adam momentum and statistics) is not saved or restored, causing the optimizer to restart from scratch while using trained network weights.
Steps to Reproduce
- Train a PPO agent for N steps:
from brax.training.agents.ppo import train as ppo
train_fn = ppo.train(
environment=env,
num_timesteps=1_000_000,
episode_length=1000,
num_envs=2048,
# ... other config
)
make_inference_fn, params, metrics = train_fn(
environment_fn=env_fn,
progress_fn=progress_fn,
)
# Save checkpoint
from brax.training.agents.ppo import checkpoint
config = checkpoint.network_config(
observation_size=env.observation_size,
action_size=env.action_size,
normalize_observations=True,
)
checkpoint.save('checkpoint_path', step=1000, params=params, config=config)- Restore training from the checkpoint:
train_fn = ppo.train(
environment=env,
num_timesteps=2_000_000,
restore_checkpoint_path='checkpoint_path/000000001000', # Restore from checkpoint
# ... same config
)
make_inference_fn, params, metrics = train_fn(...)- Observe that training metrics (reward, value loss, policy loss) show degradation immediately after restoration before gradually recovering.
Expected Behavior
When restoring from a checkpoint, training should continue seamlessly with similar performance characteristics as if training had never stopped. The optimizer should maintain its momentum and adaptive learning rate statistics.
Actual Behavior
Training performance degrades after checkpoint restoration:
- Lower rewards initially
- Higher losses (value, policy)
- Training needs time to "re-converge" to previous performance levels
- Optimizer appears to be starting from scratch despite network weights being restored
Root Cause Analysis
Current Checkpoint Format
File: brax/training/agents/ppo/train.py (lines 739-742)
The checkpoint saves only a 3-element tuple:
params = (
training_state.normalizer_params, # [0] Observation normalization
training_state.params.policy, # [1] Policy network
training_state.params.value, # [2] Value network
)TrainingState Structure
File: brax/training/agents/ppo/train.py (lines 105-109)
@flax.struct.dataclass
class TrainingState:
"""Contains training state for the learner."""
optimizer_state: optax.OptState
params: PPONetworkParams
normalizer_params: running_statistics.RunningStatisticsState
env_steps: types.UInt64Notice that TrainingState has 4 fields, but checkpoints only save data for 2 of them (normalizer_params and params).
Missing Components During Restoration
File: brax/training/agents/ppo/train.py (lines 715-733)
When restoring, the optimizer state is not restored:
if restore_checkpoint_path is not None:
params = checkpoint.load(restore_checkpoint_path)
# params only contains the 3 network components: [normalizer, policy, value]
value_params = params[2] if restore_value_fn else init_params.value
training_state = training_state.replace(
normalizer_params=params[0],
params=training_state.params.replace(
policy=params[1],
value=value_params
),
# optimizer_state is NOT restored - remains freshly initialized!
# env_steps is NOT restored - remains at 0!
)The optimizer_state and env_steps remain as initialized on lines 707-711:
training_state = TrainingState(
optimizer_state=optimizer.init(init_params), # Fresh optimizer state
params=init_params,
normalizer_params=running_statistics.init_state(...),
env_steps=types.UInt64(hi=0, lo=0), # Reset to zero
)What Gets Lost
- Adam optimizer momentum: First moment estimates (m) reset to zero
- Adam optimizer variance: Second moment estimates (v) reset to zero
- Optimizer step count: Affects learning rate schedules and adaptive algorithms
- Environment step counter (
env_steps): Training appears to restart from step 0 - Gradient history: All accumulated gradient statistics lost
Impact
The trained network weights are restored correctly, but the optimizer has no memory of:
- Recent gradient directions (momentum)
- Gradient magnitude patterns (adaptive learning rates)
- Training progress (step counts for schedules)
This creates a mismatch: experienced network weights being updated by a naive optimizer, leading to suboptimal updates and temporary performance degradation.
Components Affected
Algorithms with Same Issue
This pattern appears in multiple Brax agents:
- PPO (
brax/training/agents/ppo/train.py) - saves only (normalizer, policy, value) - SAC (
brax/training/agents/sac/train.py) - likely similar pattern - Other agents - need verification
Network Components in PPO
- Policy network (actor)
- Value network (critic/baseline)
Both networks have their weights restored but lose optimizer state.
Proposed Solution
Option 1: Extended Checkpoint Format (Recommended)
Extend the checkpoint to include optimizer state and training metadata:
# In brax/training/agents/ppo/train.py, modify the return value (lines 739-742)
params = (
training_state.normalizer_params, # [0]
training_state.params.policy, # [1]
training_state.params.value, # [2]
training_state.optimizer_state, # [3] NEW: Optimizer state
training_state.env_steps, # [4] NEW: Training step counter
)
# In restoration (lines 715-733)
if restore_checkpoint_path is not None:
params = checkpoint.load(restore_checkpoint_path)
# Handle both old (3-element) and new (5-element) checkpoint formats
if len(params) >= 5:
optimizer_state = params[3]
env_steps = params[4]
else:
# Backward compatibility: old checkpoints without optimizer state
optimizer_state = optimizer.init(init_params)
env_steps = types.UInt64(hi=0, lo=0)
warnings.warn(
"Checkpoint missing optimizer state. Training will resume with "
"fresh optimizer, which may cause temporary performance degradation. "
"Consider retraining from scratch or accepting the initial performance drop."
)
value_params = params[2] if restore_value_fn else init_params.value
training_state = training_state.replace(
normalizer_params=params[0],
params=training_state.params.replace(
policy=params[1],
value=value_params
),
optimizer_state=optimizer_state, # Restore optimizer state
env_steps=env_steps, # Restore step counter
)Option 2: Save Entire TrainingState
Save and restore the entire TrainingState object instead of a tuple:
# In brax/training/agents/ppo/train.py (lines 739-742)
# Instead of returning a tuple, return the entire TrainingState
params = training_state # Save complete state
# In checkpoint.py
def save(path, step, training_state, config):
"""Save complete training state including optimizer."""
# Use Orbax to save the entire TrainingState pytree
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(training_state)
orbax_checkpointer.save(ckpt_path, training_state, force=True, save_args=save_args)
# ... save config as before
def load(path):
"""Load complete training state."""
orbax_checkpointer = ocp.PyTreeCheckpointer()
training_state = orbax_checkpointer.restore(path, ...)
return training_stateOption 3: Add Explicit save_optimizer_state Flag
Allow users to optionally save optimizer state:
def train(
environment: Env,
# ... existing params ...
save_optimizer_state: bool = True, # NEW parameter
):
# ... training loop ...
if save_optimizer_state:
params = (
training_state.normalizer_params,
training_state.params.policy,
training_state.params.value,
training_state.optimizer_state,
training_state.env_steps,
)
else:
params = (
training_state.normalizer_params,
training_state.params.policy,
training_state.params.value,
)Backward Compatibility
All solutions should handle old checkpoints gracefully:
- Detect checkpoint format by tuple length or structure
- Fall back to current behavior with a clear warning
- Document the change in release notes
- Provide migration guide for existing checkpoints
Additional Context
Comparison with Other RL Frameworks
Most modern RL frameworks save optimizer state by default:
- Stable Baselines3: Saves optimizer state in
.zipfiles - CleanRL: Can optionally save optimizer state
- RLlib: Checkpoints include optimizer state
- PyTorch Lightning: Automatically saves optimizer state
- TF-Agents: Saves optimizer state in checkpoints
Brax is currently an outlier in this regard.
Current Workarounds
Users experiencing this issue can:
- Accept the temporary performance degradation after restoration (usually recovers within a few update steps)
- Use a lower learning rate when restoring to reduce instability
- Implement custom checkpoint saving that includes optimizer state
- Retrain from scratch instead of using checkpoints (not ideal for long training runs)
Performance Impact
The performance degradation can be significant:
- Training may take 10-50% longer to return to pre-checkpoint performance
- In some cases, training may diverge if learning rate is too high
- Wasted compute resources during the recovery period
Why This Matters for Long Training Runs
For researchers running multi-day or multi-week training:
- Need to checkpoint frequently to handle preemption/failures
- Optimizer state becomes increasingly important as training progresses
- Loss of momentum can significantly impact convergence in later stages
Questions for Maintainers
- Was there a specific reason optimizer state was excluded from checkpoints (e.g., storage concerns, simplicity)?
- Would you prefer Option 1 (extended tuple), Option 2 (TrainingState), or Option 3 (optional flag)?
- Should this be applied to all Brax training agents consistently?
- Is there a preference for checkpoint versioning strategy?
- Would you like a migration tool to convert old checkpoints to the new format?
Files Involved
brax/training/agents/ppo/train.py(main training loop, restoration logic)brax/training/agents/ppo/checkpoint.py(PPO-specific checkpoint wrapper)brax/training/checkpoint.py(base checkpoint utilities)brax/training/types.py(TrainingState and other type definitions)- Similar files in other agent directories (SAC, etc.)
Willing to Contribute
I am willing to submit a PR to fix this issue if the maintainers can provide guidance on:
- Preferred solution approach
- Which agents should be updated (all or just PPO initially?)
- Testing requirements
- Documentation updates needed
Related Issues
- Are there existing issues tracking this behavior?
- Have other users reported performance degradation after checkpoint restoration?