Skip to content

PyTorch implementation of Wasserstein Adversarial Proximal Policy Optimization(WAPPO).

License

Notifications You must be signed in to change notification settings

toshikwa/wappo.pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

41 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

WAPPO in PyTorch

This is a PyTorch implementation of Wasserstein Adversarial Proximal Policy Optimization (WAPPO)[1]. I tried to make it easy for readers to understand the algorithm. Please let me know if you have any questions.

Setup

If you are using Anaconda, first create the virtual environment.

conda create -n wappo python=3.8 -y
conda activate wappo

You can install Python liblaries using pip.

pip install --upgrade pip
pip install -r requirements.txt

If you're using other than CUDA 10.2, you need to install PyTorch for the proper version of CUDA. See instructions for more details.

Example

VisualCartpole

I trained WAPPO and PPO on cartpole-visual-v1 as below. Following the WAPPO paper, results are averaged over 5 trials. a graph below corresponds to Figure 2 in the paper. Source and target tasks in my experiment are also shown below.

Note that I changed some hyperparameters from the paper. I set 128 for rollout_length instead of 256, and 2 for num_initial_blocks instead of 1. Please refer to config/cartpole.yaml for details.

python train.py --cuda --wappo --env_id cartpole-visual-v1 --config config/cartpole.yaml --trial 0

References

[1] Roy, Josh, and George Konidaris. "Visual Transfer for Reinforcement Learning via Wasserstein Domain Confusion."

About

PyTorch implementation of Wasserstein Adversarial Proximal Policy Optimization(WAPPO).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages