Skip to content

Commit

Permalink
[gym_jiminy/common] Speed-up nested data structure utilities.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexis Duburcq committed Aug 16, 2023
1 parent 90ad70d commit 6146091
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ good-names =
t, q, v, x, u, s, qx, qy, qz, qw, # Physics: state, action
I, R, H, T, M, dt, # Physics: dynamics
a, b, c, y, z, n, e, # Maths / Algebra : variables
f, rg, lo, hi, # Maths / Algebra : operators
f, rg, lo, hi, op, # Maths / Algebra : operators
kp, kd, ki, # Control: Gains
ax # Matplotlib

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _observer_handle(self,
# Refresh the observation if not already done
if not self.__is_observation_refreshed:
measurement: EngineObsType = OrderedDict(
t=np.array((t,)),
t=np.array(t),
states=OrderedDict(agent=OrderedDict(q=q, v=v)),
measurements=OrderedDict(sensors_data))
self.refresh_observation(measurement)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def mahony_filter(q: np.ndarray,
omega = gyro - bias_hat + kp * omega_mes

# Early return if there is no IMU motion
if np.all(np.abs(omega) < 1e-6):
if (np.abs(omega) < 1e-6).all():
return

# Compute Axis-Angle repr. of the angular velocity: exp3(dt * omega)
Expand Down Expand Up @@ -197,7 +197,7 @@ def refresh_observation(self, measurement: BaseObsT) -> None:
if not self.env.is_simulation_running:
is_initialized = False
if not self.exact_init:
if np.all(np.abs(self.acc) < 0.1 * EARTH_SURFACE_GRAVITY):
if (np.abs(self.acc) < 0.1 * EARTH_SURFACE_GRAVITY).all():
LOGGER.warning(
"The acceleration at reset is too small. Impossible "
"to initialize Mahony filter for 'exact_init=False'.")
Expand Down
33 changes: 23 additions & 10 deletions python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from ..utils import (FieldNested,
DataNested,
zeros,
copyto,
clip,
build_clip,
build_copyto,
build_contains,
get_fieldnames,
register_variables)
from ..bases import (ObsT,
Expand Down Expand Up @@ -220,6 +221,15 @@ def __init__(self,
self.observation: ObsT = zeros(self.observation_space)
self.action: ActT = zeros(self.action_space)

# Define specialized operators for efficiency
self._get_clipped_env_observation: Callable[
[], DataNested] = lambda : OrderedDict()

Check warning on line 226 in python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py#L226

Lambda may not be necessary
self._copyto_observation = build_copyto(self.observation)
self._copyto_action = build_copyto(self.action)
self._contains_observation = build_contains(
self.observation, self.observation_space)
self._contains_action = build_contains(self.action, self.action_space)

# Set robot in neutral configuration
qpos = self._neutral()
framesForwardKinematics(
Expand Down Expand Up @@ -738,10 +748,14 @@ def reset(self, # type: ignore[override]
self.system_state.v,
self.robot.sensors_data)

# Initialize specialized most-derived observation clipping operator
self._get_clipped_env_observation = build_clip(
env.observation, env.observation_space)

# Make sure the state is valid, otherwise there `refresh_observation`
# and `_initialize_observation_space` are probably inconsistent.
try:
obs: ObsT = clip(env.observation, env.observation_space)
obs: ObsT = cast(ObsT, self._get_clipped_env_observation())
except (TypeError, ValueError) as e:
raise RuntimeError(
"The observation computed by `refresh_observation` is "
Expand Down Expand Up @@ -801,7 +815,7 @@ def step(self, # type: ignore[override]
f"'nan' value found in action ({action}).")

# Update the action
copyto(self.action, action)
self._copyto_action(action)

# Try performing a single simulation step
try:
Expand All @@ -817,8 +831,7 @@ def step(self, # type: ignore[override]
# Update the observer at the end of the step.
# This is necessary because, internally, it is called at the beginning
# of the every integration steps, during the controller update.
env = self._env_derived
env._observer_handle(
self._env_derived._observer_handle(
self.stepper_state.t,
self.system_state.q,
self.system_state.v,
Expand Down Expand Up @@ -878,7 +891,7 @@ def step(self, # type: ignore[override]
self.num_steps += 1

# Clip (and copy) the most derived observation before returning it
obs = clip(env.observation, env.observation_space, check=False)
obs = self._get_clipped_env_observation()

return obs, reward, done, truncated, deepcopy(self._info)

Expand Down Expand Up @@ -1374,7 +1387,7 @@ def refresh_observation(self, measurement: EngineObsType) -> None:
checking whether the simulation already started. It is not exactly
the same but it does the job regarding preserving efficiency.
"""
copyto(self.observation, cast(DataNested, measurement))
self._copyto_observation(measurement)

def compute_command(self, action: ActT) -> np.ndarray:
"""Compute the motors efforts to apply on the robot.
Expand All @@ -1395,7 +1408,7 @@ def compute_command(self, action: ActT) -> np.ndarray:
# pylint: disable=unused-argument

# Check if the action is out-of-bounds, in debug mode only
if self.debug and not self.action_space.contains(action):
if self.debug and not self._contains_action():
LOGGER.warning("The action is out-of-bounds.")

if not isinstance(action, np.ndarray):
Expand Down Expand Up @@ -1431,7 +1444,7 @@ def has_terminated(self) -> Tuple[bool, bool]:
"method.")

# Check if the observation is out-of-bounds
truncated = not self.observation_space.contains(self.observation)
truncated = not self._contains_observation()

return False, truncated

Expand Down
10 changes: 9 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
set_value,
copyto,
copy,
clip)
clip,
contains,
build_contains,
build_copyto,
build_clip)
from .helpers import (is_breakpoint,
get_fieldnames,
register_variables)
Expand All @@ -24,6 +28,10 @@
'copyto',
'copy',
'clip',
'contains',
'build_contains',
'build_copyto',
'build_clip',
'is_breakpoint',
'get_fieldnames',
'register_variables'
Expand Down

0 comments on commit 6146091

Please sign in to comment.