Skip to content

Commit

Permalink
Improve VectorEnv insterface
Browse files Browse the repository at this point in the history
  • Loading branch information
muupan committed Oct 8, 2018
1 parent 25924f8 commit 1a92bdb
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 37 deletions.
20 changes: 20 additions & 0 deletions chainerrl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,23 @@ def reset(self):
@abstractmethod
def close(self):
raise NotImplementedError()


class VectorEnv(with_metaclass(ABCMeta, object)):
"""Parallel RL learning environments."""

@abstractmethod
def step(self, action):
raise NotImplementedError()

@abstractmethod
def reset(self, mask):
raise NotImplementedError()

@abstractmethod
def seed(self, seeds):
raise NotImplementedError()

@abstractmethod
def close(self):
raise NotImplementedError()
2 changes: 2 additions & 0 deletions chainerrl/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from chainerrl.envs.multiprocess_vector_env import MultiprocessVectorEnv # NOQA
from chainerrl.envs.serial_vector_env import SerialVectorEnv # NOQA
Original file line number Diff line number Diff line change
Expand Up @@ -15,66 +15,76 @@
from chainerrl import env


def worker(remote, env_fn_wrapper):
env = env_fn_wrapper
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
def worker(remote, env_fn):
env = env_fn()
try:
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.action_space, env.observation_space))
elif cmd == 'spec':
remote.send(env.spec)
elif cmd == 'seed':
remote.send(env.seed(data))
else:
raise NotImplementedError
remote.send(ob)
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.action_space, env.observation_space))
elif cmd == 'spec':
remote.send(env.spec)
elif cmd == 'seed':
remote.send(env.seed(data))
else:
raise NotImplementedError
finally:
env.close()


class VectorEnv(env.Env):
class MultiprocessVectorEnv(env.VectorEnv):
"""VectorEnv where each env is run in its own subprocess.
def __init__(self, env_fns):
"""envs: list of gym environments to run in subprocesses
Args:
env_fns (list of callable): List of callables, each of which
returns gym.Env that is run in its own subprocess.
"""

"""
def __init__(self, env_fns):
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = \
[Process(target=worker, args=(work_remote, env_fn))
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
for p in self.ps:
p.start()

self.last_obs = [None] * self.num_envs
self.remotes[0].send(('get_spaces', None))
self.action_space, self.observation_space = self.remotes[0].recv()
self.closed = False

def __del__(self):
self.close()
if not self.closed:
self.close()

@cached_property
def spec(self):
self._assert_not_closed()
self.remotes[0].send(('spec', None))
spec = self.remotes[0].recv()
return spec

def step(self, actions):
self._assert_not_closed()
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
results = [remote.recv() for remote in self.remotes]
self.last_obs, rews, dones, infos = zip(*results)
return self.last_obs, rews, dones, infos

def reset(self, mask=None):
self._assert_not_closed()
if mask is None:
mask = np.zeros(self.num_envs)
for m, remote in zip(mask, self.remotes):
Expand All @@ -87,12 +97,15 @@ def reset(self, mask=None):
return obs

def close(self):
self._assert_not_closed()
self.closed = True
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()

def seed(self, seeds=None):
self._assert_not_closed()
if seeds is not None:
if isinstance(seeds, int):
seeds = [seeds] * self.num_envs
Expand All @@ -115,3 +128,6 @@ def seed(self, seeds=None):
@property
def num_envs(self):
return len(self.remotes)

def _assert_not_closed(self):
assert not self.closed, "Trying to operate on a MultiprocessVectorEnv after calling close()" # NOQA
51 changes: 51 additions & 0 deletions chainerrl/envs/serial_vector_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA

from cached_property import cached_property
import numpy as np

from chainerrl import env


class SerialVectorEnv(env.VectorEnv):
"""VectorEnv where each env is run sequentially.
The purpose of this VectorEnv is to help debugging. For speed, you should
use MultiprocessVectorEnv if possible.
Args:
env_fns (list of gym.Env): List of gym.Env.
"""

def __init__(self, envs):
self.envs = envs
self.last_obs = [None] * self.num_envs
self.action_space = envs[0].action_space
self.observation_space = envs[0].observation_space
self.spec = envs[0].observation_space

def step(self, actions):
results = [env.step(a) for env, a in zip(self.envs, actions)]
self.last_obs, rews, dones, infos = zip(*results)
return self.last_obs, rews, dones, infos

def reset(self, mask=None):
if mask is None:
mask = np.zeros(self.num_envs)
obs = [env.reset() if not m else o
for m, env, o in zip(mask, self.envs, self.last_obs)]
self.last_obs = obs
return obs

def seed(self, seeds):
for env, seed in zip(self.envs, seeds):
env.seed(seed)

@property
def num_envs(self):
return len(self.envs)
12 changes: 6 additions & 6 deletions examples/ale/train_dqn_batch_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import chainerrl
from chainerrl.action_value import DiscreteActionValue
from chainerrl import agents
from chainerrl.envs.vec_env import VectorEnv
from chainerrl import experiments
from chainerrl import explorers
from chainerrl import links
Expand Down Expand Up @@ -159,8 +158,9 @@ def make_env(idx, test):
return env

def make_batch_env(test):
return VectorEnv([make_env(idx, test)
for idx, env in enumerate(range(args.num_envs))])
return chainerrl.envs.MultiprocessVectorEnv(
[(lambda: make_env(idx, test))
for idx, env in enumerate(range(args.num_envs))])

sample_env = make_env(0, test=False)

Expand Down Expand Up @@ -215,7 +215,7 @@ def phi(x):

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env,
env=make_batch_env(test=True),
agent=agent,
n_runs=args.eval_n_runs)
print('n_runs: {} mean: {} median: {} stdev {}'.format(
Expand All @@ -224,8 +224,8 @@ def phi(x):
else:
experiments.train_agent_batch_with_evaluation(
agent=agent,
env=make_batch_env(False),
eval_env=make_batch_env(True),
env=make_batch_env(test=False),
eval_env=make_batch_env(test=True),
steps=args.steps,
eval_n_runs=args.eval_n_runs,
eval_interval=args.eval_interval,
Expand Down
7 changes: 4 additions & 3 deletions examples/gym/train_ppo_batch_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import gym.wrappers
import numpy as np

import chainerrl
from chainerrl.agents import a3c
from chainerrl.agents import PPO
from chainerrl.envs.vec_env import VectorEnv
from chainerrl import experiments
from chainerrl import links
from chainerrl import misc
Expand Down Expand Up @@ -166,8 +166,9 @@ def make_env(process_idx, test):
return env

def make_batch_env(test):
return VectorEnv([make_env(idx, test)
for idx, env in enumerate(range(args.num_envs))])
return chainerrl.envs.MultiprocessVectorEnv(
[(lambda: make_env(idx, test))
for idx, env in enumerate(range(args.num_envs))])

# Only for getting timesteps, and obs-action spaces
sample_env = gym.make(args.env)
Expand Down
95 changes: 95 additions & 0 deletions tests/envs_tests/test_vector_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA

import unittest

from chainer import cuda
from chainer import testing
from chainer.testing import attr
from chainer.testing import condition
import gym
import numpy as np

import chainerrl


@testing.parameterize(*testing.product({
'num_envs': [1, 2, 3],
'env_id': ['CartPole-v0', 'Pendulum-v0'],
'random_seed_offset': [0, 100],
'vector_env_to_test': ['SerialVectorEnv', 'MultiprocessVectorEnv'],
}))
class TestSerialVectorEnv(unittest.TestCase):

def setUp(self):
# Init VectorEnv to test
if self.vector_env_to_test == 'SerialVectorEnv':
self.vec_env = chainerrl.envs.SerialVectorEnv(
[gym.make(self.env_id) for _ in range(self.num_envs)])
elif self.vector_env_to_test == 'MultiprocessVectorEnv':
self.vec_env = chainerrl.envs.MultiprocessVectorEnv(
[(lambda: gym.make(self.env_id))
for _ in range(self.num_envs)])
else:
assert False
# Init envs to compare against
self.envs = [gym.make(self.env_id) for _ in range(self.num_envs)]

def tearDown(self):
# Delete so that all the subprocesses are joined
del self.vec_env

def test_num_envs(self):
self.assertEqual(self.vec_env.num_envs, self.num_envs)

def test_action_space(self):
self.assertEqual(self.vec_env.action_space, self.envs[0].action_space)

def test_observation_space(self):
self.assertEqual(
self.vec_env.observation_space, self.envs[0].observation_space)

def test_seed_reset_and_step(self):
# seed
seeds = [self.random_seed_offset + i for i in range(self.num_envs)]
self.vec_env.seed(seeds)
for env, seed in zip(self.envs, seeds):
env.seed(seed)

# reset
obss = self.vec_env.reset()
real_obss = [env.reset() for env in self.envs]
np.testing.assert_allclose(obss, real_obss)

# step
actions = [env.action_space.sample() for env in self.envs]
real_obss, real_rewards, real_dones, real_infos = zip(*[
env.step(action) for env, action in zip(self.envs, actions)])
obss, rewards, dones, infos = self.vec_env.step(actions)
np.testing.assert_allclose(obss, real_obss)
self.assertEqual(rewards, real_rewards)
self.assertEqual(dones, real_dones)
self.assertEqual(infos, real_infos)

# reset with full mask should have no effect
mask = np.ones(self.num_envs)
obss = self.vec_env.reset(mask)
np.testing.assert_allclose(obss, real_obss)

# reset with partial mask
mask = np.zeros(self.num_envs)
mask[-1] = 1
obss = self.vec_env.reset(mask)
real_obss = list(real_obss)
for i in range(self.num_envs):
if not mask[i]:
real_obss[i] = self.envs[i].reset()
np.testing.assert_allclose(obss, real_obss)


testing.run_module(__name__, __file__)

0 comments on commit 1a92bdb

Please sign in to comment.