Skip to content

dsshim0125/s2p

Repository files navigation

S2P: State-conditioned Image Synthesis for Data Augmentation in Offline Reinforcement Learning

This repo provides an official PyTorch implementation of "S2P: State-conditioned Image Synthesis for Data Augmentation in Offline Reinforcement Learning" (NeurIPS 2022). [paper]

Setup

conda create -n s2p python=3.8.5
conda activate s2p
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt

Our experiments have been done with PyTorch 1.10.1, CUDA 11.4, Python 3.8.5 and Ubuntu 18.04. We use a single NVIDIA RTX A6000 for training, but you can still run our code with GPUs which have smaller memory by reducing the batchSize. A simpel visualziation of the generation results can be done by GPUs with 4GB of memory use.

Download pre-trained models

We provide pre-trained weights of S2P in some environments for simple test of the generation performance. Create a folder ./checkpoints and download the model weights into it. Here are model weights of S2P trained on cheetah and walker environment of DeepMind Controp Suite.

Env_type model
cheetah cheetah_30.pth
walker walker_30.pth

Simple generation

We provide pre-trained models of S2P and some tiny dataset for simple visualization of S2P. Reviewers can easily visualize N-step generation results with --seq_len.

python simple_test.py --env_type=cheetah --dataroot=./datasets --netG=s2p --start_idx=0 --seq_len=5 --gpu_ids=0

Reference

  1. https://github.com/NVlabs/SPADE
  2. https://github.com/yenchenlin/nerf-pytorch
  3. https://github.com/huangzh13/StyleGAN.pytorch

About

"S2P: State-conditioned Image Synthesis for Data Augmentation in Offline Reinforcement Learning" (NeurIPS 2022)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published