MAX is a highly experimental and rapidly evolving modular reinforcement learning library built on JAX. It is primarily designed to prioritize online adaptation algorithms and information-gathering strategies in reinforcement learning, with a focus on both model-based and model-free control, and first-class support for multi-agent systems.
- Pure JAX Implementation: Leverage JIT compilation, automatic differentiation, and GPU/TPU acceleration for fast iteration.
- Emphasis on Online Adaptation: Core design centers around algorithms and components for efficient adaptation to changing or uncertain dynamics.
- Model-Based Algorithms with Parameter Belief: Supports model-based control where the dynamics components maintain a distribution or belief over uncertain parameters (e.g., in a Bayesian context).
- Multi-Agent RL: Built-in support for IPPO (Independent PPO) and multi-agent environments.
- Modular Design: Mix and match components (environments, policies, trainers, normalizers) for rapid prototyping of novel online algorithms.
git clone <repository-url>
cd max
pip install -e .environments: Multi-agent tracking and pursuit-evasion environmentsdynamics: Learned dynamics models (MLP-based, analytical models)policies: Actor-critic policies and model-based plannerspolicy_trainers: PPO and IPPO training algorithmstrainers: Dynamics model training (gradient descent, EKF, PETS)normalizers: State/action/reward normalization utilitiesbuffers: JAX-based replay buffers for efficient data storageplanners: Model-based planning algorithms (CEM, iCEM)policy_evaluators: Policy evaluation and rollout utilitiesevaluation: Dynamics model evaluation metrics
estimators: Extended Kalman Filter for online Bayesian optimization
-
scripts/ippo_pe.py: Train IPPO agents on pursuit-evasion task -
scripts/visualize_pe.py: Visualize trained policies
scripts/ippo_tracking.py: Train IPPO agents for goal trackingscripts/visualize_tracking.py: Visualize trained tracking policies
All components follow JAX's functional programming paradigm:
- Immutable state containers (NamedTuples, PyTreeNodes)
- Pure functions for transformations
- JIT-compiled operations for performance
The library is designed with multi-agent systems as a first-class citizen:
- Independent parameter sets per agent
- Shared or separate training
- Flexible observation/action spaces
Mix and match components easily:
# Use model-based planner as policy
policy = create_planner_policy(planner, dynamics_model)
# Or use model-free actor-critic
policy = create_actor_critic_policy(config)
# Same trainer interface for both!
trainer = init_policy_trainer(config, policy)MIT License
