Skip to content

Commit

Permalink
Make test a bit clearer
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
  • Loading branch information
lebrice committed Sep 24, 2021
1 parent 53772da commit bcc5c54
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 46 deletions.
9 changes: 5 additions & 4 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def step_fn(actions):
if command == "reset":
observation = env.reset()
pipe.send((observation, True))
elif command == 'step':
elif command == "step":
observation, reward, done, info = step_fn(data)
pipe.send(((observation, reward, done, info), True))
elif command == "seed":
Expand Down Expand Up @@ -467,10 +467,11 @@ def step_fn(actions):
index, observation, shared_memory, observation_space
)
pipe.send((None, True))
elif command == 'step':
elif command == "step":
observation, reward, done, info = step_fn(data)
write_to_shared_memory(index, observation, shared_memory,
observation_space)
write_to_shared_memory(
index, observation, shared_memory, observation_space
)
pipe.send(((None, reward, done, info), True))
elif command == "seed":
env.seed(data)
Expand Down
132 changes: 90 additions & 42 deletions gym/vector/tests/test_vector_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from typing import Callable, Type
import pytest
import numpy as np

Expand Down Expand Up @@ -58,79 +59,126 @@ def test_custom_space_vector_env():
assert isinstance(env.action_space, Tuple)


@pytest.mark.parametrize('base_env', ["CubeCrash-v0", "CartPole-v0"])
@pytest.mark.parametrize('async_inner', [False, True])
@pytest.mark.parametrize('async_outer', [False, True])
@pytest.mark.parametrize('inner_envs', [1, 4, 7])
@pytest.mark.parametrize('outer_envs', [1, 4, 7])
def test_nesting_vector_envs(base_env: str,
async_inner: bool,
async_outer: bool,
inner_envs: int,
outer_envs: int):
inner_vector_wrapper = AsyncVectorEnv if async_inner else SyncVectorEnv
# When nesting AsyncVectorEnvs, only the "innermost" envs can have
@pytest.mark.parametrize("base_env", ["CubeCrash-v0", "CartPole-v0"])
@pytest.mark.parametrize("async_inner", [False, True])
@pytest.mark.parametrize("async_outer", [False, True])
@pytest.mark.parametrize("n_inner_envs", [1, 4, 7])
@pytest.mark.parametrize("n_outer_envs", [1, 4, 7])
def test_nesting_vector_envs(
base_env: str,
async_inner: bool,
async_outer: bool,
n_inner_envs: int,
n_outer_envs: int,
):
"""Tests nesting of vector envs: Using a VectorEnv of VectorEnvs.
This can be useful for example when running a large number of environments
on a machine with few cores, as worker process of an AsyncVectorEnv can themselves
run multiple environments sequentially using a SyncVectorEnv (a.k.a. chunking).
This test uses creates `n_outer_envs` vectorized environments, each of which has
`n_inner_envs` inned environments. If `async_outer` is True, then the outermost
wrapper is an `AsyncVectorEnv` and a `SyncVectorEnv` when `async_outer` is False.
Same goes for the "inner" environments.
Parameters
----------
- base_env : str
The base environment id.
- async_inner : bool
Wether the inner VectorEnv will be async or not.
- async_outer : bool
Wether the outer VectorEnv will be async or not.
- n_inner_envs : int
Number of inner environments.
- n_outer_envs : int
Number of outer environments.
"""

inner_vectorenv_type: Type[VectorEnv] = (
AsyncVectorEnv if async_inner else SyncVectorEnv
)
outer_vectorenv_type: Type[VectorEnv] = (
partial(AsyncVectorEnv, daemon=False) if async_outer else SyncVectorEnv
)
# NOTE: When nesting AsyncVectorEnvs, only the "innermost" envs can have
# `daemon=True`, otherwise the "daemonic processes are not allowed to have
# children" AssertionError is raised in `multiprocessing.process`.
outer_vector_wrapper = (
partial(AsyncVectorEnv, daemon=False) if async_outer
else SyncVectorEnv

# Create the VectorEnv of VectorEnvs
env = outer_vectorenv_type(
[
partial(
inner_vectorenv_type,
env_fns=[
make_env(base_env, seed=n_inner_envs * i + j)
for j in range(n_inner_envs)
],
)
for i in range(n_outer_envs)
]
)

env = outer_vector_wrapper([ # type: ignore
partial(inner_vector_wrapper, [
make_env(base_env, seed=inner_envs * i + j) for j in range(inner_envs)
]) for i in range(outer_envs)
])


# Create a single test environment.
with make_env(base_env, 0)() as temp_single_env:
single_observation_space = temp_single_env.observation_space
single_action_space = temp_single_env.action_space

assert isinstance(single_observation_space, Box)
assert isinstance(env.observation_space, Box)
assert env.observation_space.shape == (outer_envs, inner_envs, *single_observation_space.shape)
assert env.observation_space.shape == (
n_outer_envs,
n_inner_envs,
*single_observation_space.shape,
)
assert env.observation_space.dtype == single_observation_space.dtype

assert isinstance(env.action_space, spaces.Tuple)
assert len(env.action_space.spaces) == outer_envs
assert len(env.action_space.spaces) == n_outer_envs
assert all(
isinstance(outer_action_space, spaces.Tuple) and
len(outer_action_space.spaces) == inner_envs
isinstance(outer_action_space, spaces.Tuple)
and len(outer_action_space.spaces) == n_inner_envs
for outer_action_space in env.action_space.spaces
)
assert all([
len(inner_action_space.spaces) == inner_envs
for inner_action_space in env.action_space.spaces
])
assert all([
inner_action_space.spaces[i] == single_action_space
for inner_action_space in env.action_space.spaces
for i in range(inner_envs)
])
assert all(
[
len(inner_action_space.spaces) == n_inner_envs
for inner_action_space in env.action_space.spaces
]
)
assert all(
[
inner_action_space.spaces[i] == single_action_space
for inner_action_space in env.action_space.spaces
for i in range(n_inner_envs)
]
)

with env:
observations = env.reset()
assert observations in env.observation_space

actions = env.action_space.sample()
assert actions in env.action_space

observations, rewards, dones, _ = env.step(actions)
assert observations in env.observation_space

assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (outer_envs, inner_envs) + single_observation_space.shape
assert (
observations.shape
== (n_outer_envs, n_inner_envs) + single_observation_space.shape
)

assert isinstance(rewards, np.ndarray)
assert isinstance(rewards[0], np.ndarray)
assert rewards.ndim == 2
assert rewards.shape == (outer_envs, inner_envs)
assert rewards.shape == (n_outer_envs, n_inner_envs)

assert isinstance(dones, np.ndarray)
assert dones.dtype == np.bool_
assert dones.ndim == 2
assert dones.shape == (outer_envs, inner_envs)
assert dones.shape == (n_outer_envs, n_inner_envs)

0 comments on commit bcc5c54

Please sign in to comment.