/
vector_env.py
202 lines (149 loc) · 6.14 KB
/
vector_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import gym
from gym.spaces import Tuple
from gym.vector.utils.spaces import batch_space
__all__ = ["VectorEnv"]
class VectorEnv(gym.Env):
r"""Base class for vectorized environments.
Each observation returned from vectorized environment is a batch of observations
for each sub-environment. And :meth:`step` is also expected to receive a batch of
actions for each sub-environment.
.. note::
All sub-environments should share the identical observation and action spaces.
In other words, a vector of multiple different environments is not supported.
Parameters
----------
num_envs : int
Number of environments in the vectorized environment.
observation_space : `gym.spaces.Space` instance
Observation space of a single environment.
action_space : `gym.spaces.Space` instance
Action space of a single environment.
"""
def __init__(self, num_envs, observation_space, action_space):
super(VectorEnv, self).__init__()
self.num_envs = num_envs
self.observation_space = batch_space(observation_space, n=num_envs)
self.action_space = Tuple((action_space,) * num_envs)
self.closed = False
self.viewer = None
# The observation and action spaces of a single environment are
# kept in separate properties
self.single_observation_space = observation_space
self.single_action_space = action_space
def reset_async(self):
pass
def reset_wait(self, **kwargs):
raise NotImplementedError()
def reset(self):
r"""Reset all sub-environments and return a batch of initial observations.
Returns
-------
observations : sample from `observation_space`
A batch of observations from the vectorized environment.
"""
self.reset_async()
return self.reset_wait()
def step_async(self, actions):
pass
def step_wait(self, **kwargs):
raise NotImplementedError()
def step(self, actions):
r"""Take an action for each sub-environments.
Parameters
----------
actions : iterable of samples from `action_space`
List of actions.
Returns
-------
observations : sample from `observation_space`
A batch of observations from the vectorized environment.
rewards : `np.ndarray` instance (dtype `np.float_`)
A vector of rewards from the vectorized environment.
dones : `np.ndarray` instance (dtype `np.bool_`)
A vector whose entries indicate whether the episode has ended.
infos : list of dict
A list of auxiliary diagnostic information dicts from sub-environments.
"""
self.step_async(actions)
return self.step_wait()
def close_extras(self, **kwargs):
r"""Clean up the extra resources e.g. beyond what's in this base class."""
raise NotImplementedError()
def close(self, **kwargs):
r"""Close all sub-environments and release resources.
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
:attr:`closed` as ``True``.
.. warning::
This function itself does not close the environments, it should be handled
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
vectorized environments.
.. note::
This will be automatically called when garbage collected or program exited.
"""
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras(**kwargs)
self.closed = True
def seed(self, seeds=None):
"""
Parameters
----------
seeds : list of int, or int, optional
Random seed for each individual environment. If `seeds` is a list of
length `num_envs`, then the items of the list are chosen as random
seeds. If `seeds` is an int, then each environment uses the random
seed `seeds + n`, where `n` is the index of the environment (between
`0` and `num_envs - 1`).
"""
pass
def __del__(self):
if not getattr(self, "closed", True):
self.close(terminate=True)
def __repr__(self):
if self.spec is None:
return "{}({})".format(self.__class__.__name__, self.num_envs)
else:
return "{}({}, {})".format(
self.__class__.__name__, self.spec.id, self.num_envs
)
class VectorEnvWrapper(VectorEnv):
r"""Wraps the vectorized environment to allow a modular transformation.
This class is the base class for all wrappers for vectorized environments. The subclass
could override some methods to change the behavior of the original vectorized environment
without touching the original code.
.. note::
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""
def __init__(self, env):
assert isinstance(env, VectorEnv)
self.env = env
# explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class)
def reset_async(self):
return self.env.reset_async()
def reset_wait(self):
return self.env.reset_wait()
def step_async(self, actions):
return self.env.step_async(actions)
def step_wait(self):
return self.env.step_wait()
def close(self, **kwargs):
return self.env.close(**kwargs)
def close_extras(self, **kwargs):
return self.env.close_extras(**kwargs)
def seed(self, seeds=None):
return self.env.seed(seeds)
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(
"attempted to get missing private attribute '{}'".format(name)
)
return getattr(self.env, name)
@property
def unwrapped(self):
return self.env.unwrapped
def __repr__(self):
return "<{}, {}>".format(self.__class__.__name__, self.env)