Skip to content

conglu1997/SynthER

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Synthetic Experience Replay

Twitter arXiv

Synthetic Experience Replay (SynthER) is a diffusion-based approach to arbitrarily upsample an RL agent's collected experience, leading to large gains in sample efficiency and scaling benefits. We integrate SynthER into a variety of offline and online algorithms in this codebase, including SAC, TD3+BC, IQL, EDAC, and CQL. For further details, please see the paper:

Synthetic Experience Replay; Cong Lu*, Philip J. Ball*, Yee Whye Teh, Jack Parker-Holder. Published at NeurIPS, 2023.

View on arXiv

Setup

To install, clone the repository and run the following:

git submodule update --init --recursive
pip install -r requirements.txt

The code was tested on Python 3.8 and 3.9. If you don't have MuJoCo installed, follow the instructions here: https://github.com/openai/mujoco-py#install-mujoco.

Running Instructions

Offline RL

Diffusion model training (this automatically generates samples and saves them):

python3 synther/diffusion/train_diffuser.py --dataset halfcheetah-medium-replay-v2

Baseline without SynthER (e.g. on TD3+BC):

python3 synther/corl/algorithms/td3_bc.py --config synther/corl/yaml/td3_bc/halfcheetah/medium_replay_v2.yaml --checkpoints_path corl_logs/

Offline RL training with SynthER:

# Generating diffusion samples on the fly.
python3 synther/corl/algorithms/td3_bc.py --config synther/corl/yaml/td3_bc/halfcheetah/medium_replay_v2.yaml --checkpoints_path corl_logs/ --name SynthER --diffusion.path path/to/model-100000.pt

# Using saved diffusion samples.
python3 synther/corl/algorithms/td3_bc.py --config synther/corl/yaml/td3_bc/halfcheetah/medium_replay_v2.yaml --checkpoints_path corl_logs/ --name SynthER --diffusion.path path/to/samples.npz

Online RL

Baselines (SAC, REDQ):

# SAC.
python3 synther/online/online_exp.py --env quadruped-walk-v0 --results_folder online_logs/ --exp_name SAC --gin_config_files 'config/online/sac.gin'

# REDQ.
python3 synther/online/online_exp.py --env quadruped-walk-v0 --results_folder online_logs/ --exp_name REDQ --gin_config_files 'config/online/redq.gin'

SynthER (SAC):

# DMC environments.
python3 synther/online/online_exp.py --env quadruped-walk-v0 --results_folder online_logs/ --exp_name SynthER --gin_config_files 'config/online/sac_synther_dmc.gin' --gin_params 'redq_sac.utd_ratio = 20' 'redq_sac.num_samples = 1000000'

# OpenAI environments (different gin config).
python3 synther/online/online_exp.py --env HalfCheetah-v2 --results_folder online_logs/ --exp_name SynthER --gin_config_files 'config/online/sac_synther_openai.gin' --gin_params 'redq_sac.utd_ratio = 20' 'redq_sac.num_samples = 1000000'

Thinking of adding SynthER to your own algorithm?

Our codebase has everything you need for diffusion with low-dimensional data along with example integrations with RL algorithms. For a custom use-case, we recommend starting from the training script and SimpleDiffusionGenerator class in synther/diffusion/train_diffuser.py. You can modify the hyperparameters specified in config/resmlp_denoiser.gin to suit your own needs.

Additional Notes

  • Our codebase uses wandb for logging, you will need to set --wandb-entity across the repository.
  • Our pixel-based experiments are based on a modified version of the V-D4RL repository. The latent representations are derived from the trunks of the actor and critic.

Acknowledgements

SynthER builds upon many works and open-source codebases in both diffusion modelling and reinforcement learning. We would like to particularly thank the authors of:

Contact

Please contact Cong Lu or Philip Ball for any queries. We welcome any suggestions or contributions!