Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym/common] Improve support of dynamic computation graph. #751

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions python/gym_jiminy/common/gym_jiminy/common/bases/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,16 @@ def cache(self, cache: SharedCache[ValueT]) -> None:
self._cache = cache
self._has_cache = True

@property
def is_active(self) -> bool:
def is_active(self, any_cache_owner: bool = False) -> bool:
"""Whether this quantity is considered active, namely `initialize` has
been called at least once since previous tracking reset, either by this
exact instance or any identical quantity if shared cache is available.
been called at least once since previous tracking reset.

:param any_owner: False to check only if this exact instance is active,
True if any of the identical quantities (sharing the
same cache) is considered sufficient.
Optional: False by default.
"""
if self._cache is None:
if not any_cache_owner or self._cache is None:
return self._is_active
return any(owner._is_active for owner in self._cache.owners)

Expand All @@ -229,8 +232,8 @@ def get(self) -> ValueT:
evaluate it and store it in cache.

This quantity is considered active as soon as this method has been
called at least once since previous tracking reset. The corresponding
property `is_active` will be true even before calling `initialize`.
called at least once since previous tracking reset. The method
`is_active` will be return true even before calling `initialize`.

.. warning::
This method is not meant to be overloaded.
Expand Down Expand Up @@ -289,7 +292,7 @@ def reset(self, reset_tracking: bool = False) -> None:

# Reset all requirements first
for quantity in self.requirements.values():
quantity.reset()
quantity.reset(reset_tracking)

# More work has to be done if shared cache is available and has value
if self._has_cache:
Expand Down
13 changes: 9 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,15 @@ def initialize(self) -> None:
self.frame_names = {self.parent.frame_name}
if self.cache:
for owner in self.cache.owners:
parent = owner.parent
assert isinstance(parent, EulerAnglesFrame)
if parent.is_active:
self.frame_names.add(parent.frame_name)
# We only consider active instances of `_BatchEulerAnglesFrame`
# instead of their corresponding parent `EulerAnglesFrame`.
# This is necessary because a derived quantity may feature
# `_BatchEulerAnglesFrame` as a requirement without actually
# relying on it depending on whether it is part of the optimal
# computation path at the time being or not.
if owner.is_active(any_cache_owner=False):
assert isinstance(owner.parent, EulerAnglesFrame)
self.frame_names.add(owner.parent.frame_name)

# Re-allocate memory as the number of frames is not known in advance.
# Note that Fortran memory layout (column-major) is used for speed up
Expand Down
59 changes: 59 additions & 0 deletions python/gym_jiminy/examples/quantity_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import timeit

Check notice on line 1 in python/gym_jiminy/examples/quantity_benchmark.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/examples/quantity_benchmark.py#L1

Missing module docstring

import matplotlib.pyplot as plt
import gymnasium as gym

import gym_jiminy.common.bases.quantity
from gym_jiminy.common.bases import QuantityManager
from gym_jiminy.common.quantities import EulerAnglesFrame

# Define number of samples for benchmarking
N_SAMPLES = 20000

# Disable caching by forcing `SharedCache.has_value` to always return `False`
setattr(gym_jiminy.common.bases.quantity.SharedCache,
"has_value",
property(lambda self: False))

# Instantiate a dummy environment
env = gym.make("gym_jiminy.envs:atlas")
env.reset()
env.step(env.action)

# Define quantity manager and add quantities to benchmark
nframes = len(env.pinocchio_model.frames)
quantity_manager = QuantityManager(
env.simulator,
{
f"rpy_{i}": (EulerAnglesFrame, dict(frame_name=frame.name))
for i, frame in enumerate(env.pinocchio_model.frames)
})

# Run the benchmark for all batch size
time_per_frame_all = []
for i in range(1, nframes):
# Reset tracking
quantity_manager.reset(reset_tracking=True)

# Fetch all quantities once to update dynamic computation graph
for j, quantity in enumerate(quantity_manager.quantities.values()):
quantity.get()
if i == j + 1:
break

# Extract batched data buffer of `EulerAnglesFrame` quantities
shared_data = quantity.requirements['data']

Check warning on line 45 in python/gym_jiminy/examples/quantity_benchmark.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/examples/quantity_benchmark.py#L45

Using possibly undefined loop variable 'quantity'

# Benchmark computation of batched data buffer
duration = timeit.timeit(
'shared_data.get()', number=N_SAMPLES, globals={
"shared_data": shared_data
})
time_per_frame_all.append(duration / N_SAMPLES / i * 1e9)

# Plot the result
plt.figure()
plt.plot(time_per_frame_all)
plt.xlabel("Number of frames")
plt.ylabel("Average computation time per frame (ns)")
plt.show()
2 changes: 1 addition & 1 deletion python/jiminy_py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def finalize_options(self) -> None:
# Panda3d is NOT supported by PyPy even if built from source.
# - 1.10.12 fixes numerous bugs
# - 1.10.13 crashes when generating wheels on MacOS
"panda3d>=1.10.14",
"panda3d>=1.10.13",
# Photo-realistic shader for Panda3d to improve rendering of meshes.
# - 0.11.X is not backward compatible.
"panda3d-simplepbr==0.11.2",
Expand Down
Loading