This is a codebase primarily developed by Joey Hejna for training robot models using Jax, Flax, and the OpenX Embodiment datasets. We build heavily upon ideas used in the Octo repository.
Principles: this codebase is desined to be fucntional in nature. Feel free to define types and dataclasses and use objects from other libraries, but our implementations should be functions. This makes it easier to scale code across multiple platforms and for distributed training.
First, create a conda environment with python 3.11, and then install requirements and this repo.
conda create -n openx python=3.12
pip install -r requirements.txt
pip install -e .
If you are on GPU, you will additionally need to install the corresponding jaxlib verison.
pip install --upgrade "jax[cuda12_pip]==0.4.37" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
If you are on TPU, instead run:
pip install --upgrade "jax[tpu]==0.4.37" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
We benchmarked some of our implementations against Pytorch versions in robomimic. Installing the correct robomimic version corresponding to that used in the original Robomimic paper is pain. We provide more details commented out in the requirements.txt file, but the basics are as follows.
First, follow the instructions to install mujoco210_linux found here
sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
Then, install robosuite, robomimic, and needed dependencies.
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
# Dependencies
pip install "mujoco-py<2.2,>=2.0"
pip install cython==0.29.37
pip install numba
# Robosuite
git clone https://github.com/ARISE-Initiative/robosuite/
cd robosuite
git checkout offline_study
pip install -e . --no-deps # Ignore
cd ..
# Robomimic
git clone https://github.com/ARISE-Initiative/robomimic/
cd robomimic
git checkout v0.2.0
pip install -e . --no-deps # Ignore
cd ..
and enable USE_MUJOCO_PY in setup_shell.sh.
Then repeatedly try to import mujoco_py, robosuite, and robomimic until it works. There are a few manual changes to the code in robosuite and robomimic you will need to make:
- Comment out all references to EGL Probe if you are using TPU.
- You will need to change some imports to
from collections.abcfromfrom collections. This is because some typing hints used in robosuite and robomimic were deprecated in Python 3.11.
A few fixes if this doesn't immediately work:
conda install -c conda-forge gcc=12.1.0 # No longer used as of 12/24
If you want to use the libero benchmark, you have to follow separate installation instructions. Note that we parse these dependencies out carefully to prevent conflicts. For example, we make sure to install the CPU only version of PyTorch.
For TPUs, ensure the following are installed:
sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 libgl1-mesa-dev libsm6 libxext6
Then install the following python dependencies (in this order):
pip install torch==2.4.1 --index-url https://download.pytorch.org/whl/cpu
pip install robosuite==1.4.0 bddl==1.0.1 future "easydict==1.9" termcolor
git clone https://github.com/Lifelong-Robot-Learning/LIBERO
cd LIBERO
pip install -e . --no-deps
To avoid installing gym, I then comment out the line from .venv import SubprocVectorEnv, DummyVectorEnv in LIBERO/libero/libero/envs/venv.py.
If you encounter an error relating to AttributeError: 'NoneType' object has no attribute 'glGetError' when using MUJOCO_GL="osmesa" try the following fix:
If it doesn't immediately work with conda, try adding the following:
conda install -c conda-forge libstdcxx-ng
and do not enable USE_MUJOCO_PY in setup_shell.sh.
You can train a model with
python scripts/train.py --config path/to/config:config_str --path save/path --name name/on/wandb --project project/on/wandb
Example config files can be found in configs.
Dataloading is designed to happen in a functional pipeline. Implementations in openx/datasets/core.py include core functionality. openx/datasets/dataloader.py combines the functions in core in a user-approachable and configurable way.
There are
load_dataset. This is when you load and RLDS dataset, and must be used everywhere. After this step is when you can apply dataset specific transformations.compute_dataset_statisticscomputes and caches dataset statistics globally from a path. This ignores splits.
The dataloader class does this for all datasets in a standard fashion and then shuffles, decodes images, and applies augmentations.
The following features are planned:
- Incorporate language (choose where the instruction / encoding belongs)
- allow for changing the action keys for different datasets. ie on bridge we want to train on
achieved_deltabut on other datasets we wantdesired_delta. - Allow for structure padding., ie some datasets might not have some values, set those to zero.
- Add pretrained model loading support (look at flax big_vision for good pretrained models.)
- Fix seeding for random image augmentations for greater flexibility of training.
- Merge Concatenate and Tokenize to a more unified class
- Unify the convention for pooling at the end of vision encoders. Currently they are different... sad.
- add better Ensemblize support for action heads
- Update checkpointing to newer orbax paradigm.