Official JAX code base for ICLR 2024 paper - SMORE: Score Models for Offline Goal-Conditioned Reinforcement Learning
Harshit Sikchi1, Rohan Chitnis2, Ahmed Touati2, Alborz Geramifard2, Amy Zhang1,2, Scott Niekum3,
1UT Austin
2Meta AI
3UMass Amherst
Create an empty conda environment and follow the commands below.
conda create -n smore python=3.9
conda install -c conda-forge cudnn
pip install --upgrade pip
# Install 1 of the below jax versions depending on your CUDA version
## 1. CUDA 12 installation
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
## 2. CUDA 11 installation
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt
The offline datasets can be downloaded from the google drive link WGCSL offline data. This dataset is provided by prior work WGCSL. Extract the offline data in root-folder/offline_data/*
Locomotion
python train_offline_smore.py --double=True --env_name=halfcheetah-medium-v2 --config=configs/gcrl_config.py --eval_episodes=10 --eval_interval=5000 --beta=0.8 --loss_type=<'smore_stable'/'smore'> --exp_name=<exp_name>
Manipulation
python train_offline_smore.py --double=True --env_name=SawyerReach --config=configs/gcrl_config.py --eval_episodes=10 --eval_interval=5000 --beta=0.8 --loss_type=<'smore_stable'/'smore'> --exp_name=<exp_name>
This code base builds upon the following code bases: Extreme Q-learning and Implicit Q-Learning.