Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym/common] Add angular momentum, stability, and friction rewards. #808

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ generated-members = torch, jiminy
[tool.pylint.basic]
# Good variable names which should always be accepted, separated by a comma
good-names =
i, j, k, l, N, # Python: for-loop indices
tb, np, nb, mp, tp, # Python: classical modules
fd, _, # Python: contexte
t, q, v, x, u, s, qx, qy, qz, qw, # Physics: state, action
I, R, H, T, M, dt, # Physics: dynamics
a, b, c, y, z, n, e, # Maths / Algebra : variables
f, rg, lo, hi, op, fn, # Maths / Algebra : operators
kp, kd, ki, # Control: Gains
ax # Matplotlib
i, j, k, l, N, # Python: for-loop indices
tb, np, nb, mp, tp, # Python: classical modules
fd, _, # Python: contexte
t, q, v, x, u, s, qx, qy, qz, qw, # Physics: state, action
I, R, H, T, M, dt, # Physics: dynamics
A, a, b, c, y, z, n, e, # Maths / Algebra: variables
f, rg, lo, hi, op, fn, # Maths / Algebra: operators
kp, kd, ki, # Control: Gains
ax # Matplotlib

[tool.pylint.format]
# Regexp for a line that is allowed to be longer than the limit
Expand Down
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ if(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /fp:contract")
endif()
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /EHsc /bigobj /Gy /Zc:inline /Zc:preprocessor /Zc:__cplusplus /permissive- /DWIN32 /D_USE_MATH_DEFINES /DNOMINMAX")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
/MP /EHsc /bigobj /Gy /Zc:inline /Zc:preprocessor /Zc:__cplusplus /permissive- /DWIN32 \
/D_USE_MATH_DEFINES /D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR /DNOMINMAX")
set(CMAKE_CXX_FLAGS_DEBUG "/Zi /Od")
set(CMAKE_CXX_FLAGS_RELEASE "/DNDEBUG /O2 /Ob3")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE} /Zi")
Expand Down
7 changes: 6 additions & 1 deletion build_tools/build_install_deps_windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ if (-not (Test-Path env:GENERATOR)) {
}

### Set common CMAKE_C/CXX_FLAGS
${CMAKE_CXX_FLAGS} = "${env:CMAKE_CXX_FLAGS} /MP2 /EHsc /bigobj /Gy /Zc:inline /Zc:preprocessor /Zc:__cplusplus /permissive- /DWIN32 /D_USE_MATH_DEFINES /DNOMINMAX"
# * Flag "DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR" is a dirty workaround to deal with VC runtime
# conflict related to different search path ordering at compile-time / run-time causing segfault:
# https://github.com/actions/runner-images/issues/10004
${CMAKE_CXX_FLAGS} = "${env:CMAKE_CXX_FLAGS} $(
) /MP2 /EHsc /bigobj /Gy /Zc:inline /Zc:preprocessor /Zc:__cplusplus /permissive- $(
) /DWIN32 /D_USE_MATH_DEFINES /D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR /DNOMINMAX"
if (${BUILD_TYPE} -eq "Debug") {
${CMAKE_CXX_FLAGS} = "${CMAKE_CXX_FLAGS} /Zi /Od"
} else {
Expand Down
25 changes: 22 additions & 3 deletions python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..utils import DataNested
if TYPE_CHECKING:
from ..envs.generic import BaseJiminyEnv
from ..quantities import QuantityManager


Expand Down Expand Up @@ -220,6 +221,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# Call super to allow mixing interfaces through multiple inheritance
super().__init__(*args, **kwargs)

# Define convenience proxy for quantity manager
self.quantities = self.unwrapped.quantities

def _setup(self) -> None:
"""Configure the observer-controller.

Expand Down Expand Up @@ -335,11 +339,26 @@ def _controller_handle(self,
# '_controller_handle' as it is never called more often than necessary.
self.__is_observation_refreshed = False

def stop(self) -> None:
"""Stop the episode immediately without waiting for a termination or
truncation condition to be satisfied.

.. note::
This method is mainly intended for data analysis and debugging.
Stopping the episode is necessary to log the final state, otherwise
it will be missing from plots and viewer replay. Moreover, sensor
data will not be available during replay using object-oriented
method `replay`. Helper method `play_logs_data` must be preferred
to replay an episode that cannot be stopped at the time being.
"""
self.simulator.stop()

@property
def unwrapped(self) -> "InterfaceJiminyEnv":
"""Base environment of the pipeline.
@abstractmethod
def unwrapped(self) -> "BaseJiminyEnv":
"""The "underlying environment at the basis of the pipeline from which
this environment is part of.
"""
return self

@property
@abstractmethod
Expand Down
49 changes: 33 additions & 16 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
* a wrapper to combine a controller block and a `BaseJiminyEnv` environment,
eventually already wrapped, so that it appears as a black-box environment.
"""
import sys
import math
import logging
from weakref import ref
from copy import deepcopy
from abc import abstractmethod
from collections import OrderedDict
from typing import (
Dict, Any, List, Optional, Tuple, Union, Generic, TypeVar, SupportsFloat,
Callable, cast)
Callable, cast, TYPE_CHECKING)

import numpy as np

Expand All @@ -37,6 +39,8 @@
from .blocks import BaseControllerBlock, BaseObserverBlock

from ..utils import DataNested, is_breakpoint, zeros, build_copyto, copy
if TYPE_CHECKING:
from ..envs.generic import BaseJiminyEnv


OtherObsT = TypeVar('OtherObsT', bound=DataNested)
Expand All @@ -46,6 +50,9 @@
TransformedActT = TypeVar('TransformedActT', bound=DataNested)


LOGGER = logging.getLogger(__name__)


class BasePipelineWrapper(
InterfaceJiminyEnv[ObsT, ActT],
Generic[ObsT, ActT, BaseObsT, BaseActT]):
Expand Down Expand Up @@ -101,7 +108,20 @@ def __getattr__(self, name: str) -> Any:

It enables to get access to the attribute and methods of the wrapped
environment directly without having to do it through `env`.

.. warning::
This fallback incurs a significant runtime overhead. As such, it
must only be used for debug and manual analysis between episodes.
Calling this method in script mode while a simulation is already
running would trigger a warning to avoid relying on it by mistake.
"""
if self.is_simulation_running and not hasattr(sys, 'ps1'):
# `hasattr(sys, 'ps1')` is used to detect whether the method was
# called from an interpreter or within a script. For details, see:
# https://stackoverflow.com/a/64523765/4820605
LOGGER.warning(
"Relying on fallback attribute getter is inefficient and "
"strongly discouraged at runtime.")
return getattr(self.__getattribute__('env'), name)

def __dir__(self) -> List[str]:
Expand Down Expand Up @@ -143,9 +163,7 @@ def np_random(self, value: np.random.Generator) -> None:
self.env.np_random = value

@property
def unwrapped(self) -> InterfaceJiminyEnv:
"""Base environment of the pipeline.
"""
def unwrapped(self) -> "BaseJiminyEnv":
return self.env.unwrapped

@property
Expand Down Expand Up @@ -236,8 +254,7 @@ def step(self, # type: ignore[override]
self._copyto_action(action)

# Make sure that the pipeline has not change since last reset
env_derived = (
self.unwrapped.derived) # type: ignore[attr-defined]
env_derived = self.unwrapped.derived
if env_derived is not self:
raise RuntimeError(
"Pipeline environment has changed. Please call 'reset' "
Expand Down Expand Up @@ -532,14 +549,14 @@ def __init__(self,
# Register the observer's internal state and feature to the telemetry
if state is not None:
try:
self.env.register_variable( # type: ignore[attr-defined]
self.unwrapped.register_variable(
'state', state, None, self.observer.name)
except ValueError:
pass
self.env.register_variable('feature', # type: ignore[attr-defined]
self.observer.observation,
self.observer.fieldnames,
self.observer.name)
self.unwrapped.register_variable('feature',
self.observer.observation,
self.observer.fieldnames,
self.observer.name)

def _setup(self) -> None:
"""Configure the wrapper.
Expand Down Expand Up @@ -750,14 +767,14 @@ def __init__(self,
# Register the controller's internal state and target to the telemetry
if state is not None:
try:
self.env.register_variable( # type: ignore[attr-defined]
self.unwrapped.register_variable(
'state', state, None, self.controller.name)
except ValueError:
pass
self.env.register_variable('action', # type: ignore[attr-defined]
self.action,
self.controller.fieldnames,
self.controller.name)
self.unwrapped.register_variable('action',
self.action,
self.controller.fieldnames,
self.controller.name)

def _setup(self) -> None:
"""Configure the wrapper.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=missing-module-docstring

from .mixin import (radial_basis_function,
from .mixin import (CUTOFF_ESP,
radial_basis_function,
AdditiveMixtureReward,
MultiplicativeMixtureReward)
from .generic import (BaseTrackingReward,
Expand All @@ -11,9 +12,11 @@
TrackingCapturePointReward,
TrackingFootPositionsReward,
TrackingFootOrientationsReward,
MinimizeAngularMomentumReward)
MinimizeAngularMomentumReward,
MinimizeFrictionReward)

__all__ = [
"CUTOFF_ESP",
"radial_basis_function",
"AdditiveMixtureReward",
"MultiplicativeMixtureReward",
Expand All @@ -25,5 +28,6 @@
"TrackingFootPositionsReward",
"TrackingFootOrientationsReward",
"MinimizeAngularMomentumReward",
"MinimizeFrictionReward",
"SurviveReward"
]
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ class BaseTrackingReward(BaseQuantityReward):
otherwise an exception will be risen. See `DatasetTrajectoryQuantity` and
`AbstractQuantity` documentations for details.

The error transform in a normalized reward to maximize by applying RBF
The error is transformed in a normalized reward to maximize by applying RBF
kernel on the error. The reward will be 0.0 if the error cancels out
completely and less than 0.01 above the user-specified cutoff threshold.
completely and less than 'CUTOFF_ESP' above the user-specified cutoff
threshold.
"""
def __init__(self,
env: InterfaceJiminyEnv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from ..bases import (
InterfaceJiminyEnv, StateQuantity, QuantityEvalMode, BaseQuantityReward)
from ..quantities import (
MaskedQuantity, UnaryOpQuantity, AverageBaseOdometryVelocity,
AverageBaseMomentum, MultiFootRelativeXYZQuat, CapturePoint)
MaskedQuantity, UnaryOpQuantity, AverageBaseOdometryVelocity, CapturePoint,
MultiFootRelativeXYZQuat, MultiContactRelativeForceTangential,
AverageBaseMomentum)
from ..quantities.locomotion import sanitize_foot_frame_names
from ..utils import quat_difference

Expand Down Expand Up @@ -221,3 +222,34 @@ def __init__(self,
partial(radial_basis_function, cutoff=self.cutoff, order=2),
is_normalized=True,
is_terminal=False)


class MinimizeFrictionReward(BaseQuantityReward):
"""Reward the agent for minimizing the tangential forces at all the contact
points and collision bodies, and to avoid jerky intermittent contact state.

The L2-norm is used to aggregate all the local tangential forces. While the
L1-norm would be more natural in this specific cases, using the L2-norm is
preferable as it promotes space-time regularity, ie balancing the force
distribution evenly between all the candidate contact points and avoiding
jerky contact forces over time (high-frequency vibrations), phenomena to
which the L1-norm is completely insensitive.
"""
def __init__(self,
env: InterfaceJiminyEnv,
cutoff: float) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param cutoff: Cutoff threshold for the RBF kernel transform.
"""
# Backup some user argument(s)
self.cutoff = cutoff

# Call base implementation
super().__init__(
env,
"reward_friction",
(MultiContactRelativeForceTangential, dict()),
partial(radial_basis_function, cutoff=self.cutoff, order=2),
is_normalized=True,
is_terminal=False)
23 changes: 17 additions & 6 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


# Reward value at cutoff threshold
RBF_CUTOFF_ESP = 1.0e-2
CUTOFF_ESP = 1.0e-2


ArrayOrScalar = Union[np.ndarray, float]
Expand Down Expand Up @@ -46,12 +46,23 @@ def radial_basis_function(error: ArrayOrScalar,
:param cutoff: Cut-off threshold to consider.
:param order: Order of Lp-Norm that will be used as distance metric.
"""
error_ = np.asarray(error).reshape((-1,))
if order == 2:
squared_dist_rel = np.dot(error_, error_) / math.pow(cutoff, 2)
error_ = np.asarray(error)
is_contiguous = error_.flags.f_contiguous or error_.flags.c_contiguous
if is_contiguous or order != 2:
if error_.ndim > 1 and not is_contiguous:
error_ = np.ascontiguousarray(error_)
if error_.flags.c_contiguous:
error1d = np.asarray(error_).ravel()
else:
error1d = np.asarray(error_.T).ravel()
if order == 2:
squared_dist_rel = np.dot(error1d, error1d) / math.pow(cutoff, 2)
else:
squared_dist_rel = math.pow(
np.linalg.norm(error1d, order) / cutoff, 2)
else:
squared_dist_rel = math.pow(np.linalg.norm(error_, order) / cutoff, 2)
return math.pow(RBF_CUTOFF_ESP, squared_dist_rel)
squared_dist_rel = np.sum(np.square(error_)) / math.pow(cutoff, 2)
return math.pow(CUTOFF_ESP, squared_dist_rel)


class AdditiveMixtureReward(BaseMixtureReward):
Expand Down
Loading
Loading