Skip to content

Commit

Permalink
Respect the order of keys in a Dict's observation space when flatteni…
Browse files Browse the repository at this point in the history
…ng (#1748)

* Respect the order of keys in a Dict's observation space when flattening

Prior to this change, the order of the key/values in the observation was used instead of the order in the Dict's observation space. unflatten already uses the order specified by the Dict's observation space.

* add tests for FlattenObservation
  • Loading branch information
dwiel authored and pzhokhov committed Dec 6, 2019
1 parent 59d401e commit 3ee7e67
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gym/spaces/utils.py
Expand Up @@ -35,7 +35,7 @@ def flatten(space, x):
elif isinstance(space, Tuple):
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
elif isinstance(space, Dict):
return np.concatenate([flatten(space.spaces[key], item) for key, item in x.items()])
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
elif isinstance(space, MultiBinary):
return np.asarray(x).flatten()
elif isinstance(space, MultiDiscrete):
Expand Down
96 changes: 96 additions & 0 deletions tests/gym/wrappers/flatten_test.py
@@ -0,0 +1,96 @@
"""Tests for the flatten observation wrapper."""

from collections import OrderedDict

import numpy as np
import pytest

import gym
from gym.spaces import Box, Dict, unflatten, flatten
from gym.wrappers import FlattenObservation


class FakeEnvironment(gym.Env):
def __init__(self, observation_space):
self.observation_space = observation_space

def reset(self):
self.observation = self.observation_space.sample()
return self.observation


OBSERVATION_SPACES = (
(
Dict(
OrderedDict(
[
("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)),
("key2", Box(shape=(), low=1, high=1, dtype=np.float32)),
("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)),
]
)
),
True,
),
(
Dict(
OrderedDict(
[
("key2", Box(shape=(), low=0, high=0, dtype=np.float32)),
("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)),
("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)),
]
)
),
True,
),
(
Dict(
{
"key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
False,
),
)


class TestFlattenEnvironment(object):
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
def test_flattened_environment(self, observation_space, ordered_values):
"""
make sure that flattened observations occur in the order expected
"""
env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(env)
flattened = wrapped_env.reset()

unflattened = unflatten(env.observation_space, flattened)
original = env.observation

self._check_observations(original, flattened, unflattened, ordered_values)

@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
def test_flatten_unflatten(self, observation_space, ordered_values):
"""
test flatten and unflatten functions directly
"""
original = observation_space.sample()

flattened = flatten(observation_space, original)
unflattened = unflatten(observation_space, flattened)

self._check_observations(original, flattened, unflattened, ordered_values)

def _check_observations(self, original, flattened, unflattened, ordered_values):
# make sure that unflatten(flatten(original)) == original
assert set(unflattened.keys()) == set(original.keys())
for k, v in original.items():
np.testing.assert_allclose(unflattened[k], v)

if ordered_values:
# make sure that the values were flattened in the order they appeared in the
# OrderedDict
np.testing.assert_allclose(sorted(flattened), flattened)

0 comments on commit 3ee7e67

Please sign in to comment.