Skip to content

Commit

Permalink
Incorporate distributed RL framework, Ape-X and Ape-X DQN (#246)
Browse files Browse the repository at this point in the history
* Take context as init_communication input; all processes share the same context.

* implement abstract classes for distributed and ApeX Learner wrapper

* Implement params2numpy method that loads torch state_dict as array of np.ndarray.

* add __init__

* implement worker as abstract class, not wrapper base class

* Change apex_learner file name to learner.

* Implement Ape-X worker and learner base classes

* implement Ape-X DQN worker

* Create base class for distributed architectures

* Implement and test Ape-X DQN working on Pong
* Accept current change (master) for PongNoFrameskip-v4 dqn config
* Make env_info more explicit in run_pong script (accept incoming change)
* Make learner return cpu state_dict (accept incoming change)

* Fix minor errors

* Implement ApeXWorker as a wrapper ApeXWorkerWrapper
Implement Logger and test wandb functionality
Add worker and logger render in argparse
Implement load_param() method in logger and worker

* Move num_workers to hyperparams, and add logger_interval to hyperparams.

* Implement safe exit condition for all ray actors.

* Change _init_communication -> init_communication and call outside of __init__ for all ApeX actors
Implement test() in distributed architectures (load from checkpoint and run logger test())

* * Add documentation
* Move collect_data from worker class to ApeX Wrapper
* Change hyperparameters around
* Add worker-verbose as argparse flag

* * Move num_worker to hyper_param cfg

* * Add author
* Add separate integration test for ApeX
* Add integration test flag to pong

* argparse integration test flag store_false->store_true

* Change default config to dqn.

* * Log worker scores per update step on Wandb.

* Modify integration test

* Modify apex buffer config for integration test

* Change distributed directory structure

* Add documentation

* Modify readme.md

* Modify readme.md

* Add Ape-X to README.

* Add description about args flags for distributed training.

Co-authored-by: khkim <kh.kim@medipixel.io>
Co-authored-by: Kyunghwan Kim <khsyee@gmail.com>
  • Loading branch information
3 people committed Jun 23, 2020
1 parent 07743f6 commit 9e897ad
Show file tree
Hide file tree
Showing 28 changed files with 1,538 additions and 49 deletions.
16 changes: 1 addition & 15 deletions LICENSE.md
@@ -1,5 +1,4 @@
# Our repository
MIT License
The MIT License (MIT)

Copyright (c) 2019 Medipixel

Expand All @@ -20,16 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

# Mujoco models
This work is derived from [MuJuCo models](http://www.mujoco.org/forum/index.php?resources/) used under the following license:
```
This file is part of MuJoCo.
Copyright 2009-2015 Roboti LLC.
Mujoco :: Advanced physics simulation engine
Source : www.roboti.us
Version : 1.31
Released : 23Apr16
Author :: Vikash Kumar
Contacts : kumar@roboti.us
```
19 changes: 15 additions & 4 deletions README.md
@@ -1,9 +1,9 @@
<p align="center">
<img src="https://user-images.githubusercontent.com/17582508/52845370-4a930200-314a-11e9-9889-e00007043872.jpg" align="center">

[![CircleCI](https://circleci.com/gh/circleci/circleci-docs.svg?style=shield)](https://circleci.com/gh/medipixel)
[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/medipixel/rl_algorithms.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/medipixel/rl_algorithms/context:python)
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->

Expand Down Expand Up @@ -63,8 +63,8 @@ This project follows the [all-contributors](https://github.com/all-contributors/
7. [Rainbow DQN](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn)
8. [Rainbow IQN (without DuelingNet)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn) - DuelingNet [degrades performance](https://github.com/medipixel/rl_algorithms/pull/137)
9. Rainbow IQN (with [ResNet](https://github.com/medipixel/rl_algorithms/blob/master/rl_algorithms/common/networks/backbones/resnet.py))
10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent/dqn_agent.py)

10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent)
11. [Distributed Pioritized Experience Replay (Ape-X)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/common/distributed)

## Performance

Expand Down Expand Up @@ -205,6 +205,16 @@ python <run-file> -h
- Start rendering after the number of episodes.
- `--load-from <save-file-path>`
- Load the saved models and optimizers at the beginning.

#### Arguments for distributed training in run-files
- `--max-episode-steps <int>`
- Set maximum update step for learner as a stopping criterion for training loop. If the number is less than or equal to 0, it uses the default maximum step number of the environment.
- `--off-worker-render`
- Turn off rendering of individual workers.
- `--off-logger-render`
- Turn off rendering of logger tests.
- `--worker-verbose`
- Turn on printing episode run info for individual workers


#### Show feature map with Grad-CAM
Expand Down Expand Up @@ -252,3 +262,4 @@ This won't be frequently updated.
17. [Ramprasaath R. Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." arXiv preprint arXiv:1610.02391, 2016.](https://arxiv.org/pdf/1610.02391.pdf)
18. [Kaiming He et al., "Deep Residual Learning for Image Recognition." arXiv preprint arXiv:1512.03385, 2015.](https://arxiv.org/pdf/1512.03385)
19. [Steven Kapturowski et al., "Recurrent Experience Replay in Distributed Reinforcement Learning." in International Conference on Learning Representations https://openreview.net/forum?id=r1lyTjAqYX, 2019.](https://openreview.net/forum?id=r1lyTjAqYX)
20. [Horgan et al., "Distributed Prioritized Experience Replay." in International Conference on Learning Representations, 2018](https://arxiv.org/pdf/1803.00933.pdf)
77 changes: 77 additions & 0 deletions configs/pong_no_frameskip_v4/apex_dqn.py
@@ -0,0 +1,77 @@
"""Config for ApeX-DQN on Pong-No_FrameSkip-v4.
- Author: Chris Yoon
- Contact: chris.yoon@medipixel.io
"""

from rl_algorithms.common.helper_functions import identity

agent = dict(
type="ApeX",
hyper_params=dict(
gamma=0.99,
tau=5e-3,
buffer_size=int(2.5e5), # openai baselines: int(1e4)
batch_size=512, # openai baselines: 32
update_starts_from=int(1e5), # openai baselines: int(1e4)
multiple_update=1, # multiple learning updates
train_freq=1, # in openai baselines, train_freq = 4
gradient_clip=10.0, # dueling: 10.0
n_step=5,
w_n_step=1.0,
w_q_reg=0.0,
per_alpha=0.6, # openai baselines: 0.6
per_beta=0.4,
per_eps=1e-6,
loss_type=dict(type="DQNLoss"),
# Epsilon Greedy
max_epsilon=1.0,
min_epsilon=0.1, # openai baselines: 0.01
epsilon_decay=1e-6, # openai baselines: 1e-7 / 1e-1
# grad_cam
grad_cam_layer_list=[
"backbone.cnn.cnn_0.cnn",
"backbone.cnn.cnn_1.cnn",
"backbone.cnn.cnn_2.cnn",
],
num_workers=4,
local_buffer_max_size=1000,
worker_update_interval=50,
logger_interval=2000,
),
learner_cfg=dict(
type="DQNLearner",
device="cuda",
backbone=dict(
type="CNN",
configs=dict(
input_sizes=[4, 32, 64],
output_sizes=[32, 64, 64],
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
paddings=[1, 0, 0],
),
),
head=dict(
type="DuelingMLP",
configs=dict(
use_noisy_net=False, hidden_sizes=[512], output_activation=identity
),
),
optim_cfg=dict(
lr_dqn=0.0003, # dueling: 6.25e-5, openai baselines: 1e-4
weight_decay=0.0, # this makes saturation in cnn weights
adam_eps=1e-8, # rainbow: 1.5e-4, openai baselines: 1e-8
),
),
worker_cfg=dict(type="DQNWorker", device="cpu",),
logger_cfg=dict(type="DQNLogger",),
comm_cfg=dict(
learner_buffer_port=6554,
learner_worker_port=6555,
worker_buffer_port=6556,
learner_logger_port=6557,
send_batch_port=6558,
priorities_port=6559,
),
)
11 changes: 10 additions & 1 deletion requirements.txt
Expand Up @@ -9,8 +9,17 @@ cloudpickle
opencv-python
wandb
addict

# mujoco

# for distributed learning
ray
ray[debug]
pyzmq
pyarrow

# for log
matplotlib
plotly

setuptools
wheel
6 changes: 6 additions & 0 deletions rl_algorithms/__init__.py
Expand Up @@ -5,12 +5,15 @@
from .bc.her import LunarLanderContinuousHER, ReacherHER
from .bc.sac_agent import BCSACAgent
from .bc.sac_learner import BCSACLearner
from .common.distributed.apex import ApeX
from .common.networks.backbones import CNN, ResNet
from .ddpg.agent import DDPGAgent
from .ddpg.learner import DDPGLearner
from .dqn.agent import DQNAgent
from .dqn.learner import DQNLearner
from .dqn.logger import DQNLogger
from .dqn.losses import C51Loss, DQNLoss, IQNLoss
from .dqn.worker import DQNWorker
from .fd.ddpg_agent import DDPGfDAgent
from .fd.ddpg_learner import DDPGfDLearner
from .fd.dqn_agent import DQfDAgent
Expand Down Expand Up @@ -65,4 +68,7 @@
"R2D1IQNLoss",
"R2D1C51Loss",
"R2D1DQNLoss",
"ApeX",
"DQNWorker",
"DQNLogger",
]
23 changes: 23 additions & 0 deletions rl_algorithms/common/abstract/architecture.py
@@ -0,0 +1,23 @@
"""Abstract class for distributed architectures.
- Author: Chris Yoon
- Contact: chris.yoon@medipixel.io
"""

from abc import ABC, abstractmethod


class Architecture(ABC):
"""Abstract class for distributed architectures"""

@abstractmethod
def _spawn(self):
pass

@abstractmethod
def train(self):
pass

@abstractmethod
def test(self):
pass

0 comments on commit 9e897ad

Please sign in to comment.