Direct port of https://github.com/sfujim/TD3_BC to JAX using Haiku and optax.
python3 -m venv venv
source venv/bin/activate
pip install -U pip setuptools
pip install -r requirements.txt
Refer to the original README for usage.
TD3+BC is a simple approach to offline RL where only two changes are made to TD3: (1) a weighted behavior cloning loss is added to the policy update and (2) the states are normalized. Unlike competing methods there are no changes to architecture or underlying hyperparameters. The paper can be found here.
Paper results were collected with MuJoCo 1.50 (and mujoco-py 1.50.1.1) in OpenAI gym 0.17.0 with the D4RL datasets. Networks are trained using PyTorch 1.4.0 and Python 3.6.
The paper results can be reproduced by running:
./run_experiments.sh
*This is not an official Google product.