From 955209822b24cce45d7e6f57a1d5af06e220048a Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Sun, 11 Feb 2024 20:35:14 +0100 Subject: [PATCH] [gym/common] Fix circular reference issue preventing garbage collection. --- .../gym_jiminy/common/bases/pipeline_bases.py | 18 ++++++++---------- .../gym_jiminy/common/envs/env_generic.py | 14 +++++++++----- python/jiminy_py/src/jiminy_py/simulator.py | 19 +++++++++---------- .../viewer/panda3d/panda3d_visualizer.py | 10 ++-------- python/jiminy_py/unit_py/test_flexible_arm.py | 8 ++++---- python/jiminy_py/unit_py/test_simple_mass.py | 15 +++++++-------- 6 files changed, 39 insertions(+), 45 deletions(-) diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline_bases.py b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline_bases.py index 95a54b7a44..2b6ccc60a9 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline_bases.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline_bases.py @@ -10,7 +10,7 @@ eventually already wrapped, so that it appears as a black-box environment. """ import math -from weakref import ref +import weakref from copy import deepcopy from abc import abstractmethod from collections import OrderedDict @@ -167,14 +167,14 @@ def reset(self, # type: ignore[override] # Create weak reference to self. # This is necessary to avoid circular reference that would make the # corresponding object noncollectable and hence cause a memory leak. - pipeline_wrapper_ref = ref(self) + pipeline_wrapper_proxy = weakref.proxy(self) # Extra reset_hook from options if provided derived_reset_hook: Optional[Callable[[], InterfaceJiminyEnv]] = ( options or {}).get("reset_hook") # Define chained controller hook - def reset_hook() -> Optional[InterfaceJiminyEnv]: + def reset_hook() -> Optional[weakref.ProxyType]: """Register the block to the higher-level block. This method is used internally to make sure that `_setup` method @@ -182,20 +182,18 @@ def reset_hook() -> Optional[InterfaceJiminyEnv]: lowest-level block to the highest-level one, right after reset of the low-level simulator and just before performing the first step. """ - nonlocal pipeline_wrapper_ref, derived_reset_hook + nonlocal pipeline_wrapper_proxy, derived_reset_hook - # Extract and initialize the pipeline wrapper - pipeline_wrapper = pipeline_wrapper_ref() - assert pipeline_wrapper is not None - pipeline_wrapper._setup() + # Initialize the pipeline wrapper + pipeline_wrapper_proxy._setup() # Forward the environment provided by the reset hook of higher- # level block if any, or use this wrapper otherwise. if derived_reset_hook is None: - env_derived: InterfaceJiminyEnv = pipeline_wrapper + env_derived: weakref.ProxyType = pipeline_wrapper_proxy else: assert callable(derived_reset_hook) - env_derived = derived_reset_hook() or pipeline_wrapper + env_derived = derived_reset_hook() or pipeline_wrapper_proxy return env_derived diff --git a/python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py b/python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py index 2ad1a39ff9..cb6cd0e0fd 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py @@ -4,12 +4,14 @@ """ import os import math +import weakref import logging import tempfile from copy import deepcopy from collections import OrderedDict from collections.abc import Mapping from itertools import chain +from functools import partial from typing import ( Dict, Any, List, cast, no_type_check, Optional, Tuple, Callable, Iterable, Union, SupportsFloat, Iterator, Generic, Sequence, Mapping as MappingT, @@ -190,7 +192,7 @@ def __init__(self, self.robot.sensor_measurements) # Top-most block of the pipeline to which the environment is part of - self._env_derived: InterfaceJiminyEnv = self + self._env_derived: weakref.ProxyType = weakref.proxy(self) # Store references to the variables to register to the telemetry self._registered_variables: MutableMappingT[ @@ -730,16 +732,18 @@ def reset(self, # type: ignore[override] # Similarly, the observer and controller update periods must be set. reset_hook: Optional[Callable[[], InterfaceJiminyEnv]] = ( options or {}).get("reset_hook") - env: InterfaceJiminyEnv = self + env: weakref.ProxyType = weakref.proxy(self) if reset_hook is not None: assert callable(reset_hook) - env_derived = reset_hook() or self + env_derived = reset_hook() or env assert env_derived.unwrapped is self env = env_derived self._env_derived = env - # Instantiate the actual controller - controller = jiminy.FunctionalController(env._controller_handle) + # Instantiate the actual controller. + # Note that a weak reference must be used to avoid circular reference. + controller = jiminy.FunctionalController( + partial(env._controller_handle.__func__, env)) controller.initialize(self.robot) self.simulator.set_controller(controller) diff --git a/python/jiminy_py/src/jiminy_py/simulator.py b/python/jiminy_py/src/jiminy_py/simulator.py index 12066c88a9..0d8acaea52 100644 --- a/python/jiminy_py/src/jiminy_py/simulator.py +++ b/python/jiminy_py/src/jiminy_py/simulator.py @@ -7,7 +7,7 @@ import logging import pathlib import tempfile -from weakref import ref +import weakref from copy import deepcopy from itertools import chain from functools import partial @@ -93,21 +93,20 @@ def __init__(self, # pylint: disable=unused-argument # Wrap callback in nested function to hide update of progress bar # Note that a weak reference must be used to avoid circular reference # resulting in uncollectable object and hence memory leak. - simulator_ref = ref(self) + simulator_proxy = weakref.proxy(self) - def callback_wrapper(t: float, + def callback_wrapper(simulator_proxy: weakref.ProxyType, + t: float, *args: Any, **kwargs: Any) -> None: - nonlocal simulator_ref - simulator = simulator_ref() - assert simulator is not None - if simulator.__pbar is not None: - simulator.__pbar.update(t - simulator.__pbar.n) - simulator._callback(t, *args, **kwargs) + if simulator_proxy.__pbar is not None: + simulator_proxy.__pbar.update(t - simulator_proxy.__pbar.n) + simulator_proxy._callback(t, *args, **kwargs) # Instantiate the low-level Jiminy engine, then initialize it self.engine = engine_class() - self.engine.initialize(robot, controller, callback_wrapper) + self.engine.initialize( + robot, controller, partial(callback_wrapper, simulator_proxy)) # Create shared memories and python-native attribute for fast access self.stepper_state = self.engine.stepper_state diff --git a/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_visualizer.py b/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_visualizer.py index 807cb53534..24af71f315 100644 --- a/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_visualizer.py +++ b/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_visualizer.py @@ -8,12 +8,12 @@ import math import array import signal +import weakref import warnings import importlib import threading import multiprocessing as mp import xml.etree.ElementTree as ET -from weakref import ref from functools import wraps from itertools import chain from datetime import datetime @@ -1794,25 +1794,19 @@ def async_mode(self) -> AbstractContextManager: right before the next method execution instead of being thrown on the spot. """ - proxy_ref = ref(self) + proxy = weakref.proxy(self) class ContextAsyncMode(AbstractContextManager): """Context manager forcing async execution when forwarding request to the underlying panda3d viewer instance. """ def __enter__(self) -> None: - nonlocal proxy_ref - proxy = proxy_ref() - assert proxy is not None proxy._is_async = True def __exit__(self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: - nonlocal proxy_ref - proxy = proxy_ref() - assert proxy is not None proxy._is_async = False return ContextAsyncMode() diff --git a/python/jiminy_py/unit_py/test_flexible_arm.py b/python/jiminy_py/unit_py/test_flexible_arm.py index 42f1dfe14b..bc88deece2 100644 --- a/python/jiminy_py/unit_py/test_flexible_arm.py +++ b/python/jiminy_py/unit_py/test_flexible_arm.py @@ -102,14 +102,14 @@ def setUp(self): # Remove temporary file os.remove(urdf_path) - # Instantiate and initialize the controller - controller = jiminy.FunctionalController() - controller.initialize(robot) + # Instantiate and initialize a controller doing nothing + noop_controller = jiminy.FunctionalController() + noop_controller.initialize(robot) # Create a simulator using this robot and controller self.simulator = Simulator( robot, - controller, + noop_controller, viewer_kwargs=dict( camera_pose=((0.0, -2.0, 0.0), (np.pi/2, 0.0, 0.0), None) )) diff --git a/python/jiminy_py/unit_py/test_simple_mass.py b/python/jiminy_py/unit_py/test_simple_mass.py index 14b8b0139f..850eb6f656 100644 --- a/python/jiminy_py/unit_py/test_simple_mass.py +++ b/python/jiminy_py/unit_py/test_simple_mass.py @@ -1,11 +1,12 @@ """This file aims at verifying the sanity of the physics and the integration method of jiminy on simple mass. """ +import weakref import unittest -import numpy as np from enum import Enum -from weakref import ref from itertools import product + +import numpy as np from scipy.signal import savgol_filter import jiminy_py.core as jiminy @@ -195,20 +196,18 @@ def test_contact_sensor(self): engine = jiminy.Engine() # No control law, only check sensors data - engine_ref = ref(engine) + engine_proxy = weakref.proxy(engine) def check_sensor_measurements(t, q, v, sensor_measurements, command): # Verify sensor data, if the engine has been initialized - nonlocal engine_ref, frame_pose - engine = engine_ref() - assert engine is not None - if engine.is_initialized: + nonlocal engine_proxy + if engine_proxy.is_initialized: f_linear = sensor_measurements[ ContactSensor.type, self.body_name] f_wrench = sensor_measurements[ ForceSensor.type, self.body_name] f_contact_sensor = frame_pose * Force(f_linear, np.zeros(3)) f_force_sensor = frame_pose * Force(*np.split(f_wrench, 2)) - f_true = engine.system_state.f_external[joint_index] + f_true = engine_proxy.system_state.f_external[joint_index] self.assertTrue(np.allclose( f_contact_sensor.linear, f_true.linear, atol=TOLERANCE)) self.assertTrue(np.allclose(