Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym_jiminy/common] Speed-up nested data structure utilities. #632

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ good-names =
t, q, v, x, u, s, qx, qy, qz, qw, # Physics: state, action
I, R, H, T, M, dt, # Physics: dynamics
a, b, c, y, z, n, e, # Maths / Algebra : variables
f, rg, lo, hi, # Maths / Algebra : operators
f, rg, lo, hi, op, # Maths / Algebra : operators
kp, kd, ki, # Control: Gains
ax # Matplotlib

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _observer_handle(self,
# Refresh the observation if not already done
if not self.__is_observation_refreshed:
measurement: EngineObsType = OrderedDict(
t=np.array((t,)),
t=np.array(t),
states=OrderedDict(agent=OrderedDict(q=q, v=v)),
measurements=OrderedDict(sensors_data))
self.refresh_observation(measurement)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def mahony_filter(q: np.ndarray,
omega = gyro - bias_hat + kp * omega_mes

# Early return if there is no IMU motion
if np.all(np.abs(omega) < 1e-6):
if (np.abs(omega) < 1e-6).all():
return

# Compute Axis-Angle repr. of the angular velocity: exp3(dt * omega)
Expand Down Expand Up @@ -197,7 +197,7 @@ def refresh_observation(self, measurement: BaseObsT) -> None:
if not self.env.is_simulation_running:
is_initialized = False
if not self.exact_init:
if np.all(np.abs(self.acc) < 0.1 * EARTH_SURFACE_GRAVITY):
if (np.abs(self.acc) < 0.1 * EARTH_SURFACE_GRAVITY).all():
LOGGER.warning(
"The acceleration at reset is too small. Impossible "
"to initialize Mahony filter for 'exact_init=False'.")
Expand Down
33 changes: 23 additions & 10 deletions python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from ..utils import (FieldNested,
DataNested,
zeros,
copyto,
clip,
build_clip,
build_copyto,
build_contains,
get_fieldnames,
register_variables)
from ..bases import (ObsT,
Expand Down Expand Up @@ -220,6 +221,15 @@ def __init__(self,
self.observation: ObsT = zeros(self.observation_space)
self.action: ActT = zeros(self.action_space)

# Define specialized operators for efficiency
self._get_clipped_env_observation: Callable[
[], DataNested] = lambda: OrderedDict()
self._copyto_observation = build_copyto(self.observation)
self._copyto_action = build_copyto(self.action)
self._contains_observation = build_contains(
self.observation, self.observation_space)
self._contains_action = build_contains(self.action, self.action_space)

# Set robot in neutral configuration
qpos = self._neutral()
framesForwardKinematics(
Expand Down Expand Up @@ -738,10 +748,14 @@ def reset(self, # type: ignore[override]
self.system_state.v,
self.robot.sensors_data)

# Initialize specialized most-derived observation clipping operator
self._get_clipped_env_observation = build_clip(
env.observation, env.observation_space)

# Make sure the state is valid, otherwise there `refresh_observation`
# and `_initialize_observation_space` are probably inconsistent.
try:
obs: ObsT = clip(env.observation, env.observation_space)
obs: ObsT = cast(ObsT, self._get_clipped_env_observation())
except (TypeError, ValueError) as e:
raise RuntimeError(
"The observation computed by `refresh_observation` is "
Expand Down Expand Up @@ -801,7 +815,7 @@ def step(self, # type: ignore[override]
f"'nan' value found in action ({action}).")

# Update the action
copyto(self.action, action)
self._copyto_action(action)

# Try performing a single simulation step
try:
Expand All @@ -817,8 +831,7 @@ def step(self, # type: ignore[override]
# Update the observer at the end of the step.
# This is necessary because, internally, it is called at the beginning
# of the every integration steps, during the controller update.
env = self._env_derived
env._observer_handle(
self._env_derived._observer_handle(
self.stepper_state.t,
self.system_state.q,
self.system_state.v,
Expand Down Expand Up @@ -878,7 +891,7 @@ def step(self, # type: ignore[override]
self.num_steps += 1

# Clip (and copy) the most derived observation before returning it
obs = clip(env.observation, env.observation_space, check=False)
obs = self._get_clipped_env_observation()

return obs, reward, done, truncated, deepcopy(self._info)

Expand Down Expand Up @@ -1374,7 +1387,7 @@ def refresh_observation(self, measurement: EngineObsType) -> None:
checking whether the simulation already started. It is not exactly
the same but it does the job regarding preserving efficiency.
"""
copyto(self.observation, cast(DataNested, measurement))
self._copyto_observation(measurement)

def compute_command(self, action: ActT) -> np.ndarray:
"""Compute the motors efforts to apply on the robot.
Expand All @@ -1395,7 +1408,7 @@ def compute_command(self, action: ActT) -> np.ndarray:
# pylint: disable=unused-argument

# Check if the action is out-of-bounds, in debug mode only
if self.debug and not self.action_space.contains(action):
if self.debug and not self._contains_action():
LOGGER.warning("The action is out-of-bounds.")

if not isinstance(action, np.ndarray):
Expand Down Expand Up @@ -1431,7 +1444,7 @@ def has_terminated(self) -> Tuple[bool, bool]:
"method.")

# Check if the observation is out-of-bounds
truncated = not self.observation_space.contains(self.observation)
truncated = not self._contains_observation()

return False, truncated

Expand Down
10 changes: 9 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
set_value,
copyto,
copy,
clip)
clip,
contains,
build_contains,
build_copyto,
build_clip)
from .helpers import (is_breakpoint,
get_fieldnames,
register_variables)
Expand All @@ -24,6 +28,10 @@
'copyto',
'copy',
'clip',
'contains',
'build_contains',
'build_copyto',
'build_clip',
'is_breakpoint',
'get_fieldnames',
'register_variables'
Expand Down
Loading
Loading