Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] Total order of keys for spaces.Dict #3023

Closed
1 task done
XuehaiPan opened this issue Aug 13, 2022 · 6 comments · Fixed by #3024
Closed
1 task done

[Bug Report] Total order of keys for spaces.Dict #3023

XuehaiPan opened this issue Aug 13, 2022 · 6 comments · Fixed by #3024

Comments

@XuehaiPan
Copy link
Contributor

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug
A clear and concise description of what the bug is.

Pair of issue jax-ml/jax#11871. A total order is required for dictionary keys.

Equal Dict spaces dict1 == dict2 do not imply dict1.sample() == dict2.sample() and flatten_space(dict1) == flatten_space(dict2).

In [1]: from gym import spaces

In [2]: dict1 = spaces.Dict({1: spaces.Box(0, 1, shape=(1,)), 'a': spaces.Box(0, 2, shape=(1,))})

In [3]: dict2 = spaces.Dict({'a': spaces.Box(0, 2, shape=(1,)), 1: spaces.Box(0, 1, shape=(1,))})

In [4]: dict1
Out[4]: Dict(1: Box(0.0, 1.0, (1,), float32), a: Box(0.0, 2.0, (1,), float32))

In [5]: dict2
Out[5]: Dict(a: Box(0.0, 2.0, (1,), float32), 1: Box(0.0, 1.0, (1,), float32))

In [6]: dict1 == dict2
Out[6]: True

Not equally seeded in subspaces:

In [7]: dict1.seed(0)
Out[7]: [0, 2488343231644625808, 5874934615388537134]

In [8]: dict2.seed(0)
Out[8]: [0, 2488343231644625808, 5874934615388537134]

In [9]: dict1.sample() == dict2.sample()
Out[9]: False

Different order while flattening:

In [10]: spaces.flatten_space(dict1)
Out[10]: Box(0.0, [1. 2.], (2,), float32)

In [11]: spaces.flatten_space(dict2)
Out[11]: Box(0.0, [2. 1.], (2,), float32)

In [12]: spaces.flatten_space(dict1) == spaces.flatten_space(dict2)
Out[12]: False

The order of keys is important when seeding the subspaces and flattening the space.


In spaces.Dict.__init__ method, we always convert the inputs into an OrderedDict:

gym/gym/spaces/dict.py

Lines 82 to 92 in 8b74413

if spaces is None:
spaces = spaces_kwargs
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
try:
spaces = OrderedDict(sorted(spaces.items()))
except TypeError: # raise when sort by different types of keys
spaces = OrderedDict(spaces.items())
if isinstance(spaces, Sequence):
spaces = OrderedDict(spaces)
assert isinstance(spaces, OrderedDict), "spaces must be a dictionary"

However, function sorted will fail when sorting with uncomparable types (e.g. int vs. str):

In [1]: d = {1: 1, 'a': 2}

In [2]: sorted(d.items())
TypeError: '<' not supported between instances of 'str' and 'int'

So we add a failback choice at line 88 in PR #2491. This means when the keys are not sortable, the keys are ordered by the insertion order (since Python 3.6). However, the order of keys is important when seeding the subspaces and flattening the space.

seed: (seed in order of dict.spaces.values())

gym/gym/spaces/dict.py

Lines 131 to 135 in 8b74413

for subspace, subseed in zip(self.spaces.values(), subseeds):
seeds.append(subspace.seed(int(subseed))[0])
elif seed is None:
for space in self.spaces.values():
seeds += space.seed(seed)

flatten_space: (flatten in order of dict.spaces.values())

gym/gym/spaces/utils.py

Lines 326 to 333 in 8b74413

@flatten_space.register(Dict)
def _flatten_space_dict(space: Dict) -> Box:
space_list = [flatten_space(s) for s in space.spaces.values()]
return Box(
low=np.concatenate([s.low for s in space_list]),
high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space_list]),
)

This will cause dict1 == dict2 does not imply flatten_space(dict1) == flatten_space(dict2).

Code example
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

See Description above.

System Info
Describe the characteristic of your environment:

  • Describe how Gym was installed: pip
  • What OS/version of Linux you're using: Ubuntu 20.04 LTS
  • Python version: 3.9

Additional context
Add any other context about the problem here.

See also:

Checklist

  • I have checked that there is no similar issue in the repo (required)
@XuehaiPan
Copy link
Contributor Author

As I commented in jax-ml/jax#11871 (comment), a feasible solution is sorting by key = (obj.__class__.__qualname__, obj). It works for all Python built-in types. However, for user-defined types, there is no total order any more.

@pseudo-rnd-thoughts
Copy link
Contributor

pseudo-rnd-thoughts commented Aug 13, 2022

Thanks for the issue but Im confused what the issue that needs solving is?
Is it that the __eq__ function needs to consider the order of the keys, not just that the same keys exist in both dicts?
Or that we sort no matter if the key type. I should note that in our typing, we assume that the key type is always str so this issue shouldn't occur. In what cases are you using a non-str key?

Looking back, I disable that we sort the dict keys and not just follow the order of insertion but it seems to late to change.

@XuehaiPan
Copy link
Contributor Author

XuehaiPan commented Aug 13, 2022

Im confused what the issue that needs solving is?

@pseudo-rnd-thoughts, the issue is:

Equal Dict spaces dict1 == dict2 do not imply dict1.sample() == dict2.sample() and flatten_space(dict1) == flatten_space(dict2).

There are two solutions:

  1. find a total order and sort the keys by it. (e.g. key = (obj.__class__.__qualname__, obj) works for all Python built-in types, but may not work for user-defined types)
  2. consider the order of the keys in Dict.__eq__ method.

Is it that the __eq__ function needs to consider the order of the keys, not just that the same keys exist in both dicts?

Python built-in dict does not consider the key order either. When we compare OrderedDict vs. dict, the key order is also ignored.

In [1]: from collections import OrderedDict

In [2]: d1 = {1: 0, 'a': 1}

In [3]: d2 = {'a': 1, 1: 0}

In [4]: d1
Out[4]: {1: 0, 'a': 1}

In [5]: d2
Out[5]: {'a': 1, 1: 0}

In [6]: d1 == d2
Out[6]: True

In [7]: od1 = OrderedDict([(1, 0), ('a', 1)])

In [8]: od2 = OrderedDict([('a', 1), (1, 0)])

In [9]: d1 == od1
Out[9]: True

In [10]: d2 == od1
Out[10]: True

When comparing two OrderedDict, the key order is under consideration.

In [11]: od1 == od2
Out[11]: False

In [12]: dict(od1) == dict(od2)
Out[12]: True

We save the dict structure as OrderedDict in our spaces.Dict implementation. It would be fine to consider the key order when comparing two Dict spaces.


Looking back, I disable that we sort the dict keys and not just follow the order of insertion but it seems to late to change.

We do not need to disable the sorting behavior, just updating the __eq__ method is enough.

@pseudo-rnd-thoughts
Copy link
Contributor

Ok that makes sense.
Would you like to make the PR or I can add it to #2977

@XuehaiPan
Copy link
Contributor Author

Would you like to make the PR

Done.

@RedTachyon
Copy link
Contributor

I definitely think this should be fixed by updating the __eq__ to reflect the underlying OrderedDict structure. The current implementation of __eq__ is inherited from collections.Mapping which casts everything to a dict

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants