Permalink
Browse files

Added .shape attr to Dict space. Fixed Tuple.from_jsonable() method a…

…s it has been returning <zip> object instead of list of tuples. Test with nested Dict space passed. (#763)
  • Loading branch information...
Kismuz authored and wojzaremba committed Nov 5, 2017
1 parent c659a85 commit 69b677e6d8bfc0b86f586ca5ee13620b20fab90e
Showing with 63 additions and 12 deletions.
  1. +27 −1 gym/spaces/dict_space.py
  2. +1 −1 gym/spaces/discrete.py
  3. +3 −2 gym/spaces/multi_binary.py
  4. +27 −7 gym/spaces/tests/test_spaces.py
  5. +5 −1 gym/spaces/tuple_space.py
View
@@ -3,17 +3,43 @@
class Dict(Space):
"""
A dictionary of simpler spaces
A dictionary of simpler spaces.
Example usage:
self.observation_space = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
Example usage [nested]:
self.nested_observation_space = spaces.Dict({
'sensors': spaces.Dict({
'position': spaces.Box(low=-100, high=100, shape=(3)),
'velocity': spaces.Box(low=-1, high=1, shape=(3)),
'front_cam': spaces.Tuple((
spaces.Box(low=0, high=1, shape=(10, 10, 3)),
spaces.Box(low=0, high=1, shape=(10, 10, 3))
)),
'rear_cam': spaces.Box(low=0, high=1, shape=(10, 10, 3)),
}),
'ext_controller': spaces.MultiDiscrete([ [0,4], [0,1], [0,1] ]),
'inner_state':spaces.Dict({
'charge': spaces.Discrete(100),
'system_checks': spaces.MultiBinary(10),
'job_status': spaces.Dict({
'task': spaces.Discrete(5),
'progress': spaces.Box(low=0, high=100, shape=()),
})
})
})
"""
def __init__(self, spaces):
if isinstance(spaces, dict):
spaces = OrderedDict(sorted(list(spaces.items())))
if isinstance(spaces, list):
spaces = OrderedDict(spaces)
self.spaces = spaces
self.shape = self._get_shape()
def _get_shape(self):
return OrderedDict([(k, space.shape) for k, space in self.spaces.items()])
def sample(self):
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
View
@@ -25,7 +25,7 @@ def contains(self, x):
@property
def shape(self):
return ()
return (self.n,)
def __repr__(self):
return "Discrete(%d)" % self.n
def __eq__(self, other):
@@ -5,11 +5,12 @@
class MultiBinary(gym.Space):
def __init__(self, n):
self.n = n
self.shape = (n,)
def sample(self):
return prng.np_random.randint(low=0, high=2, size=self.n)
def contains(self, x):
return ((x==0) | (x==1)).all()
def to_jsonable(self, sample_n):
return sample_n.tolist()
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n):
return np.array(sample_n)
return [np.asarray(sample) for sample in sample_n]
@@ -1,16 +1,36 @@
import json # note: ujson fails this test due to float equality
import numpy as np
import pytest
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, Dict
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
@pytest.mark.parametrize("space", [
Discrete(3),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Box(np.array([0,0]),np.array([1,5]))]),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([ [0, 1], [0, 1], [0, 100] ]),
Dict({"position": Discrete(5), "velocity": Box(np.array([0,0]),np.array([1,5]))}),
Discrete(3),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Box(np.array([0,0]),np.array([1,5]))]),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiBinary(10),
MultiDiscrete([ [0, 1], [0, 1], [0, 100] ]),
Dict({
'sensors': Dict({
'position': Box(low=-100, high=100, shape=(3)),
'velocity': Box(low=-1, high=1, shape=(3)),
'front_cam': Tuple((
Box(low=0, high=1, shape=(10, 10, 3)),
Box(low=0, high=1, shape=(10, 10, 3))
)),
'rear_cam': Box(low=0, high=1, shape=(10, 10, 3)),
}),
'ext_controller': MultiDiscrete([[0, 4], [0, 1], [0, 1]]),
'inner_state': Dict({
'charge': Discrete(100),
'system_checks': MultiBinary(10),
'job_status': Dict({
'task': Discrete(5),
'progress': Box(low=0, high=100, shape=()),
})
})
})
])
def test_roundtripping(space):
sample_1 = space.sample()
@@ -9,6 +9,7 @@ class Tuple(Space):
"""
def __init__(self, spaces):
self.spaces = spaces
self.shape = self._get_shape()
def sample(self):
return tuple([space.sample() for space in self.spaces])
@@ -19,6 +20,9 @@ def contains(self, x):
return isinstance(x, tuple) and len(x) == len(self.spaces) and all(
space.contains(part) for (space,part) in zip(self.spaces,x))
def _get_shape(self):
return tuple([space.shape for space in self.spaces])
def __repr__(self):
return "Tuple(" + ", ". join([str(s) for s in self.spaces]) + ")"
@@ -28,4 +32,4 @@ def to_jsonable(self, sample_n):
for i, space in enumerate(self.spaces)]
def from_jsonable(self, sample_n):
return zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])
return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]

1 comment on commit 69b677e

@diegoalejogm

This comment has been minimized.

Show comment
Hide comment
@diegoalejogm

diegoalejogm Dec 6, 2017

When are these changes going to be uploaded to Pip/PyPI?

diegoalejogm commented on 69b677e Dec 6, 2017

When are these changes going to be uploaded to Pip/PyPI?

Please sign in to comment.