/
multiprocessing_env.py
310 lines (259 loc) · 9.04 KB
/
multiprocessing_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""This code is taken from openai baseline
https://github.com/openai/baselines/tree/master/baselines/common/vec_env
"""
from abc import ABC, abstractmethod
from multiprocessing import Pipe, Process
import numpy as np
def tile_images(img_nhwc):
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
input: img_nhwc, list or array of images, ndim=4 once turned into array
n = batch index, h = height, w = width, c = channel
returns:
bigim_HWc, ndarray with ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
N, h, w, c = img_nhwc.shape
H = int(np.ceil(np.sqrt(N)))
W = int(np.ceil(float(N) / H))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(N, H * W)])
img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
img_Hh_Ww_c = img_HhWwc.reshape(H * h, W * w, c)
return img_Hh_Ww_c
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x
try:
while True:
cmd, data = remote.recv()
if cmd == "step":
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == "reset":
ob = env.reset()
remote.send(ob)
elif cmd == "render":
remote.send(env.render(mode="rgb_array"))
elif cmd == "sample":
remote.send(env.action_space.sample())
elif cmd == "close":
remote.close()
break
elif cmd == "get_spaces":
remote.send((env.observation_space, env.action_space))
else:
raise NotImplementedError
except KeyboardInterrupt:
print("SubprocVecEnv worker: got KeyboardInterrupt")
finally:
env.close()
class VecEnv(ABC):
"""An abstract asynchronous, vectorized environment.
Used to batch data from multiple copies of an environment, so that
each observation becomes an batch of observations, and expected action
is a batch of actions to be applied per-environment.
"""
closed = False
viewer = None
metadata = {"render.modes": ["human", "rgb_array"]}
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
@abstractmethod
def reset(self):
"""
Reset all the environments and return an array of
observations, or a dict of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
"""
raise NotImplementedError
@abstractmethod
def step_async(self, actions):
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
raise NotImplementedError
@abstractmethod
def step_wait(self):
"""
Wait for the step taken with step_async().
Returns (obs, rews, dones, infos):
- obs: an array of observations, or a dict of
arrays of observations.
- rews: an array of rewards
- dones: an array of "episode done" booleans
- infos: a sequence of info objects
"""
raise NotImplementedError
def close_extras(self):
"""
Clean up the extra resources, beyond what's in this base class.
Only runs when not self.closed.
"""
raise NotImplementedError
def close(self):
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras()
self.closed = True
def step(self, actions):
"""
Step the environments synchronously.
This is available for backwards compatibility.
"""
self.step_async(actions)
return self.step_wait()
def render(self, mode="human"):
"""Render."""
imgs = self.get_images()
bigimg = tile_images(imgs)
if mode == "human":
self.get_viewer().imshow(bigimg)
return self.get_viewer().isopen
elif mode == "rgb_array":
return bigimg
else:
raise NotImplementedError
def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
@property
def unwrapped(self):
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
def get_viewer(self):
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer()
return self.viewer
class VecEnvWrapper(VecEnv):
"""
An environment wrapper that applies to an entire batch
of environments at once.
"""
def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv
VecEnv.__init__(
self,
num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space,
)
def step_async(self, actions):
self.venv.step_async(actions)
@abstractmethod
def reset(self):
pass
@abstractmethod
def step_wait(self):
pass
def close(self):
return self.venv.close()
def render(self, mode="human"):
return self.venv.render(mode=mode)
def get_images(self):
return self.venv.get_images()
class CloudpickleWrapper:
"""Uses cloudpickle to serialize contents.
(otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
class SubprocVecEnv(VecEnv):
"""VecEnv that runs multiple environments in parallel in subproceses
and communicates with them via pipes.
Recommended to use when num_envs > 1 and step() can be a bottleneck.
"""
def __init__(self, env_fns):
"""
Arguments:
env_fns: iterable of callables - functions that create environments to run
in subprocesses. Need to be cloud-pickleable
"""
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [
Process(
target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))
)
for (work_remote, remote, env_fn) in zip(
self.work_remotes, self.remotes, env_fns
)
]
for p in self.ps:
p.daemon = (
True
) # if the main process crashes, we should not cause things to hang
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(("get_spaces", None))
observation_space, action_space = self.remotes[0].recv()
self.viewer = None
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
self._assert_not_closed()
for remote, action in zip(self.remotes, actions):
remote.send(("step", action))
self.waiting = True
def step_wait(self):
self._assert_not_closed()
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
self._assert_not_closed()
for remote in self.remotes:
remote.send(("reset", None))
return np.stack([remote.recv() for remote in self.remotes])
def sample(self):
self._assert_not_closed()
for remote in self.remotes:
remote.send(("sample", None))
return np.stack([remote.recv() for remote in self.remotes])
def close_extras(self):
self.closed = True
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(("close", None))
for p in self.ps:
p.join()
def get_images(self):
self._assert_not_closed()
for pipe in self.remotes:
pipe.send(("render", None))
imgs = [pipe.recv() for pipe in self.remotes]
return imgs
def _assert_not_closed(self):
assert (
not self.closed
), "Trying to operate on a SubprocVecEnv after calling close()"