-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #236 from mirraaj/MPI
Adding multiprocessing VecEnv for actor critic agents
- Loading branch information
Showing
6 changed files
with
239 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .misc_util import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import gym | ||
from rl.common.vec_env.subproc_env_vec import SubprocVecEnv | ||
from rl.common import set_global_seeds | ||
|
||
def make_gym_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0): | ||
""" | ||
Create a wrapped, SubprocVecEnv for Gym Environments. | ||
""" | ||
if wrapper_kwargs is None: wrapper_kwargs = {} | ||
def make_env(rank): # pylint: disable=C0111 | ||
def _thunk(): | ||
env = gym.make(env_id) | ||
env.seed(seed + rank) | ||
return env | ||
return _thunk | ||
set_global_seeds(seed) | ||
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import gym | ||
import numpy as np | ||
import keras.backend as K | ||
import random | ||
|
||
def set_global_seeds(i): | ||
np.random.seed(i) | ||
random.seed(i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import numpy as np | ||
|
||
def tile_images(img_nhwc): | ||
""" | ||
Tile N images into one big PxQ image | ||
(P,Q) are chosen to be as close as possible, and if N | ||
is square, then P=Q. | ||
input: img_nhwc, list or array of images, ndim=4 once turned into array | ||
n = batch index, h = height, w = width, c = channel | ||
returns: | ||
bigim_HWc, ndarray with ndim=3 | ||
""" | ||
img_nhwc = np.asarray(img_nhwc) | ||
N, h, w, c = img_nhwc.shape | ||
H = int(np.ceil(np.sqrt(N))) | ||
W = int(np.ceil(float(N)/H)) | ||
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) | ||
img_HWhwc = img_nhwc.reshape(H, W, h, w, c) | ||
img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) | ||
img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) | ||
return img_Hh_Ww_c |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
class VecEnv(object): | ||
""" | ||
An abstract asynchronous, vectorized environment. | ||
""" | ||
def __init__(self, num_envs, observation_space, action_space): | ||
self.num_envs = num_envs | ||
self.observation_space = observation_space | ||
self.action_space = action_space | ||
|
||
def reset(self): | ||
""" | ||
Reset all the environments and return an array of | ||
observations, or a tuple of observation arrays. | ||
If step_async is still doing work, that work will | ||
be cancelled and step_wait() should not be called | ||
until step_async() is invoked again. | ||
""" | ||
pass | ||
|
||
def step_async(self, actions): | ||
""" | ||
Tell all the environments to start taking a step | ||
with the given actions. | ||
Call step_wait() to get the results of the step. | ||
You should not call this if a step_async run is | ||
already pending. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def step_wait(self): | ||
""" | ||
Wait for the step taken with step_async(). | ||
Returns (obs, rews, dones, infos): | ||
- obs: an array of observations, or a tuple of | ||
arrays of observations. | ||
- rews: an array of rewards | ||
- dones: an array of "episode done" booleans | ||
- infos: a sequence of info objects | ||
""" | ||
raise NotImplementedError() | ||
|
||
def close(self): | ||
""" | ||
Clean up the environments' resources. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def step(self, actions): | ||
self.step_async(actions) | ||
return self.step_wait() | ||
|
||
def render(self, mode='human'): | ||
logger.warn('Render not defined for %s'%self) | ||
|
||
@property | ||
def unwrapped(self): | ||
if isinstance(self, VecEnvWrapper): | ||
return self.venv.unwrapped | ||
else: | ||
return self | ||
|
||
class VecEnvWrapper(VecEnv): | ||
def __init__(self, venv, observation_space=None, action_space=None): | ||
self.venv = venv | ||
VecEnv.__init__(self, | ||
num_envs=venv.num_envs, | ||
observation_space=observation_space or venv.observation_space, | ||
action_space=action_space or venv.action_space) | ||
|
||
def step_async(self, actions): | ||
self.venv.step_async(actions) | ||
|
||
def reset(self): | ||
pass | ||
|
||
def step_wait(self): | ||
pass | ||
|
||
def close(self): | ||
return self.venv.close() | ||
|
||
def render(self): | ||
self.venv.render() | ||
|
||
class CloudpickleWrapper(object): | ||
""" | ||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) | ||
""" | ||
def __init__(self, x): | ||
self.x = x | ||
def __getstate__(self): | ||
import cloudpickle | ||
return cloudpickle.dumps(self.x) | ||
def __setstate__(self, ob): | ||
import pickle | ||
self.x = pickle.loads(ob) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import numpy as np | ||
from multiprocessing import Process, Pipe | ||
from rl.common.vec_env import VecEnv, CloudpickleWrapper | ||
from rl.common.tile_images import tile_images | ||
|
||
def worker(remote, parent_remote, env_fn_wrapper): | ||
parent_remote.close() | ||
env = env_fn_wrapper.x() | ||
while True: | ||
cmd, data = remote.recv() | ||
if cmd == 'step': | ||
ob, reward, done, info = env.step(data) | ||
if done: | ||
ob = env.reset() | ||
remote.send((ob, reward, done, info)) | ||
elif cmd == 'reset': | ||
ob = env.reset() | ||
remote.send(ob) | ||
elif cmd == 'render': | ||
remote.send(env.render(mode='rgb_array')) | ||
elif cmd == 'close': | ||
remote.close() | ||
break | ||
elif cmd == 'get_spaces': | ||
remote.send((env.observation_space, env.action_space)) | ||
else: | ||
raise NotImplementedError | ||
|
||
|
||
class SubprocVecEnv(VecEnv): | ||
def __init__(self, env_fns, spaces=None): | ||
""" | ||
envs: list of gym environments to run in subprocesses | ||
""" | ||
self.waiting = False | ||
self.closed = False | ||
nenvs = len(env_fns) | ||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) | ||
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) | ||
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] | ||
for p in self.ps: | ||
p.daemon = True # if the main process crashes, we should not cause things to hang | ||
p.start() | ||
for remote in self.work_remotes: | ||
remote.close() | ||
|
||
self.remotes[0].send(('get_spaces', None)) | ||
observation_space, action_space = self.remotes[0].recv() | ||
VecEnv.__init__(self, len(env_fns), observation_space, action_space) | ||
|
||
def step_async(self, actions): | ||
for remote, action in zip(self.remotes, actions): | ||
remote.send(('step', action)) | ||
self.waiting = True | ||
|
||
def step_wait(self): | ||
results = [remote.recv() for remote in self.remotes] | ||
self.waiting = False | ||
obs, rews, dones, infos = zip(*results) | ||
return np.stack(obs), np.stack(rews), np.stack(dones), infos | ||
|
||
def reset(self): | ||
for remote in self.remotes: | ||
remote.send(('reset', None)) | ||
return np.stack([remote.recv() for remote in self.remotes]) | ||
|
||
def reset_task(self): | ||
for remote in self.remotes: | ||
remote.send(('reset_task', None)) | ||
return np.stack([remote.recv() for remote in self.remotes]) | ||
|
||
def close(self): | ||
if self.closed: | ||
return | ||
if self.waiting: | ||
for remote in self.remotes: | ||
remote.recv() | ||
for remote in self.remotes: | ||
remote.send(('close', None)) | ||
for p in self.ps: | ||
p.join() | ||
self.closed = True | ||
|
||
def render(self, mode='human'): | ||
for pipe in self.remotes: | ||
pipe.send(('render', None)) | ||
imgs = [pipe.recv() for pipe in self.remotes] | ||
bigimg = tile_images(imgs) | ||
if mode == 'human': | ||
import cv2 | ||
cv2.imshow('vecenv', bigimg[:,:,::-1]) | ||
cv2.waitKey(1) | ||
elif mode == 'rgb_array': | ||
return bigimg | ||
else: | ||
raise NotImplementedError |