This is the official PyTorch code implementation for "BC-IRL: Learning Generalizable Reward Functions from Demonstrations".
- Requires Python >= 3.7:
conda create -y -n bcirl python=3.7
pip install -e .
To run the obstacle version, substitute pointmass
with pointmass_obstacle
.
- BC-IRL-PPO
python imitation_learning/run.py +bc_irl=pointmass
- GCL
python imitation_learning/run.py +gcl=pointmass
- AIRL
python imitation_learning/run.py +airl=pointmass
- MaxEnt
python imitation_learning/run.py +maxent=pointmass
To evaluate on the eval
distribution add: env=pointmass_eval
. Specify the path of the saved reward with load_checkpoint=
.
- BC-IRL-PPO
python imitation_learning/eval.py +meta_irl=pointmass load_checkpoint=saved_reward.pth
- GCL
python imitation_learning/eval.py +gcl=pointmass load_checkpoint=saved_reward.pth
- AIRL
python imitation_learning/eval.py +airl=pointmass load_checkpoint=saved_reward.pth
- MaxEnt
python imitation_learning/eval.py +maxent=pointmass load_checkpoint=saved_reward.pth
Structure of the code under imitation_learning
:
run.py
: Code for the main training loop.config
: Theyaml
config files for Hydra split by each method. Underconfig/env
are the configs for the different environment settings (such as the generalization setting).config/logger
contains the configs for the WandB and CLI logger.config/default.yaml
contains the default settings shared across all methods.policy_opt
: Code for the policy and PPO updater. The PPO updater is designed to be differentiable with respect to the rewards for use in BC-IRL.bcirl
: The BC-IRL methodgail
: The GAIL baseline.gcl
: The GCL baseline.maxent
: The MaxEntIRL baseline.common
: Utilities for plotting in the point mass navigation task, reward functions, and other helper functions.
@article{szot2023bc,
title={BC-IRL: Learning Generalizable Reward Functions from Demonstrations},
author={Szot, Andrew and Zhang, Amy and Batra, Dhruv and Kira, Zsolt and Meier, Franziska},
journal={arXiv preprint arXiv:2303.16194},
year={2023}
}
this code is licensed under the CC-BY-NC license, see LICENSE.md for more details