Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporate distributed RL framework, Ape-X and Ape-X DQN #246

Merged
merged 30 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
aeef276
implement abstract classes for distributed and ApeX Learner wrapper
cyoon1729 Jun 11, 2020
32c073d
Implement params2numpy method that loads torch state_dict as array of…
cyoon1729 Jun 11, 2020
01e5f2c
add __init__
cyoon1729 Jun 11, 2020
4b7c7f6
implement worker as abstract class, not wrapper base class
cyoon1729 Jun 11, 2020
bbe0236
Change apex_learner file name to learner.
cyoon1729 Jun 11, 2020
069b361
Implement Ape-X worker and learner base classes
cyoon1729 Jun 11, 2020
7e19bca
implement Ape-X DQN worker
cyoon1729 Jun 11, 2020
1cc872e
Create base class for distributed architectures
cyoon1729 Jun 11, 2020
8c4ecf5
Take context as init_communication input; all processes share the sam…
cyoon1729 Jun 11, 2020
f2969e1
Implement and test Ape-X DQN working on Pong
cyoon1729 Jun 17, 2020
4a73778
Fix minor errors
cyoon1729 Jun 17, 2020
dcbb857
Implement ApeXWorker as a wrapper ApeXWorkerWrapper
cyoon1729 Jun 18, 2020
22b2d21
Move num_workers to hyperparams, and add logger_interval to hyperparams.
cyoon1729 Jun 18, 2020
610f0f6
Implement safe exit condition for all ray actors.
cyoon1729 Jun 18, 2020
86da701
Change _init_communication -> init_communication and call outside of …
cyoon1729 Jun 19, 2020
4342b78
* Add documentation
cyoon1729 Jun 19, 2020
d36212a
* Move num_worker to hyper_param cfg
cyoon1729 Jun 19, 2020
0388ac3
* Add author
cyoon1729 Jun 22, 2020
1ab5e35
argparse integration test flag store_false->store_true
cyoon1729 Jun 22, 2020
b954f4d
Change default config to dqn.
cyoon1729 Jun 22, 2020
a7be46b
* Log worker scores per update step on Wandb.
cyoon1729 Jun 22, 2020
ab93eb6
Modify integration test
Jun 23, 2020
15bce10
Modify apex buffer config for integration test
Jun 23, 2020
ec150b5
Change distributed directory structure
Jun 23, 2020
cd0af75
Add documentation
Jun 23, 2020
f3c0646
Merge branch 'master' into feature/add_apex_rebase
MrSyee Jun 23, 2020
edd0219
Modify readme.md
Jun 23, 2020
e8ccea4
Modify readme.md
Jun 23, 2020
64afd45
Add Ape-X to README.
cyoon1729 Jun 23, 2020
5c09012
Add description about args flags for distributed training.
cyoon1729 Jun 23, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 77 additions & 0 deletions configs/pong_no_frameskip_v4/apex_dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Config for ApeX-DQN on Pong-No_FrameSkip-v4.

- Author: Chris Yoon
- Contact: chris.yoon@medipixel.io
"""

from rl_algorithms.common.helper_functions import identity

agent = dict(
type="ApeX",
hyper_params=dict(
gamma=0.99,
tau=5e-3,
buffer_size=int(2.5e5), # openai baselines: int(1e4)
batch_size=512, # openai baselines: 32
update_starts_from=int(1e5), # openai baselines: int(1e4)
multiple_update=1, # multiple learning updates
train_freq=1, # in openai baselines, train_freq = 4
gradient_clip=10.0, # dueling: 10.0
n_step=5,
w_n_step=1.0,
w_q_reg=0.0,
per_alpha=0.6, # openai baselines: 0.6
per_beta=0.4,
per_eps=1e-6,
loss_type=dict(type="DQNLoss"),
# Epsilon Greedy
max_epsilon=1.0,
min_epsilon=0.1, # openai baselines: 0.01
epsilon_decay=1e-6, # openai baselines: 1e-7 / 1e-1
# grad_cam
grad_cam_layer_list=[
"backbone.cnn.cnn_0.cnn",
"backbone.cnn.cnn_1.cnn",
"backbone.cnn.cnn_2.cnn",
],
num_workers=4,
local_buffer_max_size=1000,
worker_update_interval=50,
logger_interval=2000,
),
learner_cfg=dict(
type="DQNLearner",
device="cuda",
backbone=dict(
type="CNN",
configs=dict(
input_sizes=[4, 32, 64],
output_sizes=[32, 64, 64],
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
paddings=[1, 0, 0],
),
),
head=dict(
type="DuelingMLP",
configs=dict(
use_noisy_net=False, hidden_sizes=[512], output_activation=identity
),
),
optim_cfg=dict(
lr_dqn=0.0003, # dueling: 6.25e-5, openai baselines: 1e-4
weight_decay=0.0, # this makes saturation in cnn weights
adam_eps=1e-8, # rainbow: 1.5e-4, openai baselines: 1e-8
),
),
worker_cfg=dict(type="DQNWorker", device="cpu",),
logger_cfg=dict(type="DQNLogger",),
comm_cfg=dict(
learner_buffer_port=6554,
learner_worker_port=6555,
worker_buffer_port=6556,
learner_logger_port=6557,
send_batch_port=6558,
priorities_port=6559,
),
)
11 changes: 10 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,17 @@ cloudpickle
opencv-python
wandb
addict

# mujoco

# for distributed learning
ray
ray[debug]
pyzmq
pyarrow

# for log
matplotlib
plotly

setuptools
wheel
6 changes: 6 additions & 0 deletions rl_algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from .bc.her import LunarLanderContinuousHER, ReacherHER
from .bc.sac_agent import BCSACAgent
from .bc.sac_learner import BCSACLearner
from .common.distributed.apex import ApeX
from .common.networks.backbones import CNN, ResNet
from .ddpg.agent import DDPGAgent
from .ddpg.learner import DDPGLearner
from .dqn.agent import DQNAgent
from .dqn.learner import DQNLearner
from .dqn.logger import DQNLogger
from .dqn.losses import C51Loss, DQNLoss, IQNLoss
from .dqn.worker import DQNWorker
from .fd.ddpg_agent import DDPGfDAgent
from .fd.ddpg_learner import DDPGfDLearner
from .fd.dqn_agent import DQfDAgent
Expand Down Expand Up @@ -57,4 +60,7 @@
"IQNLoss",
"C51Loss",
"DQNLoss",
"ApeX",
"DQNWorker",
"DQNLogger",
]
23 changes: 23 additions & 0 deletions rl_algorithms/common/abstract/architecture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Abstract class for distributed architectures.

- Author: Chris Yoon
- Contact: chris.yoon@medipixel.io
"""

from abc import ABC, abstractmethod


class Architecture(ABC):
"""Abstract class for distributed architectures"""

@abstractmethod
def _spawn(self):
pass

@abstractmethod
def train(self):
pass

@abstractmethod
def test(self):
pass
239 changes: 239 additions & 0 deletions rl_algorithms/common/abstract/distributed_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
"""Base class for loggers use in distributed training.

- Author: Chris Yoon
- Contact: chris.yoon@medipixel.io
"""

from abc import ABC, abstractmethod
import argparse
from collections import deque
import os
import shutil
from typing import List

import gym
import numpy as np
import plotly.graph_objects as go
import pyarrow as pa
import torch
import wandb
import zmq

from rl_algorithms.common.env.atari_wrappers import atari_env_generator
import rl_algorithms.common.env.utils as env_utils
from rl_algorithms.common.networks.brain import Brain
from rl_algorithms.utils.config import ConfigDict


class DistributedLogger(ABC):
"""Base class for loggers use in distributed training.

Attributes:
args (argparse.Namespace): arguments including hyperparameters and training settings
env_info (ConfigDict): information about environment
log_cfg (ConfigDict): configuration for saving log and checkpoint
comm_config (ConfigDict): configs for communication
backbone (ConfigDict): backbone configs for building network
head (ConfigDict): head configs for building network
brain (Brain): logger brain for evaluation
update_step (int): tracker for learner update step
device (torch.device): device, cpu by default
log_info_queue (deque): queue for storing log info received from learner
env (gym.Env): gym environment for running test

"""

def __init__(
self,
args: argparse.Namespace,
env_info: ConfigDict,
log_cfg: ConfigDict,
comm_cfg: ConfigDict,
backbone: ConfigDict,
head: ConfigDict,
):
self.args = args
self.env_info = env_info
self.log_cfg = log_cfg
self.comm_cfg = comm_cfg
self.device = torch.device("cpu") # Logger only runs on cpu
self.brain = Brain(backbone, head).to(self.device)

self.update_step = 0
self.log_info_queue = deque(maxlen=100)

self._init_env()

# pylint: disable=attribute-defined-outside-init
def _init_env(self):
"""Initialize gym environment."""
if self.env_info.is_atari:
self.env = atari_env_generator(
self.env_info.name, self.args.max_episode_steps
)
else:
self.env = gym.make(self.env_info.name)
env_utils.set_env(self.env, self.args)

@abstractmethod
def load_params(self, path: str):
if not os.path.exists(path):
raise Exception(
f"[ERROR] the input path does not exist. Wrong path: {path}"
)

# pylint: disable=attribute-defined-outside-init
def init_communication(self):
"""Initialize inter-process communication sockets."""
ctx = zmq.Context()
self.pull_socket = ctx.socket(zmq.PULL)
self.pull_socket.bind(f"tcp://127.0.0.1:{self.comm_cfg.learner_logger_port}")

@abstractmethod
def select_action(self, state: np.ndarray):
pass

@abstractmethod
def write_log(self, log_value: dict):
pass

# pylint: disable=no-self-use
@staticmethod
def _preprocess_state(state: np.ndarray, device: torch.device) -> torch.Tensor:
state = torch.FloatTensor(state).to(device)
return state

def set_wandb(self):
"""Set configuration for wandb logging."""
wandb.init(
project=self.env_info.name,
name=f"{self.log_cfg.agent}/{self.log_cfg.curr_time}",
)
wandb.config.update(vars(self.args))
shutil.copy(self.args.cfg_path, os.path.join(wandb.run.dir, "config.py"))

def recv_log_info(self):
"""Receive info from learner."""
received = False
try:
log_info_id = self.pull_socket.recv(zmq.DONTWAIT)
received = True
except zmq.Again:
pass

if received:
self.log_info_queue.append(log_info_id)

def run(self):
"""Run main logging loop; continuously receive data and log."""
if self.args.log:
self.set_wandb()

while self.update_step < self.args.max_update_step:
self.recv_log_info()
if self.log_info_queue: # if non-empty
log_info_id = self.log_info_queue.pop()
log_info = pa.deserialize(log_info_id)
state_dict = log_info["state_dict"]
log_value = log_info["log_value"]
self.update_step = log_value["update_step"]

self.synchronize(state_dict)
avg_score = self.test(self.update_step)
log_value["avg_score"] = avg_score
self.write_log(log_value)

def write_worker_log(self, worker_logs: List[dict]):
"""Log the mean scores of each episode per update step to wandb."""
# NOTE: Worker plots are passed onto wandb.log as matplotlib.pyplot
# since wandb doesn't support logging multiple lines to single plot
if self.args.log:
self.set_wandb()
# Plot individual workers
fig = go.Figure()
worker_id = 0
for worker_log in worker_logs:
fig.add_trace(
go.Scatter(
x=list(worker_log.keys()),
y=list(worker_log.values()),
mode="lines",
name=f"Worker {worker_id}",
line=dict(width=2),
)
)
worker_id = worker_id + 1

# Plot mean scores
steps = worker_logs[0].keys()
mean_scores = []
for step in steps:
each_scores = [worker_log[step] for worker_log in worker_logs]
mean_scores.append(np.mean(each_scores))

fig.add_trace(
go.Scatter(
x=list(worker_logs[0].keys()),
y=mean_scores,
mode="lines+markers",
name="Mean scores",
line=dict(width=5),
)
)

# Write to wandb
wandb.log({"Worker scores": fig})

def test(self, update_step: int, interim_test: bool = True):
"""Test the agent."""
avg_score = self._test(update_step, interim_test)

# termination
self.env.close()
return avg_score

def _test(self, update_step: int, interim_test: bool) -> float:
"""Common test routine."""
if interim_test:
test_num = self.args.interim_test_num
else:
test_num = self.args.episode_num

scores = []
for i_episode in range(test_num):
state = self.env.reset()
done = False
score = 0
step = 0

while not done:
if self.args.logger_render:
self.env.render()

action = self.select_action(state)
next_state, reward, done, _ = self.env.step(action)

state = next_state
score += reward
step += 1

scores.append(score)

if interim_test:
print(
"[INFO] update step: %d\ttest %d\tstep: %d\ttotal score: %d"
% (update_step, i_episode, step, score)
)
else:
print(
"[INFO] test %d\tstep: %d\ttotal score: %d"
% (i_episode, step, score)
)

return np.mean(scores)

def synchronize(self, new_params: List[np.ndarray]):
"""Copy parameters from numpy arrays"""
for param, new_param in zip(self.brain.parameters(), new_params):
new_param = torch.FloatTensor(new_param).to(self.device)
param.data.copy_(new_param)