Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] GymWrapper does not work with nested observation gym.spaces.Dict #640

Closed
3 tasks done
raphajaner opened this issue Nov 3, 2022 · 12 comments
Closed
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@raphajaner
Copy link

Describe the bug

Hi All,

First of all: thanks for the great work here!

I think I have encountered a bug in the GymWrapper in torchrl.envs.libs.gym.GymWrapper. When I use a gym.Env with an observation space with nested gym.spaces.Dict, a KeyError will be thrown since the GymLikeEnv.read_obs() function does only add "next_" to the first level of Dict but not to nested sub Dicts:

observations = {"next_" + key: value for key, value in observations.items()}

Since _gym_to_torchrl_spec_transform() in torchrl.envs.libs.gym ends "next_" in a recursive call to all sub Dicts, the key is missing the necessary "next_". Nested Dict observation spaces are often used (https://www.gymlibrary.dev/api/spaces/#dict), so I guess this is required to work properly.

To Reproduce

#!/usr/bin/env python
from torchrl.envs.libs.gym import GymWrapper
from gym import spaces, Env
import numpy as np


class CustomGym(Env):
    def __init__(self):
        self.action_space = spaces.Discrete(5)
        self.observation_space = spaces.Dict(
            {
                'sensor_1': spaces.Box(low=0, high=255, shape=(5, 5, 3), dtype=np.uint8),
                'sensor_2': spaces.Box(low=0, high=255, shape=(5, 5, 3), dtype=np.uint8),
                'sensor_3': spaces.Box(np.array([-2, -1, -5, 0]), np.array([2, 1, 30, 1]), dtype=np.float32),
                'sensor_4': spaces.Dict({'sensor_41': spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32),
                                         'sensor_42': spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32),
                                         'sensor_43': spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32)})
            }
        )

    def reset(self):
        return self.observation_space.sample()


if __name__ == '__main__':
    env = CustomGym()
    env = GymWrapper(env)

Reason and Possible fixes

The issue can be fixed by adding a recursive function call to rename also nested observation space Dicts in GymLikeEnv.read_obs() correctly by adding "next_":

    def read_obs(
        self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray]
    ) -> Dict[str, Any]:
        """Reads an observation from the environment and returns an observation compatible with the output TensorDict.

        Args:
            observations (observation under a format dictated by the inner env): observation to be read.

        """
        if isinstance(observations, dict):

            def rename(obs):
                return {
                    "next_" + key: rename(value) if isinstance(value, dict) else value
                    for key, value in obs.items()
                }

            observations = rename(observations)
        if not isinstance(observations, (TensorDict, dict)):
            key = list(self.observation_spec.keys())[0]
            observations = {key: observations}
        observations = self.observation_spec.encode(observations)
        return observations

The style checker required to not use lambda functions, otherwise the fix could also be as simple as

             rename = lambda obs: {
                "next_" + key: rename(value) if isinstance(value, dict) else value
                for key, value in obs.items()
             }

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@raphajaner raphajaner added the bug Something isn't working label Nov 3, 2022
@raphajaner raphajaner changed the title [BUG] GymWrapper does not work with nested gym.spaces.Dict [BUG] GymWrapper does not work with nested observation gym.spaces.Dict Nov 3, 2022
@vmoens
Copy link
Contributor

vmoens commented Nov 3, 2022

Hey @raphajaner thanks for this, and glad you like the lib.
I'll fix that asap, bear with me.

@vmoens
Copy link
Contributor

vmoens commented Nov 3, 2022

Oh by the way, since you have worked out the solution, do you want to implement the fix? I can take care of that otherwise. The only thing missing from what i see would be to reproduce your code example in a test.

Side note: since we send envs and objects from process to process, I usually try to avoid using lambda functions as they don't serialize well. There are solutions but not using one is always easier :)

@raphajaner raphajaner reopened this Nov 3, 2022
@raphajaner
Copy link
Author

Yes sure, I’ll take care of it :) Thanks for the feedback!

@raphajaner
Copy link
Author

Hi @vmoens,
I was looking into fixing the issue yesterday but I think the issue might be a bit more involved and could require some small design decisions. When a Dict obs space is used, a CompositeSpec spec will be created and all the keys of it will start with "next_". That's where the original issue occurred that only the first level of Dict will be renamed. This means that in _make_specs, the resulting CompositeSpec will be directly used as the self.observation_spec. However, in the cases where a different observation space is used (eg., gym.spaces.Discrete), the following gets triggered:

        if not isinstance(self.observation_spec, CompositeSpec):
            if self.from_pixels:
                self.observation_spec = CompositeSpec(next_pixels=self.observation_spec)
            else:
                self.observation_spec = CompositeSpec(
                    next_observation=self.observation_spec
                )

So for all other obs space types, the "next_" is only then added summarizing the space in the key "next_observation" or "next_pixels", respectively. In the function step_mdp function in torchrl.envs.utils, the "next_" part will get removed again for the correct output of the td:

new_keys = [key[5:] for key in keys]

which works fine when the obs are always under the "next_observation" key. This won't work for CompositeSpecs that are directly created from nested Dict gym obs spaces. A fix here should also be easy by renaming the keys in step_mdp recursively too.

However, I'm in general a bit unsure if for Dict gym obs spaces all the keys of that Dict should be renamed and kept separately instead of simply summarizing the whole Dict in a CompositeSpec with the single key "next_observation' that then accesses the different items of the Dict. IMO this makes more sense, for example, if I want to access the "camera" feature in the "next_observation" by calling td["next_observation"]["camera"] than calling td["next_camera"]. What is your take on this?

@vmoens
Copy link
Contributor

vmoens commented Nov 4, 2022

Hey!
I see your point. We're thinking about redesigning this API. I will open a PR with that shortly, but I'd be glad to get your thoughts about it.

First I think the "next_obs" is messy and makes it hard to get the tensordict of the next step. Second it does not scale well with other problems (e.g. MCTS or planners in general where we explore many different possible actions for a single state). Finally it requires for the users to pay attention to name the obs in the specs with the "next" prefix which they might as well forget and find cumbersome.

Here's what I would see:
Before: env.step returns

TensorDict({
  “state”: stuff,
  “reward”: reward,
  “done”: done,
  "action": action,
  “next_state”: stuff,
  "other": foo,
}, [])

We would change that in:

TensorDict({
  “state”: stuff,
  “reward”: reward,
  “done”: done,
  "action": action,
  “next”: TensorDict({
      “state”: stuff,
    }, []),
  "other": foo,
}, [])

That way, step_mdp just needs to do tensordict = tensordict["step"].clone(recurse=False) (we clone it, otherwise the original tensordict will keep track of the whole trajectory!)
If you likes the previous API you can just do tensordict.flatten_keys("_").

So in your case you'd have this

TensorDict({
  “state”: stuff,
  “reward”: reward,
  “done”: done,
  "action": action,
  "camera": cam,
  “next”: TensorDict({
      “state”: stuff,
      "camera": cam,
    }, []),
  "other": foo,
}, [])

Thoughts?

cc @shagunsodhani (by the way it's funny that we were just talking about that feature a couple of hours ago and @raphajaner came with a very similar idea!)

@vmoens
Copy link
Contributor

vmoens commented Nov 4, 2022

Forgot to mention: in the MCTS case, we'd have a tensordict slightly different:

TensorDict({
  “state”: stuff,
  “reward”: reward,
  “done”: done,
  "action": action,
  “next”: TensorDict({
      "0": TensorDict({
        “state”: stuff,
        "reward": reward,
        "done": done,
      }, []),
       "1": TensorDict({
        “state”: stuff,
        "reward": reward,
        "done": done,
      }, []),
    }, []),
  "other": foo,
}, [])

where 0 and 1 are the 2 possible actions.

For continuouss action domains, we'd need a custom hashing function given by the user to convert actions to a string.
But that means that if we'd like to generalize this to the other envs, each would need to index the next state with the action taken, meaning that we'd need a hash function for each env (i.e. it would quickly become a nightmare and a very difficult thing to sell to the lambda users).

@raphajaner
Copy link
Author

Yes seems like a very reasonable solution to it! I like the thought of having more flexibility in "next" for other cases. Just a small note about the example with the "camera". I'd see the "camera" as part of the overall "state" along with other sensors, e.g., like this:

TensorDict({
  “state”:  TensorDict({
      “vehicle_state”: stuff,
      "camera": cam,
      }, []),
  “reward”: reward,
  “done”: done,
  "action": action,
  “next”: TensorDict({
      “state”:  TensorDict({
          “vehicle_state”: stuff,
          "camera": cam,
          }, []),
   }, []),
  "other": foo,
}, [])

As you're already thinking about redesigning the API, I guess it could also make sense to think about the "done" part which has been recently split into "truncated" and "terminated" in gym. IMO that made a lot of sense.

I think I see you're point regarding continuous action domains in MCTS. I'm not sure if this makes sense but couldn't it be possible to use the object id() of the next state as the hash index and include this directly to the action as info?

@vmoens
Copy link
Contributor

vmoens commented Nov 4, 2022

I think I see you're point regarding continuous action domains in MCTS. I'm not sure if this makes sense but couldn't it be possible to use the object id() of the next state as the hash index and include this directly to the action as info?

OI thought about it but I guess it'll break in distributed settings (which is typical for mcts)

@shagunsodhani
Copy link
Contributor

As you're already thinking about redesigning the API, I guess it could also make sense to think about the "done" part which has been recently split into "truncated" and "terminated" in gym. IMO that made a lot of sense.

I am not sure how actively the community would pick up this change so we may want to wait for a while before changing the API.

@vmoens
Copy link
Contributor

vmoens commented Nov 5, 2022 via email

@shagunsodhani
Copy link
Contributor

Supporting both makes sense. Though if this requires a lot of work, we may want to wait a while and see how strong is the demand for the new API.

@raphajaner
Copy link
Author

Closed by #649

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants