Skip to content

Commit

Permalink
[gym/common] Restrict usage of error-prone '__getattr__' fallback in …
Browse files Browse the repository at this point in the history
…pipelines.
  • Loading branch information
duburcqa committed Jun 10, 2024
1 parent 2f5adea commit 71b8873
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 34 deletions.
23 changes: 21 additions & 2 deletions python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..utils import DataNested
if TYPE_CHECKING:
from ..envs.generic import BaseJiminyEnv
from ..quantities import QuantityManager


Expand Down Expand Up @@ -220,6 +221,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# Call super to allow mixing interfaces through multiple inheritance
super().__init__(*args, **kwargs)

# Define convenience proxy for quantity manager
self.quantities = self.unwrapped.quantities

def _setup(self) -> None:
"""Configure the observer-controller.
Expand Down Expand Up @@ -335,9 +339,24 @@ def _controller_handle(self,
# '_controller_handle' as it is never called more often than necessary.
self.__is_observation_refreshed = False

def stop(self) -> None:
"""Stop the episode immediately without waiting for a termination or
truncation condition to be satisfied.
.. note::
This method is mainly intended for data analysis and debugging.
Stopping the episode is necessary to log the final state, otherwise
it will be missing from plots and viewer replay. Moreover, sensor
data will not be available during replay using object-oriented
method `replay`. Helper method `play_logs_data` must be preferred
to replay an episode that cannot be stopped at the time being.
"""
self.simulator.stop()

@property
def unwrapped(self) -> "InterfaceJiminyEnv":
"""Base environment of the pipeline.
def unwrapped(self) -> "BaseJiminyEnv":
"""The "underlying environment at the basis of the pipeline from which
this environment is part of.
"""
return self

Expand Down
45 changes: 29 additions & 16 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
eventually already wrapped, so that it appears as a black-box environment.
"""
import math
import logging
from weakref import ref
from copy import deepcopy
from abc import abstractmethod
from collections import OrderedDict
from typing import (
Dict, Any, List, Optional, Tuple, Union, Generic, TypeVar, SupportsFloat,
Callable, cast)
Callable, cast, TYPE_CHECKING)

import numpy as np

Expand All @@ -37,6 +38,8 @@
from .blocks import BaseControllerBlock, BaseObserverBlock

from ..utils import DataNested, is_breakpoint, zeros, build_copyto, copy
if TYPE_CHECKING:
from ..envs.generic import BaseJiminyEnv


OtherObsT = TypeVar('OtherObsT', bound=DataNested)
Expand All @@ -46,6 +49,9 @@
TransformedActT = TypeVar('TransformedActT', bound=DataNested)


LOGGER = logging.getLogger(__name__)


class BasePipelineWrapper(
InterfaceJiminyEnv[ObsT, ActT],
Generic[ObsT, ActT, BaseObsT, BaseActT]):
Expand Down Expand Up @@ -101,7 +107,17 @@ def __getattr__(self, name: str) -> Any:
It enables to get access to the attribute and methods of the wrapped
environment directly without having to do it through `env`.
.. warning::
This fallback incurs a significant runtime overhead. As such, it
must only be used for debug and manual analysis between episodes.
Calling this method if a simulation is already running would
trigger a warning to avoid relying on it by mistake.
"""
if self.is_simulation_running:
LOGGER.warning(
"Relying on fallback attribute getter is inefficient and "
"strongly discouraged at runtime.")
return getattr(self.__getattribute__('env'), name)

def __dir__(self) -> List[str]:
Expand Down Expand Up @@ -143,9 +159,7 @@ def np_random(self, value: np.random.Generator) -> None:
self.env.np_random = value

@property
def unwrapped(self) -> InterfaceJiminyEnv:
"""Base environment of the pipeline.
"""
def unwrapped(self) -> "BaseJiminyEnv":
return self.env.unwrapped

@property
Expand Down Expand Up @@ -236,8 +250,7 @@ def step(self, # type: ignore[override]
self._copyto_action(action)

# Make sure that the pipeline has not change since last reset
env_derived = (
self.unwrapped.derived) # type: ignore[attr-defined]
env_derived = self.unwrapped.derived
if env_derived is not self:
raise RuntimeError(
"Pipeline environment has changed. Please call 'reset' "
Expand Down Expand Up @@ -532,14 +545,14 @@ def __init__(self,
# Register the observer's internal state and feature to the telemetry
if state is not None:
try:
self.env.register_variable( # type: ignore[attr-defined]
self.unwrapped.register_variable(
'state', state, None, self.observer.name)
except ValueError:
pass
self.env.register_variable('feature', # type: ignore[attr-defined]
self.observer.observation,
self.observer.fieldnames,
self.observer.name)
self.unwrapped.register_variable('feature',
self.observer.observation,
self.observer.fieldnames,
self.observer.name)

def _setup(self) -> None:
"""Configure the wrapper.
Expand Down Expand Up @@ -750,14 +763,14 @@ def __init__(self,
# Register the controller's internal state and target to the telemetry
if state is not None:
try:
self.env.register_variable( # type: ignore[attr-defined]
self.unwrapped.register_variable(
'state', state, None, self.controller.name)
except ValueError:
pass
self.env.register_variable('action', # type: ignore[attr-defined]
self.action,
self.controller.fieldnames,
self.controller.name)
self.unwrapped.register_variable('action',
self.action,
self.controller.fieldnames,
self.controller.name)

def _setup(self) -> None:
"""Configure the wrapper.
Expand Down
31 changes: 18 additions & 13 deletions python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class BaseJiminyEnv(InterfaceJiminyEnv[ObsT, ActT],
to implement one. It has been designed to be highly flexible and easy to
customize by overloading it to fit the vast majority of users' needs.
"""

derived: "InterfaceJiminyEnv"
"""Top-most block from which this environment is part of when leveraging
modular pipeline design capability.
"""

def __init__(self,
simulator: Simulator,
step_dt: float,
Expand Down Expand Up @@ -186,8 +192,8 @@ def __init__(self,
self.sensor_measurements: SensorMeasurementStackMap = OrderedDict(
self.robot.sensor_measurements)

# Top-most block of the pipeline to which the environment is part of
self.derived: InterfaceJiminyEnv = self
# Top-most block of the pipeline is the environment itself by default
self.derived = self

# Store references to the variables to register to the telemetry
self._registered_variables: MutableMappingT[
Expand Down Expand Up @@ -215,6 +221,9 @@ def __init__(self,
self.num_steps = np.array(-1, dtype=np.int64)
self._num_steps_beyond_terminate: Optional[int] = None

# Initialize a quantity manager for later use
self.quantities = QuantityManager(self)

# Initialize the interfaces through multiple inheritance
super().__init__() # Do not forward extra arguments, if any

Expand All @@ -233,9 +242,6 @@ def __init__(self,
"`BaseJiminyEnv.compute_command` must be overloaded in case "
"of custom action spaces.")

# Initialize a quantity manager for later use
self.quantities = QuantityManager(self)

# Define specialized operators for efficiency.
# Note that a partial view of observation corresponding to measurement
# must be extracted since only this one must be updated during refresh.
Expand Down Expand Up @@ -599,8 +605,8 @@ def reset(self, # type: ignore[override]
if seed is not None:
self._initialize_seed(seed)

# Stop the simulator
self.simulator.stop()
# Stop the episode if one is still running
self.stop()

# Remove external forces, if any
self.simulator.remove_all_forces()
Expand Down Expand Up @@ -854,7 +860,7 @@ def step(self, # type: ignore[override]
self.simulator.step(self.step_dt)
except Exception:
# Stop the simulation before raising the exception
self.simulator.stop()
self.stop()
raise

# Make sure there is no 'nan' value in observation
Expand Down Expand Up @@ -1023,8 +1029,8 @@ def replay(self, **kwargs: Any) -> None:
kwargs['close_backend'] = not self.simulator.is_viewer_available

# Stop any running simulation before replay if `has_terminated` is True
if self.is_simulation_running and any(self.has_terminated({})):
self.simulator.stop()
if any(self.has_terminated({})):
self.stop()

with viewer_lock:
# Call render before replay in order to take into account custom
Expand Down Expand Up @@ -1135,8 +1141,7 @@ def _interact(key: Optional[str] = None) -> bool:

# Stop the simulation to unlock the robot.
# It will enable to display contact forces for replay.
if self.simulator.is_simulation_running:
self.simulator.stop()
self.stop()

# Disable play interactive mode flag and restore training flag
self._is_interactive = False
Expand Down Expand Up @@ -1213,7 +1218,7 @@ def evaluate(self,
action = policy_fn(obs, reward, terminated or truncated, info)
obs, reward, terminated, truncated, info = env.step(action)
info_episode.append(info)
self.simulator.stop()
self.stop()
except KeyboardInterrupt:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def __init__(self,
# Whether the stack has been shifted to the left since last update
self._was_stack_shifted = True

# Bind action of the base environment
assert self.action_space.contains(self.env.action)
self.action = self.env.action

def _initialize_action_space(self) -> None:
"""Configure the action space.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, # pylint: disable=unused-argument
self.human_only = human_only

# Extract proxies for convenience
assert isinstance(env.unwrapped, BaseJiminyEnv)
self._step_dt_rel = env.unwrapped.step_dt / speed_ratio

# Buffer to keep track of the last time `step` method was called
Expand Down
6 changes: 4 additions & 2 deletions python/gym_jiminy/unit_py/test_pipeline_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _test_pid_standing(self):
# Run the simulation
while self.env.stepper_state.t < 9.0:
self.env.step(action)
self.env.stop()

# Export figure
fd, pdf_path = mkstemp(prefix="plot_", suffix=".pdf")
Expand Down Expand Up @@ -263,6 +264,7 @@ def test_pd_controller(self):
env.unwrapped._height_neutral = float("-inf")
while env.stepper_state.t < 2.0:
env.step(0.2 * env.action_space.sample())
env.stop()

# Extract the target position and velocity of a single motor
adapter_name, controller_name = adapter.name, controller.name
Expand All @@ -284,9 +286,9 @@ def test_pd_controller(self):
command_vel[(update_ratio-1)::update_ratio],
atol=TOLERANCE)
np.testing.assert_allclose(
target_accel_diff, target_accel[1:], atol=TOLERANCE)
target_accel_diff[:-1], target_accel[1:-1], atol=TOLERANCE)
np.testing.assert_allclose(
target_vel_diff, target_vel[1:], atol=TOLERANCE)
target_vel_diff[:-1], target_vel[1:-1], atol=TOLERANCE)

# Make sure that the position and velocity targets are within bounds
motor = env.robot.motors[-1]
Expand Down
1 change: 1 addition & 0 deletions python/gym_jiminy/unit_py/test_pipeline_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def configure_telemetry() -> InterfaceJiminyEnv:

env.reset(seed=0, options=dict(reset_hook=configure_telemetry))
env.step(env.action)
env.stop()

controller = env.env.env.env.controller
assert isinstance(controller, PDController)
Expand Down

0 comments on commit 71b8873

Please sign in to comment.