Ignat Georgiev, Varun Giridhar, Nicklas Hansen, Animesh Garg
Project website Paper Models & Datasets
This repository is a soft fork of FoRL.
We introduce Policy learning with large World Models (PWM), a novel Model-Based RL (MBRL) algorithm and framework aimed at deriving effective continuous control policies from large, muti-task world models. We utilize pre-trained TD-MPC2 world models to efficiently learn control policies with first-order gradients in <10m per task. Our empirical evaluations on complex locomotion tasks indicate that PWM not only achieves higher reward than baselines but also outperforms methods that use ground-truth simulation dynamics.
Tested only on Ubuntu 22.04. Requires Python, conda and an Nvidia GPU with >24GB VRAM.
git clone --recursive git@github.com:imgeorgiev/PWM.git
cd PWM
conda env create -f environment.yaml
conda activate pwm
ln -s $CONDA_PREFIX/lib $CONDA_PREFIX/lib64 # hack to get CUDA to work inside conda
pip install -e .
pip install -e external/tdmpc2
The first option for running PWM is on complex single-tasks with up to 152 action dimensions in the Dflex simulator. These runs used pre-trained world models which can be downloaded from hugging face.
cd scripts
conda activate pwm
python train_dflex.py env=dflex_ant alg=pwm general.checkpoint=path/to/model
Due to the nature of GPU acceleration, it is impossible to currently impossible to guarantee deterministic experiments. You can make them "less random" by using
seeding(seed, True)
but that slows down GPUs.
Instead of loading a pre-trained world model, you pretrain one yourself using the data:
cd scripts
conda activate pwm
python train_dflex.py env=dflex_ant alg=pwm general.pretrain=path/to/model pretrain_steps=XX
To recreate results from the original paper:
Task | Pretrain gradient steps |
---|---|
Hopper | 50_000 |
Ant | 100_000 |
Anymal | 100_000 |
Humanoid | 200_000 |
SNU Humanoid | 200_000 |
We evaluate on the MT30 and MT80 task settings proposed by TD-MPC2.
- Download the data for each task following the TD-MPC2 instructions.
- Train a world model from the TD-MPC2 repository using the settings below. Note that
horizon=16
andrho=0.99
are crucial. Note that training takes ~2 weeks on an RTX 3900. Alternatively, you can also use some of the pre-trained multi-task world models we provide.
cd external/tdmpc2/tdmpc2
python -u train.py task=mt30 model_size=48 horizon=16 batch_size=1024 rho=0.99 mpc=false disable_wandb=False data_dir=path/to/data
where path/to/data
is the full TD-MPC2 dataset for either MT30 or MT80.
Train a policy for a specific task using the pre-trained world model
cd scripts
python train_multitask.py -cn config_mt30 alg=pwm_48M task=pendulum-swingup general.data_dir=path/to/data general.checkpoint=path/to/model
- where
path/to/data
is the full TD-MPC2 dataset for either MT30 or MT80. - where
path/to/model
is the pre-trained world model as provided here.
We also provide scripts which launch slurm tasks across all tasks. scripts/mt30.bash
and scripts/mt80.bash
cfg
├── alg
│ ├── pwm_19M.yaml - different sized PWM models which the main models that should be used. Paired with train_multitask.py
│ ├── pwm_317M.yaml - to be used with train_multitask.py
│ ├── pwm_48M.yaml
│ ├── pwm_5M.yaml
│ ├── pwm.yaml - redunant but provided for reproducability; to be run with train_dflex.py
│ └── shac.yaml - works only with train_dflex.py
├── config_mt30.yaml - to be used with train_multitask.py
├── config_mt80.yaml - to be used with train_multitask.py
├── config.yaml - to be used with train_dflex.py
└── env - dflex env config files
├── dflex_ant.yaml
├── dflex_anymal.yaml
├── dflex_cartpole.yaml
├── dflex_doublependulum.yaml
├── dflex_hopper.yaml
├── dflex_humanoid.yaml
└── dflex_snu_humanoid.yaml
@misc{georgiev2024pwm,
title={PWM: Policy Learning with Large World Models},
author={Ignat Georgiev, Varun Giridha, Nicklas Hansen, and Animesh Garg},
eprint={2407.02466},
archivePrefix={arXiv},
primaryClass={cs.LG},
year={2024}
}