Official repository for "Investigating Pre-Training Objectives for Generalization in Vision-Based Reinforcement Learning" (ICML 2024).
Atari-PB is the first benchmark to compare the generalization capabilities of pre-trained RL agents under a unified protocol.
Each algorithm is evaluated by first pre-training an agent on a 10M dataset (across 50 games) then fine-tuning on 3 distinct environment distributions (ID, Near-OOD, Far-OOD, total 65 games).
[Paper] [Project page] [Model Weights & Datasets] (Password: ataripb)
We assume that you have access to GPU (preferably multiple) that can run CUDA 11.8 and CuDNN 8.7.0.
conda create -n ataripb python=3.9.11
conda activate ataripb
python3 -m pip install -r requirements.txt
AutoROM --accept-license
While the above is running, you can start downloading required datasets for your experiments.
You don't have to download everything, especially the pre-training dataset (which ends up taking around 6~700GB of storage).
Type | Distribution | Dataset source | Download |
---|---|---|---|
Pretrain | ID | DQN-Replay-Dataset | ./scripts/download_pretrain_dataset.sh |
Finetune | ID | DQN-Replay-Dataset | ./scripts/download_offline_bc_dataset.sh |
Finetune | Near-OOD | DQN-Replay-Dataset | ./scripts/download_offline_bc_dataset.sh |
Finetune | Far-OOD | 2M Rainbow agent | Download (Password: ataripb) |
Important: You have to make several (cumbersome) changes to the scripts and configs for Atari-PB to know where the dataset is.
- Specify the download directory at
data_dir
when using download scripts. - Specify the same directory at
replay_dataset_path
in the./configs/dataloader/pretrain.yaml
configuration file. - Specify your wandb entity name at
entity
in./configs/pretrain.yaml
,./configs/offline_bc.yaml
, etc. We recommend usinggroup_name
andexp_name
as well. - (Optional) Specify the directory to store the processed Atari-PB dataset at
ataripb_dataset_path
in./configs/dataloader/pretrain.yaml
. By default, all datasets will be stored under./materials/dataset
.
Algorithm | Author / Paper | Pre-train script |
---|---|---|
CURL | Laskin et al. | ./scripts/pretrain/curl.sh |
MAE | He et al. | ./scripts/pretrain/mae.sh |
ATC | Stooke et al. | ./scripts/pretrain/atc.sh |
SiamMAE | Gupta et al. | ./scripts/pretrain/siammae.sh |
R3M* | Nair et al. | ./scripts/pretrain/r3m.sh |
BC | Pomerleau | ./scripts/pretrain/bc.sh |
SPR | Schwarzer et al. | ./scripts/pretrain/spr.sh |
IDM | Christiano et al. | ./scripts/pretrain/idm.sh |
SPR+IDM | (SGI) Schwarzer et al. | ./scripts/pretrain/spr_idm.sh |
DT | Chen et al. | ./scripts/pretrain/dt.sh |
CQL | Kumar et al. | ./scripts/pretrain/cql_dist.sh ./scripts/pretrain/cql_mse.sh |
Several algorithms likely won't fit into a single GPU; we recommend activating DDP by adding e.g. --overrides num_gpus_per_node=4
in the scripts.
You can download the checkpoints of our pre-trained models in the main experiment here (Password: ataripb).
For checkpoints in the ablation studies, please contact the author via quagmire@kaist.ac.kr.
To fine-tune these models, you can start with ./scripts/offline_bc/base.sh
and ./scripts/online_rl/base.sh
.