Skip to content

Commit

Permalink
Merge pull request #237 from mirraaj/MPI
Browse files Browse the repository at this point in the history
Mpi
  • Loading branch information
random-user-x committed Aug 15, 2018
2 parents 5aa6366 + c5b26d8 commit e0fc419
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 37 deletions.
4 changes: 3 additions & 1 deletion rl/common/cmd_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Inspired from OpenAI Baselines

import gym
from rl.common.vec_env.subproc_env_vec import SubprocVecEnv
from rl.common import set_global_seeds

def make_gym_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
def make_gym_env(env_id, num_env=2, seed=123, wrapper_kwargs=None, start_index=0):
"""
Create a wrapped, SubprocVecEnv for Gym Environments.
"""
Expand Down
3 changes: 2 additions & 1 deletion rl/common/misc_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Inspired from OpenAI Baselines

import gym
import numpy as np
import keras.backend as K
import random

def set_global_seeds(i):
Expand Down
28 changes: 5 additions & 23 deletions rl/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Inspired from VecEnv from OpenAI Baselines

class VecEnv(object):
"""
An abstract asynchronous, vectorized environment.
Expand Down Expand Up @@ -52,36 +54,16 @@ def step(self, actions):
def render(self, mode='human'):
logger.warn('Render not defined for %s'%self)

def seed(self, i):
raise NotImplementedError()

@property
def unwrapped(self):
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self

class VecEnvWrapper(VecEnv):
def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv
VecEnv.__init__(self,
num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space)

def step_async(self, actions):
self.venv.step_async(actions)

def reset(self):
pass

def step_wait(self):
pass

def close(self):
return self.venv.close()

def render(self):
self.venv.render()

class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
Expand Down
24 changes: 12 additions & 12 deletions rl/common/vec_env/subproc_env_vec.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Inspired from OpenAI Baselines

import numpy as np
from multiprocessing import Process, Pipe
from rl.common.vec_env import VecEnv, CloudpickleWrapper
Expand All @@ -23,6 +25,9 @@ def worker(remote, parent_remote, env_fn_wrapper):
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'seed':
val = env.seed(data)
remote.send(val)
else:
raise NotImplementedError

Expand Down Expand Up @@ -82,15 +87,10 @@ def close(self):
self.closed = True

def render(self, mode='human'):
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
bigimg = tile_images(imgs)
if mode == 'human':
import cv2
cv2.imshow('vecenv', bigimg[:,:,::-1])
cv2.waitKey(1)
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
raise NotImplementedError('Render is not implemented for Synchronous Environment')

def seed(self, i):
rank = i
for remote in self.remotes:
remote.send(('seed', rank))
rank += 1

0 comments on commit e0fc419

Please sign in to comment.