Skip to content

Commit

Permalink
[gym/common] Add safety limits termination condition.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Jun 23, 2024
1 parent a3ab419 commit 50dec09
Show file tree
Hide file tree
Showing 12 changed files with 300 additions and 74 deletions.
117 changes: 69 additions & 48 deletions python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from abc import ABC, abstractmethod
from enum import Enum
from typing import Tuple, Sequence, Callable, Union, Optional, TypeVar
from typing import Tuple, Sequence, Callable, Union, Optional, Generic, TypeVar

import numpy as np

Expand Down Expand Up @@ -170,7 +170,7 @@ def __call__(self, terminated: bool, info: InfoType) -> float:
return value


class QuantityReward(AbstractReward):
class QuantityReward(AbstractReward, Generic[ValueT]):
"""Convenience class making it easy to derive reward components from
generic quantities.
Expand Down Expand Up @@ -266,7 +266,7 @@ def compute(self, terminated: bool, info: InfoType) -> Optional[float]:
return None

# Evaluate raw quantity
value = self.env.quantities[self.name]
value = self.quantity.get()

# Early return if quantity is None
if value is None:
Expand Down Expand Up @@ -426,13 +426,37 @@ class AbstractTerminationCondition(ABC):

def __init__(self,
env: InterfaceJiminyEnv,
name: str) -> None:
name: str,
grace_period: float = 0.0,
*,
is_truncation: bool = False,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the reward.
:param name: Desired name of the termination condition. This name will
be used as key for storing the current episode state from
the perspective of this specific condition in 'info', and
to add the underlying quantity to the set of already
managed quantities by the environment. As a result, it
must be unique otherwise an exception will be raised.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param is_truncation: Whether the episode should be considered
terminated or truncated whenever the termination
condition is triggered.
Optional: False by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
self.env = env
self._name = name
self.grace_period = grace_period
self.is_truncation = is_truncation
self.is_training_only = is_training_only

@property
def name(self) -> str:
Expand All @@ -445,15 +469,12 @@ def name(self) -> str:
return self._name

@abstractmethod
def compute(self, info: InfoType) -> EpisodeState:
"""Evaluate the termination condition.
def compute(self, info: InfoType) -> bool:
"""Evaluate the termination condition at hands.
:param info: Dictionary of extra information for monitoring. It will be
updated in-place for storing terminated and truncated
flags in 'info' as a tri-states `EpisodeState` value.
:returns: Current episode state from the sole perspective of the
termination condition at hands.
"""

def __call__(self, info: InfoType) -> Tuple[bool, bool]:
Expand All @@ -464,10 +485,10 @@ def __call__(self, info: InfoType) -> Tuple[bool, bool]:
current state of the environment under the ongoing action.
.. note::
This method is a lightweight wrapper around `compute` to split the
episode state in two boolean flags to comply with Gym API. 'info'
will be updated to store either custom debug information if any
or the episode state otherwise.
This method is a lightweight wrapper around `compute` to return two
boolean flags 'terminated', 'truncated' complying with Gym API.
'info' will be updated to store either custom debug information if
any, a tri-states episode state `EpisodeState` otherwise.
.. warning::
This method is not meant to be overloaded.
Expand All @@ -478,11 +499,17 @@ def __call__(self, info: InfoType) -> Tuple[bool, bool]:
:returns: terminated and truncated flags.
"""
# Evaluate the reward and store extra information
# Skip termination condition in eval mode or during grace period
termination_info: InfoType = {}
episode_state = self.compute(termination_info)
is_terminated = episode_state == EpisodeState.TERMINATED
is_truncated = episode_state == EpisodeState.TRUNCATED
if (self.is_training_only and not self.env.is_training) or (
self.env.stepper_state.t < self.grace_period):
# Always continue
is_terminated, is_truncated = False, False
else:
# Evaluate the reward and store extra information
is_done = self.compute(termination_info)
is_terminated = is_done and not self.is_truncation
is_truncated = is_done and self.is_truncation

# Store episode state as info
if self.name in info.keys():
Expand All @@ -492,13 +519,19 @@ def __call__(self, info: InfoType) -> Tuple[bool, bool]:
if termination_info:
info[self.name] = termination_info
else:
if is_terminated:
episode_state = EpisodeState.TERMINATED
elif is_truncated:
episode_state = EpisodeState.TRUNCATED
else:
episode_state = EpisodeState.CONTINUED
info[self.name] = episode_state

# Returning terminated and truncated flags
return is_terminated, is_truncated


class QuantityTermination(AbstractTerminationCondition):
class QuantityTermination(AbstractTerminationCondition, Generic[ValueT]):
"""Convenience class making it easy to derive termination conditions from
generic quantities.
Expand Down Expand Up @@ -547,14 +580,16 @@ def __init__(self,
Optional: False by default.
"""
# Backup user argument(s)
self.low = np.asarray(low) if low is not None else None
self.high = np.asarray(high) if high is not None else None
self.grace_period = grace_period
self.is_truncation = is_truncation
self.is_training_only = is_training_only
self.low = low
self.high = high

# Call base implementation
super().__init__(env, name)
super().__init__(
env,
name,
grace_period,
is_truncation=is_truncation,
is_training_only=is_training_only)

# Add quantity to the set of quantities managed by the environment
self.env.quantities[self.name] = quantity
Expand All @@ -569,40 +604,26 @@ def __del__(self) -> None:
# This method must not fail under any circumstances
pass

def compute(self, info: InfoType) -> EpisodeState:
def compute(self, info: InfoType) -> bool:
"""Evaluate the termination condition.
The underlying quantity is first evaluated. The episode continues if
its value is within bounds for all its components, otherwise the
episode is either truncated or terminated according to 'is_truncation'.
all the elements of its value are within bounds, otherwise the episode
is either truncated or terminated according to 'is_truncation'.
.. warning::
This method is not meant to be overloaded.
:returns: Current episode state from the sole perspective of the
termination condition at hands.
"""
# Skip termination condition in eval mode or during grace period
if (self.is_training_only and not self.env.is_training) or (
self.env.stepper_state.t < self.grace_period):
return EpisodeState.CONTINUED

# Evaluate the quantity
value = self.env.quantities[self.name]
value = self.quantity.get()

# Check if the quantity is within bound.
# Check if the quantity is out-of-bounds bound.
# Note that it may be `None` if the quantity is ill-defined for the
# current simulation state, which triggers termination unconditionally.
is_valid = value is not None
is_valid &= self.low is None or bool(np.all(self.low <= value))
is_valid &= self.high is None or bool(np.all(value <= self.high))

# Determine the episode state to return
if is_valid:
return EpisodeState.CONTINUED
if self.is_truncation:
return EpisodeState.TRUNCATED
return EpisodeState.TERMINATED
is_done = value is None
is_done |= self.low is not None and bool(np.any(self.low > value))
is_done |= self.high is not None and bool(np.any(value > self.high))
return is_done


QuantityTermination.name.__doc__ = \
Expand Down
3 changes: 2 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def __getattr__(self, name: str) -> Any:
Calling this method in script mode while a simulation is already
running would trigger a warning to avoid relying on it by mistake.
"""
if self.is_simulation_running and not hasattr(sys, 'ps1'):
if (self.is_simulation_running and self.env.is_training and
not hasattr(sys, 'ps1')):
# `hasattr(sys, 'ps1')` is used to detect whether the method was
# called from an interpreter or within a script. For details, see:
# https://stackoverflow.com/a/64523765/4820605
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
TrackingActuatedJointPositionsReward,
SurviveReward,
DriftTrackingQuantityTermination,
ShiftTrackingQuantityTermination)
ShiftTrackingQuantityTermination,
MechanicalSafetyTermination)
from .locomotion import (TrackingBaseHeightReward,
TrackingBaseOdometryVelocityReward,
TrackingCapturePointReward,
Expand Down Expand Up @@ -39,6 +40,7 @@
"SurviveReward",
"DriftTrackingQuantityTermination",
"ShiftTrackingQuantityTermination",
"MechanicalSafetyTermination",
"BaseRollPitchTermination",
"BaseHeightTermination",
"FootCollisionTermination"
Expand Down
Loading

0 comments on commit 50dec09

Please sign in to comment.