Skip to content

Commit

Permalink
[gym_jiminy] Add FilterObservation and NormalizeObservation wrappers.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexis Duburcq committed Aug 26, 2023
1 parent 504c6dd commit 90af82f
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 57 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# pylint: disable=missing-module-docstring

from .frame_rate_limiter import FrameRateLimiter
from .frame_stack import PartialFrameStack, StackedJiminyEnv
from .observation_filter import FilteredJiminyEnv
from .observation_stack import PartialObservationStack, StackedJiminyEnv


__all__ = [
'FrameRateLimiter',
'PartialFrameStack',
'FilteredJiminyEnv',
'PartialObservationStack',
'StackedJiminyEnv'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
""" TODO: Write documentation.
"""
from collections import OrderedDict
from typing import Sequence, Union, Generic
from typing_extensions import TypeAlias

import numpy as np
import gymnasium as gym

from ..bases import ObsT, ActT, EngineObsType, BasePipelineWrapper


FilteredObsType: TypeAlias = ObsT


class FilteredJiminyEnv(BasePipelineWrapper[FilteredObsType, ActT, ObsT, ActT],
Generic[ObsT, ActT]):
"""Normalize action space without clipping, contrary to usual
implementations.
"""
def __init__(self,
env: gym.Env[ObsT, ActT],
nested_filter_keys: Sequence[Union[Sequence[str], str]]
) -> None:
# Make sure that the observation space derives from 'gym.spaces.Dict'
assert isinstance(env.observation_space, gym.spaces.Dict)

# Make sure all nested keys are stored in sequence
self.nested_filter_keys = []
for key_nested in nested_filter_keys:
if isinstance(key_nested, str):
key_nested = (key_nested,)
self.nested_filter_keys.append(key_nested)

# Remove redundant nested keys if any
for i, key_nested in list(enumerate(self.nested_filter_keys))[::-1]:
for j, path in list(enumerate(self.nested_filter_keys[:i]))[::-1]:
if path[:len(key_nested)] == key_nested:
self.nested_filter_keys.pop(j)
elif path == key_nested[:len(path)]:
self.nested_filter_keys.pop(i)
break

# Initialize base class
super().__init__(env)

# Bind observation of the environment for all filtered keys
self.observation = OrderedDict()
for key_nested in self.nested_filter_keys:
observation_filtered = self.observation
observation = self.env.observation
for key in key_nested[:-1]:
if key not in observation_filtered.keys():
observation_filtered[key] = OrderedDict()
observation_filtered = observation_filtered[key]
observation = observation[key]
observation_filtered[key_nested[-1]] = observation[key_nested[-1]]

def _setup(self) -> None:
"""Configure the wrapper.
In addition to the base implementation, it configures the controller
and registers its target to the telemetry.
"""
# Call base implementation
super()._setup()

# Compute the observe and control update periods
self.observe_dt = self.env.observe_dt
self.control_dt = self.env.control_dt

def _initialize_action_space(self) -> None:
"""Configure the action space.
"""
self.action_space = self.env.action_space

def _initialize_observation_space(self) -> None:
"""Configure the observation space.
"""
self.observation_space = gym.spaces.Dict()
for key_nested in self.nested_filter_keys:
space_filtered = self.observation_space
space = self.env.observation_space
for key in key_nested[:-1]:
if key not in space_filtered:
space_filtered[key] = gym.spaces.Dict()
space_filtered = space_filtered[key]
space = space[key]
space_filtered[key_nested[-1]] = space[key_nested[-1]]

def refresh_observation(self, measurement: EngineObsType) -> None:
"""Compute high-level features based on the current wrapped
environment's observation.
It simply forwards the command computed by the wrapped environment
without any processing.
"""
self.env.refresh_observation(measurement)

def compute_command(self, action: ActT) -> np.ndarray:
"""Compute the motors efforts to apply on the robot.
It simply forwards the command computed by the wrapped environment
without any processing.
:param action: High-level target to achieve by means of the command.
"""
return self.env.compute_command(action)
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
StackedObsType: TypeAlias = ObsT


class PartialFrameStack(
class PartialObservationStack(
gym.Wrapper, # [StackedObsType, ActT, ObsT, ActT],
Generic[ObsT, ActT]):
"""Observation wrapper that partially stacks observations in a rolling
manner.
This wrapper combines and extends OpenAI Gym wrappers `FrameStack` and
`FilterObservation` to support nested filter keys.
`FilteredJiminyEnv` to support nested filter keys.
It adds one extra dimension to all the leaves of the original observation
spaces that must be stacked. If so, the first dimension corresponds to the
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(self,
self.__n_last_stack = 0

# Instantiate wrapper
self.wrapper = PartialFrameStack(env, **kwargs)
self.wrapper = PartialObservationStack(env, **kwargs)

# Bind the observation of the wrapper
self.observation = self.wrapper.observation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# pylint: disable=missing-module-docstring

from .normal_action import NormalizeAction
from .normalize import NormalizeAction, NormalizeObservation
from .frame_rate_limiter import FrameRateLimiter
from .meta_envs import HierarchicalTaskSettableEnv, TaskSchedulingWrapper


__all__ = [
"NormalizeAction",
"NormalizeObservation",
"FrameRateLimiter",
"HierarchicalTaskSettableEnv",
"TaskSchedulingWrapper"
]
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from jiminy_py.viewer import sleep

from ..bases import ObsT, ActT, InfoType, JiminyEnvInterface
from ..envs import BaseJiminyEnv
from gym_jiminy.common.bases import ObsT, ActT, InfoType, JiminyEnvInterface
from gym_jiminy.common.envs import BaseJiminyEnv


class FrameRateLimiter(gym.Wrapper, # [ObsT, ActT, ObsT, ActT],
Expand Down

This file was deleted.

109 changes: 109 additions & 0 deletions python/gym_jiminy/toolbox/gym_jiminy/toolbox/wrappers/normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
""" TODO: Write documentation.
"""
from typing import TypeVar

import numpy as np
import numba as nb

import gymnasium as gym


ObsT = TypeVar("ObsT")


@nb.jit(nopython=True, inline='always')
def _normalize(value: np.ndarray,
mean: np.ndarray,
scale: np.ndarray) -> np.ndarray:
"""Element-wise normalization of array.
:param value: Un-normalized data.
:param mean: mean.
:param scale: scale.
"""
return (value - mean) / scale


@nb.jit(nopython=True, inline='always')
def _denormalize(value: np.ndarray,
mean: np.ndarray,
scale: np.ndarray) -> np.ndarray:
"""Reverse element-wise normalization of array.
:param value: Normalized data.
:param mean: mean.
:param scale: scale.
"""
return mean + value * scale


class NormalizeAction(gym.ActionWrapper):
"""Normalize action space without clipping.
"""
def __init__(self, env: gym.Env[ObsT, np.ndarray]) -> None:
# Make sure that the action space derives from 'gym.spaces.Box'
assert isinstance(env.action_space, gym.spaces.Box)

# Make sure that it is bounded
low, high = env.action_space.low, env.action_space.high
assert all(np.all(np.isfinite(val)) for val in (low, high)), \
"Action space must have finite bounds."

# Assert that it has floating-point dtype
assert env.action_space.dtype is not None
dtype = env.action_space.dtype.type
assert issubclass(dtype, np.floating)

# Initialize base class
super().__init__(env)

# Define the action space
self._action_mean = (high + low) / 2.0
self._action_scale = (high - low) / 2.0
self.action_space = gym.spaces.Box(
low=-1.0, high=1.0, shape=env.action_space.shape, dtype=dtype)

# Copy 'mirror_mat' attribute if specified
if hasattr(env.action_space, "mirror_mat"):
self.action_space.mirror_mat = ( # type: ignore[attr-defined]
env.action_space.mirror_mat)

def action(self, action: np.ndarray) -> np.ndarray:
return _denormalize(action, self._action_mean, self._action_scale)


class NormalizeObservation(gym.ObservationWrapper):
"""Normalize observation space without clipping.
"""
def __init__(self, env: gym.Env[ObsT, np.ndarray]) -> None:
# Make sure that the action space derives from 'gym.spaces.Box'
assert isinstance(env.observation_space, gym.spaces.Box)

# Make sure that it is bounded
low, high = env.observation_space.low, env.observation_space.high
assert all(np.all(np.isfinite(val)) for val in (low, high)), \
"Observation space must have finite bounds."

# Assert that it has floating-point dtype
assert env.observation_space.dtype is not None
dtype = env.observation_space.dtype.type
assert issubclass(dtype, np.floating)

# Initialize base class
super().__init__(env)

# Define the observation space
self._observation_mean = (high + low) / 2.0
self._observation_scale = (high - low) / 2.0
self.observation_space = gym.spaces.Box(
low=-1.0, high=1.0, shape=env.observation_space.shape, dtype=dtype)

# Copy 'mirror_mat' attribute if specified
if hasattr(env.observation_space, "mirror_mat"):
self.observation_space.mirror_mat = ( # type: ignore[attr-defined]
env.observation_space.mirror_mat)

def observation(self, observation: np.ndarray) -> np.ndarray:
return _normalize(observation,
self._observation_mean,
self._observation_scale)

0 comments on commit 90af82f

Please sign in to comment.