-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[gym_jiminy] Add FilterObservation and NormalizeObservation wrappers.
- Loading branch information
Alexis Duburcq
committed
Aug 26, 2023
1 parent
504c6dd
commit 90af82f
Showing
7 changed files
with
230 additions
and
57 deletions.
There are no files selected for viewing
8 changes: 4 additions & 4 deletions
8
python/gym_jiminy/common/gym_jiminy/common/wrappers/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
# pylint: disable=missing-module-docstring | ||
|
||
from .frame_rate_limiter import FrameRateLimiter | ||
from .frame_stack import PartialFrameStack, StackedJiminyEnv | ||
from .observation_filter import FilteredJiminyEnv | ||
from .observation_stack import PartialObservationStack, StackedJiminyEnv | ||
|
||
|
||
__all__ = [ | ||
'FrameRateLimiter', | ||
'PartialFrameStack', | ||
'FilteredJiminyEnv', | ||
'PartialObservationStack', | ||
'StackedJiminyEnv' | ||
] |
108 changes: 108 additions & 0 deletions
108
python/gym_jiminy/common/gym_jiminy/common/wrappers/observation_filter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
""" TODO: Write documentation. | ||
""" | ||
from collections import OrderedDict | ||
from typing import Sequence, Union, Generic | ||
from typing_extensions import TypeAlias | ||
|
||
import numpy as np | ||
import gymnasium as gym | ||
|
||
from ..bases import ObsT, ActT, EngineObsType, BasePipelineWrapper | ||
|
||
|
||
FilteredObsType: TypeAlias = ObsT | ||
|
||
|
||
class FilteredJiminyEnv(BasePipelineWrapper[FilteredObsType, ActT, ObsT, ActT], | ||
Generic[ObsT, ActT]): | ||
"""Normalize action space without clipping, contrary to usual | ||
implementations. | ||
""" | ||
def __init__(self, | ||
env: gym.Env[ObsT, ActT], | ||
nested_filter_keys: Sequence[Union[Sequence[str], str]] | ||
) -> None: | ||
# Make sure that the observation space derives from 'gym.spaces.Dict' | ||
assert isinstance(env.observation_space, gym.spaces.Dict) | ||
|
||
# Make sure all nested keys are stored in sequence | ||
self.nested_filter_keys = [] | ||
for key_nested in nested_filter_keys: | ||
if isinstance(key_nested, str): | ||
key_nested = (key_nested,) | ||
self.nested_filter_keys.append(key_nested) | ||
|
||
# Remove redundant nested keys if any | ||
for i, key_nested in list(enumerate(self.nested_filter_keys))[::-1]: | ||
for j, path in list(enumerate(self.nested_filter_keys[:i]))[::-1]: | ||
if path[:len(key_nested)] == key_nested: | ||
self.nested_filter_keys.pop(j) | ||
elif path == key_nested[:len(path)]: | ||
self.nested_filter_keys.pop(i) | ||
break | ||
|
||
# Initialize base class | ||
super().__init__(env) | ||
|
||
# Bind observation of the environment for all filtered keys | ||
self.observation = OrderedDict() | ||
for key_nested in self.nested_filter_keys: | ||
observation_filtered = self.observation | ||
observation = self.env.observation | ||
for key in key_nested[:-1]: | ||
if key not in observation_filtered.keys(): | ||
observation_filtered[key] = OrderedDict() | ||
observation_filtered = observation_filtered[key] | ||
observation = observation[key] | ||
observation_filtered[key_nested[-1]] = observation[key_nested[-1]] | ||
|
||
def _setup(self) -> None: | ||
"""Configure the wrapper. | ||
In addition to the base implementation, it configures the controller | ||
and registers its target to the telemetry. | ||
""" | ||
# Call base implementation | ||
super()._setup() | ||
|
||
# Compute the observe and control update periods | ||
self.observe_dt = self.env.observe_dt | ||
self.control_dt = self.env.control_dt | ||
|
||
def _initialize_action_space(self) -> None: | ||
"""Configure the action space. | ||
""" | ||
self.action_space = self.env.action_space | ||
|
||
def _initialize_observation_space(self) -> None: | ||
"""Configure the observation space. | ||
""" | ||
self.observation_space = gym.spaces.Dict() | ||
for key_nested in self.nested_filter_keys: | ||
space_filtered = self.observation_space | ||
space = self.env.observation_space | ||
for key in key_nested[:-1]: | ||
if key not in space_filtered: | ||
space_filtered[key] = gym.spaces.Dict() | ||
space_filtered = space_filtered[key] | ||
space = space[key] | ||
space_filtered[key_nested[-1]] = space[key_nested[-1]] | ||
|
||
def refresh_observation(self, measurement: EngineObsType) -> None: | ||
"""Compute high-level features based on the current wrapped | ||
environment's observation. | ||
It simply forwards the command computed by the wrapped environment | ||
without any processing. | ||
""" | ||
self.env.refresh_observation(measurement) | ||
|
||
def compute_command(self, action: ActT) -> np.ndarray: | ||
"""Compute the motors efforts to apply on the robot. | ||
It simply forwards the command computed by the wrapped environment | ||
without any processing. | ||
:param action: High-level target to achieve by means of the command. | ||
""" | ||
return self.env.compute_command(action) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 4 additions & 1 deletion
5
python/gym_jiminy/toolbox/gym_jiminy/toolbox/wrappers/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,14 @@ | ||
# pylint: disable=missing-module-docstring | ||
|
||
from .normal_action import NormalizeAction | ||
from .normalize import NormalizeAction, NormalizeObservation | ||
from .frame_rate_limiter import FrameRateLimiter | ||
from .meta_envs import HierarchicalTaskSettableEnv, TaskSchedulingWrapper | ||
|
||
|
||
__all__ = [ | ||
"NormalizeAction", | ||
"NormalizeObservation", | ||
"FrameRateLimiter", | ||
"HierarchicalTaskSettableEnv", | ||
"TaskSchedulingWrapper" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 0 additions & 47 deletions
47
python/gym_jiminy/toolbox/gym_jiminy/toolbox/wrappers/normal_action.py
This file was deleted.
Oops, something went wrong.
109 changes: 109 additions & 0 deletions
109
python/gym_jiminy/toolbox/gym_jiminy/toolbox/wrappers/normalize.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
""" TODO: Write documentation. | ||
""" | ||
from typing import TypeVar | ||
|
||
import numpy as np | ||
import numba as nb | ||
|
||
import gymnasium as gym | ||
|
||
|
||
ObsT = TypeVar("ObsT") | ||
|
||
|
||
@nb.jit(nopython=True, inline='always') | ||
def _normalize(value: np.ndarray, | ||
mean: np.ndarray, | ||
scale: np.ndarray) -> np.ndarray: | ||
"""Element-wise normalization of array. | ||
:param value: Un-normalized data. | ||
:param mean: mean. | ||
:param scale: scale. | ||
""" | ||
return (value - mean) / scale | ||
|
||
|
||
@nb.jit(nopython=True, inline='always') | ||
def _denormalize(value: np.ndarray, | ||
mean: np.ndarray, | ||
scale: np.ndarray) -> np.ndarray: | ||
"""Reverse element-wise normalization of array. | ||
:param value: Normalized data. | ||
:param mean: mean. | ||
:param scale: scale. | ||
""" | ||
return mean + value * scale | ||
|
||
|
||
class NormalizeAction(gym.ActionWrapper): | ||
"""Normalize action space without clipping. | ||
""" | ||
def __init__(self, env: gym.Env[ObsT, np.ndarray]) -> None: | ||
# Make sure that the action space derives from 'gym.spaces.Box' | ||
assert isinstance(env.action_space, gym.spaces.Box) | ||
|
||
# Make sure that it is bounded | ||
low, high = env.action_space.low, env.action_space.high | ||
assert all(np.all(np.isfinite(val)) for val in (low, high)), \ | ||
"Action space must have finite bounds." | ||
|
||
# Assert that it has floating-point dtype | ||
assert env.action_space.dtype is not None | ||
dtype = env.action_space.dtype.type | ||
assert issubclass(dtype, np.floating) | ||
|
||
# Initialize base class | ||
super().__init__(env) | ||
|
||
# Define the action space | ||
self._action_mean = (high + low) / 2.0 | ||
self._action_scale = (high - low) / 2.0 | ||
self.action_space = gym.spaces.Box( | ||
low=-1.0, high=1.0, shape=env.action_space.shape, dtype=dtype) | ||
|
||
# Copy 'mirror_mat' attribute if specified | ||
if hasattr(env.action_space, "mirror_mat"): | ||
self.action_space.mirror_mat = ( # type: ignore[attr-defined] | ||
env.action_space.mirror_mat) | ||
|
||
def action(self, action: np.ndarray) -> np.ndarray: | ||
return _denormalize(action, self._action_mean, self._action_scale) | ||
|
||
|
||
class NormalizeObservation(gym.ObservationWrapper): | ||
"""Normalize observation space without clipping. | ||
""" | ||
def __init__(self, env: gym.Env[ObsT, np.ndarray]) -> None: | ||
# Make sure that the action space derives from 'gym.spaces.Box' | ||
assert isinstance(env.observation_space, gym.spaces.Box) | ||
|
||
# Make sure that it is bounded | ||
low, high = env.observation_space.low, env.observation_space.high | ||
assert all(np.all(np.isfinite(val)) for val in (low, high)), \ | ||
"Observation space must have finite bounds." | ||
|
||
# Assert that it has floating-point dtype | ||
assert env.observation_space.dtype is not None | ||
dtype = env.observation_space.dtype.type | ||
assert issubclass(dtype, np.floating) | ||
|
||
# Initialize base class | ||
super().__init__(env) | ||
|
||
# Define the observation space | ||
self._observation_mean = (high + low) / 2.0 | ||
self._observation_scale = (high - low) / 2.0 | ||
self.observation_space = gym.spaces.Box( | ||
low=-1.0, high=1.0, shape=env.observation_space.shape, dtype=dtype) | ||
|
||
# Copy 'mirror_mat' attribute if specified | ||
if hasattr(env.observation_space, "mirror_mat"): | ||
self.observation_space.mirror_mat = ( # type: ignore[attr-defined] | ||
env.observation_space.mirror_mat) | ||
|
||
def observation(self, observation: np.ndarray) -> np.ndarray: | ||
return _normalize(observation, | ||
self._observation_mean, | ||
self._observation_scale) |