Skip to content

Commit

Permalink
[gym_jiminy/common] Speed-up nested data structure utilities.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexis Duburcq committed Aug 16, 2023
1 parent 90ad70d commit 6c1bee5
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from gymnasium.core import RenderFrame
from gymnasium.envs.registration import EnvSpec

from ..utils import DataNested, is_breakpoint, zeros, copyto, copy
from ..utils import DataNested, is_breakpoint, zeros, copyto, copy, contains

from .generic_bases import (DT_EPS,
ObsT,
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(self,
super().__init__() # Do not forward any argument

# By default, bind the action to the one of the base environment
assert self.action_space.contains(env.action)
assert contains(env.action, self.action_space)
self.action = env.action # type: ignore[assignment]

def __getattr__(self, name: str) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def mahony_filter(q: np.ndarray,
omega = gyro - bias_hat + kp * omega_mes

# Early return if there is no IMU motion
if np.all(np.abs(omega) < 1e-6):
if (np.abs(omega) < 1e-6).all():
return

# Compute Axis-Angle repr. of the angular velocity: exp3(dt * omega)
Expand Down Expand Up @@ -197,7 +197,7 @@ def refresh_observation(self, measurement: BaseObsT) -> None:
if not self.env.is_simulation_running:
is_initialized = False
if not self.exact_init:
if np.all(np.abs(self.acc) < 0.1 * EARTH_SURFACE_GRAVITY):
if (np.abs(self.acc) < 0.1 * EARTH_SURFACE_GRAVITY).all():
LOGGER.warning(
"The acceleration at reset is too small. Impossible "
"to initialize Mahony filter for 'exact_init=False'.")
Expand Down
23 changes: 15 additions & 8 deletions python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from ..utils import (FieldNested,
DataNested,
zeros,
copyto,
clip,
contains,
build_clip,
build_copyto,
get_fieldnames,
register_variables)
from ..bases import (ObsT,
Expand Down Expand Up @@ -220,6 +221,12 @@ def __init__(self,
self.observation: ObsT = zeros(self.observation_space)
self.action: ActT = zeros(self.action_space)

# Define specialized copy and clipping operators for efficiency
self._get_clipped_observation = build_clip(
self.observation, self.observation_space)
self._copy_observation = build_copyto(self.observation)
self._copy_action = build_copyto(self.action)

# Set robot in neutral configuration
qpos = self._neutral()
framesForwardKinematics(
Expand Down Expand Up @@ -741,7 +748,7 @@ def reset(self, # type: ignore[override]
# Make sure the state is valid, otherwise there `refresh_observation`
# and `_initialize_observation_space` are probably inconsistent.
try:
obs: ObsT = clip(env.observation, env.observation_space)
obs: ObsT = self._get_clipped_observation()
except (TypeError, ValueError) as e:
raise RuntimeError(
"The observation computed by `refresh_observation` is "
Expand Down Expand Up @@ -801,7 +808,7 @@ def step(self, # type: ignore[override]
f"'nan' value found in action ({action}).")

# Update the action
copyto(self.action, action)
self._copy_action(action)

# Try performing a single simulation step
try:
Expand Down Expand Up @@ -878,7 +885,7 @@ def step(self, # type: ignore[override]
self.num_steps += 1

# Clip (and copy) the most derived observation before returning it
obs = clip(env.observation, env.observation_space, check=False)
obs = self._get_clipped_observation()

return obs, reward, done, truncated, deepcopy(self._info)

Expand Down Expand Up @@ -1374,7 +1381,7 @@ def refresh_observation(self, measurement: EngineObsType) -> None:
checking whether the simulation already started. It is not exactly
the same but it does the job regarding preserving efficiency.
"""
copyto(self.observation, cast(DataNested, measurement))
self._copy_observation(cast(DataNested, measurement))

def compute_command(self, action: ActT) -> np.ndarray:
"""Compute the motors efforts to apply on the robot.
Expand All @@ -1395,7 +1402,7 @@ def compute_command(self, action: ActT) -> np.ndarray:
# pylint: disable=unused-argument

# Check if the action is out-of-bounds, in debug mode only
if self.debug and not self.action_space.contains(action):
if self.debug and not contains(action, self.action_space):
LOGGER.warning("The action is out-of-bounds.")

if not isinstance(action, np.ndarray):
Expand Down Expand Up @@ -1431,7 +1438,7 @@ def has_terminated(self) -> Tuple[bool, bool]:
"method.")

# Check if the observation is out-of-bounds
truncated = not self.observation_space.contains(self.observation)
truncated = not contains(self.observation, self.observation_space)

return False, truncated

Expand Down
8 changes: 7 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
set_value,
copyto,
copy,
clip)
clip,
contains,
build_copyto,
build_clip)
from .helpers import (is_breakpoint,
get_fieldnames,
register_variables)
Expand All @@ -21,9 +24,12 @@
'zeros',
'fill',
'set_value',
'build_copyto',
'copyto',
'copy',
'build_clip',
'clip',
'contains',
'is_breakpoint',
'get_fieldnames',
'register_variables'
Expand Down
Loading

0 comments on commit 6c1bee5

Please sign in to comment.