Skip to content

Official repository for "Investigating Pre-Training Objectives for Generalization in Visual Reinforcement Learning" (ICML 2024)

Notifications You must be signed in to change notification settings

dojeon-ai/Atari-PB

Repository files navigation

Atari Pre-training Benchmark (Atari-PB)

Official repository for "Investigating Pre-Training Objectives for Generalization in Vision-Based Reinforcement Learning" (ICML 2024).

Atari-PB is the first benchmark to compare the generalization capabilities of pre-trained RL agents under a unified protocol.

Each algorithm is evaluated by first pre-training an agent on a 10M dataset (across 50 games) then fine-tuning on 3 distinct environment distributions (ID, Near-OOD, Far-OOD, total 65 games).

[Paper] [Project page] [Model Weights & Datasets] (Password: ataripb)

Installation

We assume that you have access to GPU (preferably multiple) that can run CUDA 11.8 and CuDNN 8.7.0.

1. Conda Environment

conda create -n ataripb python=3.9.11
conda activate ataripb
python3 -m pip install -r requirements.txt
AutoROM --accept-license

2. Dataset

While the above is running, you can start downloading required datasets for your experiments.

You don't have to download everything, especially the pre-training dataset (which ends up taking around 6~700GB of storage).

Type Distribution Dataset source Download
Pretrain ID DQN-Replay-Dataset ./scripts/download_pretrain_dataset.sh
Finetune ID DQN-Replay-Dataset ./scripts/download_offline_bc_dataset.sh
Finetune Near-OOD DQN-Replay-Dataset ./scripts/download_offline_bc_dataset.sh
Finetune Far-OOD 2M Rainbow agent Download (Password: ataripb)

Important: You have to make several (cumbersome) changes to the scripts and configs for Atari-PB to know where the dataset is.

  • Specify the download directory at data_dir when using download scripts.
  • Specify the same directory at replay_dataset_path in the ./configs/dataloader/pretrain.yaml configuration file.
  • Specify your wandb entity name at entity in ./configs/pretrain.yaml, ./configs/offline_bc.yaml, etc. We recommend using group_name and exp_name as well.
  • (Optional) Specify the directory to store the processed Atari-PB dataset at ataripb_dataset_path in ./configs/dataloader/pretrain.yaml. By default, all datasets will be stored under ./materials/dataset.

Implemented Algorithms

Algorithm Author / Paper Pre-train script
CURL Laskin et al. ./scripts/pretrain/curl.sh
MAE He et al. ./scripts/pretrain/mae.sh
ATC Stooke et al. ./scripts/pretrain/atc.sh
SiamMAE Gupta et al. ./scripts/pretrain/siammae.sh
R3M* Nair et al. ./scripts/pretrain/r3m.sh
BC Pomerleau ./scripts/pretrain/bc.sh
SPR Schwarzer et al. ./scripts/pretrain/spr.sh
IDM Christiano et al. ./scripts/pretrain/idm.sh
SPR+IDM (SGI) Schwarzer et al. ./scripts/pretrain/spr_idm.sh
DT Chen et al. ./scripts/pretrain/dt.sh
CQL Kumar et al. ./scripts/pretrain/cql_dist.sh
./scripts/pretrain/cql_mse.sh

Several algorithms likely won't fit into a single GPU; we recommend activating DDP by adding e.g. --overrides num_gpus_per_node=4 in the scripts.

Model Weights

You can download the checkpoints of our pre-trained models in the main experiment here (Password: ataripb).

For checkpoints in the ablation studies, please contact the author via quagmire@kaist.ac.kr.

To fine-tune these models, you can start with ./scripts/offline_bc/base.sh and ./scripts/online_rl/base.sh.

About

Official repository for "Investigating Pre-Training Objectives for Generalization in Visual Reinforcement Learning" (ICML 2024)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published