Skip to content

CWibault/mfax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

⚡️ MFAX: Mean-Field Games in JAX ⚡️

Environment Implementations

Mean-Field Updates:

  • Analytic/Push-forward: Exact mean-field update using white-box access to the individual state transition dynamics.
  • Sample-based: Approximate mean-field update by sampling individual trajectories.

Environments are implemented by creating:

  1. A base class, specifying the deterministic next individual state of an agent given its action, its current individual state and the current aggregate state.
  2. A mean-field update wrapper that is either a push-forward class or a sample-based class, which handles updating the mean-field distribution. Push-forward environments only support shared observations of the aggregate state, while sample-based environments support individual observations of the aggregate state. Push-forward environments update the mean-field distribution and calculate the expectation over next states by using a functional representation of pre-multiplying by the (transpose of the) A matrix.

Currently, all environments are implemented with both a push-forward and sample-based mean-field update. As mentioned in our paper, push-forward updates are only possible if the state space is small enough and observations of the aggregate state are independent of the individual state.

Environment Base Class Push-forward Wrapper Sample-based Wrapper
Linear Quadratic base_lq pushforward_lq sample_lq
Beach Bar 1D base_bb pushforward_bb sample_bb
Macroeconomics base_macro pushforward_macro sample_macro

Algorithm Implementations

We provide two types of algorithm implementations:

  1. Hybrid Structural Methods (HSMs), which compute the exact expectation over individual states by assuming white-box access to the individual-state transition model. HSMs work with a push-forward mean-field environment implementation.
  2. Reinforcement Learning (RL), which does not assume access to the individual-state transition model. RL works with a sample-based mean-field environment implementation.

All scripts use Weights & Biases for logging. Hyper-parameter sweep files are provided in the configs/ folder. After training, final models are saved to .pkl files in runs/.

Algorithm Method Type Name Script Citations
SPG HSM Structural Policy Gradient spg.py Yang2025
RSPG HSM Recurrent Structural Policy Gradient rspg.py -
M-OMD RL Deep Munchausen Online Mirror Descent m_omd.py Lauriere2022, Wu2025
IPPO RL Independent Proximal Policy Optimisation ippo.py Schulman2017, SchroederdeWitt2020, Algumaei2023
RIPPO RL Recurrent Independent Proximal Policy Optimisation rippo.py Ni2022

Evaluation

We use exploitability to approximate the distance from a Nash equilibrium. Since currently all sampled-based environments also have a push-forward environment implementation, we also use backwards induction to approximate the exploitability for the sample-based environments. We have not yet implemented approximate exploitability metrics for purely sampled-based environments.

Currently, like the HSM scripts, all RL scripts save mean-field agents, which determine action distributions for each individual state in the mean-field. For history-aware policies (e.g. RIPPO), this is only possible because the single-agent policy is structured in exactly the same way as the mean-field policy (i.e. only retaining a history of shared observations). If single-agent policies were to condition on individual histories, it would not be possible to restructure the single-agent policy as a mean-field policy.

The exploitability tuple returned includes an evaluation data-class (composed of the exploitability metrics, the mean-field trajectory generated by all agents following that policy, the policy discounted returns, and the best-response discounted returns) and an array of exploitabilities for each individual state. Following the literature, exploitability is calculated by weighting according to the initial mean-field distribution.

Evaluating a saved mean-field agent:

from mfax.algos.hsm.utils.make_act import MFActorWrapper, MFRecurrentActorWrapper
from mfax.envs import make_env
from mfax.algos.hsm.exploitability import exploitability

# --- load the mean-field agent from the runs directory ---
task = "beach_bar_1d"
algo = "rspg"
run = task + "/" + algo
mf_agent = load_pkl(f"runs/{run}/mf_agent_wrapper.pkl")
with open(f"runs/{run}/args.json", "r") as f:
    train_args = json.load(f)

# --- instantiate a push-forward environment to evaluate the exploitability --- 
env = make_env("pushforward/" + task)

# --- calculate the exploitability: compare the best-response to the learned policy ---
results = exploitability(
    jax.random.PRNGKey(0),                                  # seed
    env,
    mf_agent,
    state_type=args["state_type"],                          # either raw states, or indices representing each state 
    gamma=args["discount_factor"],                          # environment discount factor 
    num_envs=8,                                             # number of evaluation environments
    max_steps_in_episode=env.params.max_steps_in_episode,   # maximum steps in episode
    )

# --- access results ---
eval_results = results[0] 
exploitability = exploitability.exploitability
mean_policy_disc_return = eval_results.exploitability.mean_policy_return  
mean_br_disc_return = eval_results.exploitability.mean_br_return       

Quick-Start!

  1. Clone the repository using git clone https://github.com/CWibault/mfax.git. Optionally add a Weights & Biases key to the dev folder (or ensure that log is set to False).
  2. Build the Docker image:
    cd dev
    bash build.sh
  3. Launch the Docker container:
     cd ..
     bash launch_container.sh 0  # If using GPU 0
  4. Run a script:
    python3.10 -m mfax.algos.hsm.algos.rspg

Contributing

Please contribute! Ideas include:

  1. Implementing more complex environments such as ones with individual agent observation functions or threshold-based reward functions.
  2. A third mean-field update wrapper supporting function approximation for the mean-field push-forward distribution.

Acknowledgements

Our implementations of M-OMD, IPPO and RIPPO build off PureJaxRL's DQN, PPO and Recurrent PPO implementations.

Citations

If you use MFAX in your work, please cite the following paper! :

@misc{wibault2026recurrent,
      title={Recurrent Structural Policy Gradient for Partially Observable Mean Field Games},
      author={Clarisse Wibault and Sebastian Towers and Tiphaine Wibault and Juan Duque and Johannes Forkel and George Whittle and Andreas Schaab and Yucheng Yang and Chiyuan Wang and Michael Osborne and Benjamin Moll and Jakob Foerster},
      year={2026},
}

About

Mean-Field Games in JAX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages