Skip to content

Commit

Permalink
[gym/common] Fix circular reference issue preventing garbage collection.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Feb 11, 2024
1 parent 3d7c11b commit 9552098
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 45 deletions.
18 changes: 8 additions & 10 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
eventually already wrapped, so that it appears as a black-box environment.
"""
import math
from weakref import ref
import weakref
from copy import deepcopy
from abc import abstractmethod
from collections import OrderedDict
Expand Down Expand Up @@ -167,35 +167,33 @@ def reset(self, # type: ignore[override]
# Create weak reference to self.
# This is necessary to avoid circular reference that would make the
# corresponding object noncollectable and hence cause a memory leak.
pipeline_wrapper_ref = ref(self)
pipeline_wrapper_proxy = weakref.proxy(self)

# Extra reset_hook from options if provided
derived_reset_hook: Optional[Callable[[], InterfaceJiminyEnv]] = (
options or {}).get("reset_hook")

# Define chained controller hook
def reset_hook() -> Optional[InterfaceJiminyEnv]:
def reset_hook() -> Optional[weakref.ProxyType]:
"""Register the block to the higher-level block.
This method is used internally to make sure that `_setup` method
of connected blocks are called in the right order, namely from the
lowest-level block to the highest-level one, right after reset of
the low-level simulator and just before performing the first step.
"""
nonlocal pipeline_wrapper_ref, derived_reset_hook
nonlocal pipeline_wrapper_proxy, derived_reset_hook

# Extract and initialize the pipeline wrapper
pipeline_wrapper = pipeline_wrapper_ref()
assert pipeline_wrapper is not None
pipeline_wrapper._setup()
# Initialize the pipeline wrapper
pipeline_wrapper_proxy._setup()

# Forward the environment provided by the reset hook of higher-
# level block if any, or use this wrapper otherwise.
if derived_reset_hook is None:
env_derived: InterfaceJiminyEnv = pipeline_wrapper
env_derived: weakref.ProxyType = pipeline_wrapper_proxy
else:
assert callable(derived_reset_hook)
env_derived = derived_reset_hook() or pipeline_wrapper
env_derived = derived_reset_hook() or pipeline_wrapper_proxy

return env_derived

Expand Down
14 changes: 9 additions & 5 deletions python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
"""
import os
import math
import weakref
import logging
import tempfile
from copy import deepcopy
from collections import OrderedDict
from collections.abc import Mapping
from itertools import chain
from functools import partial
from typing import (
Dict, Any, List, cast, no_type_check, Optional, Tuple, Callable, Iterable,
Union, SupportsFloat, Iterator, Generic, Sequence, Mapping as MappingT,
Expand Down Expand Up @@ -190,7 +192,7 @@ def __init__(self,
self.robot.sensor_measurements)

# Top-most block of the pipeline to which the environment is part of
self._env_derived: InterfaceJiminyEnv = self
self._env_derived: weakref.ProxyType = weakref.proxy(self)

# Store references to the variables to register to the telemetry
self._registered_variables: MutableMappingT[
Expand Down Expand Up @@ -730,16 +732,18 @@ def reset(self, # type: ignore[override]
# Similarly, the observer and controller update periods must be set.
reset_hook: Optional[Callable[[], InterfaceJiminyEnv]] = (
options or {}).get("reset_hook")
env: InterfaceJiminyEnv = self
env: weakref.ProxyType = weakref.proxy(self)
if reset_hook is not None:
assert callable(reset_hook)
env_derived = reset_hook() or self
env_derived = reset_hook() or env
assert env_derived.unwrapped is self
env = env_derived
self._env_derived = env

# Instantiate the actual controller
controller = jiminy.FunctionalController(env._controller_handle)
# Instantiate the actual controller.
# Note that a weak reference must be used to avoid circular reference.
controller = jiminy.FunctionalController(
partial(env._controller_handle.__func__, env))
controller.initialize(self.robot)
self.simulator.set_controller(controller)

Expand Down
19 changes: 9 additions & 10 deletions python/jiminy_py/src/jiminy_py/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import pathlib
import tempfile
from weakref import ref
import weakref
from copy import deepcopy
from itertools import chain
from functools import partial
Expand Down Expand Up @@ -93,21 +93,20 @@ def __init__(self, # pylint: disable=unused-argument
# Wrap callback in nested function to hide update of progress bar
# Note that a weak reference must be used to avoid circular reference
# resulting in uncollectable object and hence memory leak.
simulator_ref = ref(self)
simulator_proxy = weakref.proxy(self)

def callback_wrapper(t: float,
def callback_wrapper(simulator_proxy: weakref.ProxyType,
t: float,
*args: Any,
**kwargs: Any) -> None:
nonlocal simulator_ref
simulator = simulator_ref()
assert simulator is not None
if simulator.__pbar is not None:
simulator.__pbar.update(t - simulator.__pbar.n)
simulator._callback(t, *args, **kwargs)
if simulator_proxy.__pbar is not None:
simulator_proxy.__pbar.update(t - simulator_proxy.__pbar.n)
simulator_proxy._callback(t, *args, **kwargs)

# Instantiate the low-level Jiminy engine, then initialize it
self.engine = engine_class()
self.engine.initialize(robot, controller, callback_wrapper)
self.engine.initialize(
robot, controller, partial(callback_wrapper, simulator_proxy))

# Create shared memories and python-native attribute for fast access
self.stepper_state = self.engine.stepper_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import math
import array
import signal
import weakref
import warnings
import importlib
import threading
import multiprocessing as mp
import xml.etree.ElementTree as ET
from weakref import ref
from functools import wraps
from itertools import chain
from datetime import datetime
Expand Down Expand Up @@ -1794,25 +1794,19 @@ def async_mode(self) -> AbstractContextManager:
right before the next method execution instead of being thrown on
the spot.
"""
proxy_ref = ref(self)
proxy = weakref.proxy(self)

class ContextAsyncMode(AbstractContextManager):
"""Context manager forcing async execution when forwarding request
to the underlying panda3d viewer instance.
"""
def __enter__(self) -> None:
nonlocal proxy_ref
proxy = proxy_ref()
assert proxy is not None
proxy._is_async = True

def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
nonlocal proxy_ref
proxy = proxy_ref()
assert proxy is not None
proxy._is_async = False

return ContextAsyncMode()
Expand Down
8 changes: 4 additions & 4 deletions python/jiminy_py/unit_py/test_flexible_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ def setUp(self):
# Remove temporary file
os.remove(urdf_path)

# Instantiate and initialize the controller
controller = jiminy.FunctionalController()
controller.initialize(robot)
# Instantiate and initialize a controller doing nothing
noop_controller = jiminy.FunctionalController()
noop_controller.initialize(robot)

# Create a simulator using this robot and controller
self.simulator = Simulator(
robot,
controller,
noop_controller,
viewer_kwargs=dict(
camera_pose=((0.0, -2.0, 0.0), (np.pi/2, 0.0, 0.0), None)
))
Expand Down
15 changes: 7 additions & 8 deletions python/jiminy_py/unit_py/test_simple_mass.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""This file aims at verifying the sanity of the physics and the integration
method of jiminy on simple mass.
"""
import weakref
import unittest
import numpy as np
from enum import Enum
from weakref import ref
from itertools import product

import numpy as np
from scipy.signal import savgol_filter

import jiminy_py.core as jiminy
Expand Down Expand Up @@ -195,20 +196,18 @@ def test_contact_sensor(self):
engine = jiminy.Engine()

# No control law, only check sensors data
engine_ref = ref(engine)
engine_proxy = weakref.proxy(engine)
def check_sensor_measurements(t, q, v, sensor_measurements, command):
# Verify sensor data, if the engine has been initialized
nonlocal engine_ref, frame_pose
engine = engine_ref()
assert engine is not None
if engine.is_initialized:
nonlocal engine_proxy
if engine_proxy.is_initialized:
f_linear = sensor_measurements[
ContactSensor.type, self.body_name]
f_wrench = sensor_measurements[
ForceSensor.type, self.body_name]
f_contact_sensor = frame_pose * Force(f_linear, np.zeros(3))
f_force_sensor = frame_pose * Force(*np.split(f_wrench, 2))
f_true = engine.system_state.f_external[joint_index]
f_true = engine_proxy.system_state.f_external[joint_index]
self.assertTrue(np.allclose(
f_contact_sensor.linear, f_true.linear, atol=TOLERANCE))
self.assertTrue(np.allclose(
Expand Down

0 comments on commit 9552098

Please sign in to comment.