Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym/common] Fix circular reference issue preventing garbage collection. #715

Merged
merged 2 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
"""
import os
import math
import weakref
import logging
import tempfile
from copy import deepcopy
from collections import OrderedDict
from collections.abc import Mapping
from itertools import chain
from functools import partial
from typing import (
Dict, Any, List, cast, no_type_check, Optional, Tuple, Callable, Iterable,
Union, SupportsFloat, Iterator, Generic, Sequence, Mapping as MappingT,
Expand Down Expand Up @@ -733,13 +735,15 @@ def reset(self, # type: ignore[override]
env: InterfaceJiminyEnv = self
if reset_hook is not None:
assert callable(reset_hook)
env_derived = reset_hook() or self
env_derived = reset_hook() or env
assert env_derived.unwrapped is self
env = env_derived
self._env_derived = env

# Instantiate the actual controller
controller = jiminy.FunctionalController(env._controller_handle)
# Instantiate the actual controller.
# Note that a weak reference must be used to avoid circular reference.
controller = jiminy.FunctionalController(
partial(type(env)._controller_handle, weakref.proxy(env)))
controller.initialize(self.robot)
self.simulator.set_controller(controller)

Expand Down
16 changes: 16 additions & 0 deletions python/gym_jiminy/unit_py/test_pipeline_design.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
""" TODO: Write documentation
"""
import gc
import os
import weakref
import unittest

import numpy as np
Expand Down Expand Up @@ -109,6 +111,20 @@ def test_override_default(self):
env = self.ANYmalPipelineEnv()
self.assertEqual(env.unwrapped.step_dt, self.step_dt)

def test_memory_leak(self):
"""Check that memory is freed when environment goes out of scope.

This test aims to detect circular references between Python and C++
objects that cannot be tracked by Python, which would make it
impossible for the garbage collector to release memory.
"""
env = self.ANYmalPipelineEnv()
env.reset()
proxy = weakref.proxy(env)
env = None
gc.collect()
self.assertRaises(ReferenceError, lambda: proxy.action)

def test_initial_state(self):
""" TODO: Write documentation
"""
Expand Down
19 changes: 9 additions & 10 deletions python/jiminy_py/src/jiminy_py/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import pathlib
import tempfile
from weakref import ref
import weakref
from copy import deepcopy
from itertools import chain
from functools import partial
Expand Down Expand Up @@ -93,21 +93,20 @@ def __init__(self, # pylint: disable=unused-argument
# Wrap callback in nested function to hide update of progress bar
# Note that a weak reference must be used to avoid circular reference
# resulting in uncollectable object and hence memory leak.
simulator_ref = ref(self)
simulator_proxy = weakref.proxy(self)

def callback_wrapper(t: float,
def callback_wrapper(simulator_proxy: weakref.ProxyType,
t: float,
*args: Any,
**kwargs: Any) -> None:
nonlocal simulator_ref
simulator = simulator_ref()
assert simulator is not None
if simulator.__pbar is not None:
simulator.__pbar.update(t - simulator.__pbar.n)
simulator._callback(t, *args, **kwargs)
if simulator_proxy.__pbar is not None:
simulator_proxy.__pbar.update(t - simulator_proxy.__pbar.n)
simulator_proxy._callback(t, *args, **kwargs)

# Instantiate the low-level Jiminy engine, then initialize it
self.engine = engine_class()
self.engine.initialize(robot, controller, callback_wrapper)
self.engine.initialize(
robot, controller, partial(callback_wrapper, simulator_proxy))

# Create shared memories and python-native attribute for fast access
self.stepper_state = self.engine.stepper_state
Expand Down
38 changes: 19 additions & 19 deletions python/jiminy_py/src/jiminy_py/viewer/meshcat/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def __init__(self,
with open(os.devnull, 'w') as stdout, redirect_stdout(stdout):
with open(os.devnull, 'w') as stderr, redirect_stderr(stderr):
self.gui = meshcat.Visualizer(zmq_url)
self.__zmq_socket = self.gui.window.zmq_socket
self._zmq_socket = self.gui.window.zmq_socket

# Create a backend recorder. It is not fully initialized to reduce
# overhead when not used, which is way more usual than the contrary.
Expand Down Expand Up @@ -345,10 +345,10 @@ def __del__(self) -> None:
def close(self) -> None:
""" TODO: Write documentation.
"""
if hasattr(self, "__zmq_socket"):
if not self.__zmq_socket.closed:
self.__zmq_socket.send(b"stop")
self.__zmq_socket.close()
if hasattr(self, "_zmq_socket"):
if not self._zmq_socket.closed:
self._zmq_socket.send(b"stop")
self._zmq_socket.close()
if hasattr(self, "comm_manager") and self.comm_manager is not None:
self.comm_manager.close()
if hasattr(self, "recorder") is not None:
Expand All @@ -363,15 +363,15 @@ def wait(self, require_client: bool = False) -> str:
# perform a single `do_one_iteration`, just in case there is
# already comm waiting in the queue to be registered, but it should
# not be necessary.
self.__zmq_socket.send(b"wait")
self._zmq_socket.send(b"wait")
if self.comm_manager is None:
self.__zmq_socket.recv()
self._zmq_socket.recv()
else:
while True:
try:
# Try first, just in case there is already a comm for
# websocket available.
self.__zmq_socket.recv(flags=zmq.NOBLOCK)
self._zmq_socket.recv(flags=zmq.NOBLOCK)
break
except zmq.error.ZMQError:
# No websocket nor comm connection available at this
Expand All @@ -395,21 +395,21 @@ def wait(self, require_client: bool = False) -> str:
# of comms currently registered. It is necessary to check for a reply
# of the server periodically, and the number of responses corresponds
# to the actual number of comms.
self.__zmq_socket.send(b"ready")
self._zmq_socket.send(b"ready")
if self.comm_manager is not None:
while True:
process_kernel_comm()
try:
msg = self.__zmq_socket.recv(flags=zmq.NOBLOCK)
msg = self._zmq_socket.recv(flags=zmq.NOBLOCK)
return msg.decode("utf-8")
except zmq.error.ZMQError:
pass
return self.__zmq_socket.recv().decode("utf-8")
return self._zmq_socket.recv().decode("utf-8")

def set_legend_item(self, uniq_id: str, color: str, text: str) -> None:
""" TODO: Write documentation.
"""
self.__zmq_socket.send_multipart([
self._zmq_socket.send_multipart([
b"set_property", # Frontend command. Used by Python zmq server
b"", # Tree path. Empty path means root
umsgpack.packb({ # Backend command. Used by javascript
Expand All @@ -419,12 +419,12 @@ def set_legend_item(self, uniq_id: str, color: str, text: str) -> None:
"color": color # "rgba(0, 0, 0, 0)" and "black" supported
})
])
self.__zmq_socket.recv() # Receive acknowledgement
self._zmq_socket.recv() # Receive acknowledgement

def remove_legend_item(self, uniq_id: str) -> None:
""" TODO: Write documentation.
"""
self.__zmq_socket.send_multipart([
self._zmq_socket.send_multipart([
b"set_property",
b"",
umsgpack.packb({
Expand All @@ -433,7 +433,7 @@ def remove_legend_item(self, uniq_id: str) -> None:
"text": "" # Empty message means delete the item, if any
})
])
self.__zmq_socket.recv()
self._zmq_socket.recv()

def set_watermark(self,
img_fullpath: str,
Expand Down Expand Up @@ -465,7 +465,7 @@ def set_watermark(self,
img_data = f"data:image/{img_format};base64,{img_raw}"

# Send ZMQ request to acknowledge reply
self.__zmq_socket.send_multipart([
self._zmq_socket.send_multipart([
b"set_property",
b"",
umsgpack.packb({
Expand All @@ -475,20 +475,20 @@ def set_watermark(self,
"height": height
})
])
self.__zmq_socket.recv()
self._zmq_socket.recv()

def remove_watermark(self) -> None:
""" TODO: Write documentation.
"""
self.__zmq_socket.send_multipart([
self._zmq_socket.send_multipart([
b"set_property",
b"",
umsgpack.packb({
"type": "watermark",
"data": "" # Empty string means delete the watermark, if any
})
])
self.__zmq_socket.recv()
self._zmq_socket.recv()

def start_recording(self, fps: float, width: int, height: int) -> None:
""" TODO: Write documentation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import math
import array
import signal
import weakref
import warnings
import importlib
import threading
import multiprocessing as mp
import xml.etree.ElementTree as ET
from weakref import ref
from functools import wraps
from itertools import chain
from datetime import datetime
Expand Down Expand Up @@ -1794,25 +1794,19 @@ def async_mode(self) -> AbstractContextManager:
right before the next method execution instead of being thrown on
the spot.
"""
proxy_ref = ref(self)
proxy = weakref.proxy(self)

class ContextAsyncMode(AbstractContextManager):
"""Context manager forcing async execution when forwarding request
to the underlying panda3d viewer instance.
"""
def __enter__(self) -> None:
nonlocal proxy_ref
proxy = proxy_ref()
assert proxy is not None
proxy._is_async = True

def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
nonlocal proxy_ref
proxy = proxy_ref()
assert proxy is not None
proxy._is_async = False

return ContextAsyncMode()
Expand Down
8 changes: 4 additions & 4 deletions python/jiminy_py/unit_py/test_flexible_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ def setUp(self):
# Remove temporary file
os.remove(urdf_path)

# Instantiate and initialize the controller
controller = jiminy.FunctionalController()
controller.initialize(robot)
# Instantiate and initialize a controller doing nothing
noop_controller = jiminy.FunctionalController()
noop_controller.initialize(robot)

# Create a simulator using this robot and controller
self.simulator = Simulator(
robot,
controller,
noop_controller,
viewer_kwargs=dict(
camera_pose=((0.0, -2.0, 0.0), (np.pi/2, 0.0, 0.0), None)
))
Expand Down
15 changes: 7 additions & 8 deletions python/jiminy_py/unit_py/test_simple_mass.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""This file aims at verifying the sanity of the physics and the integration
method of jiminy on simple mass.
"""
import weakref
import unittest
import numpy as np
from enum import Enum
from weakref import ref
from itertools import product

import numpy as np
from scipy.signal import savgol_filter

import jiminy_py.core as jiminy
Expand Down Expand Up @@ -195,20 +196,18 @@ def test_contact_sensor(self):
engine = jiminy.Engine()

# No control law, only check sensors data
engine_ref = ref(engine)
engine_proxy = weakref.proxy(engine)
def check_sensor_measurements(t, q, v, sensor_measurements, command):
# Verify sensor data, if the engine has been initialized
nonlocal engine_ref, frame_pose
engine = engine_ref()
assert engine is not None
if engine.is_initialized:
nonlocal engine_proxy
if engine_proxy.is_initialized:
f_linear = sensor_measurements[
ContactSensor.type, self.body_name]
f_wrench = sensor_measurements[
ForceSensor.type, self.body_name]
f_contact_sensor = frame_pose * Force(f_linear, np.zeros(3))
f_force_sensor = frame_pose * Force(*np.split(f_wrench, 2))
f_true = engine.system_state.f_external[joint_index]
f_true = engine_proxy.system_state.f_external[joint_index]
self.assertTrue(np.allclose(
f_contact_sensor.linear, f_true.linear, atol=TOLERANCE))
self.assertTrue(np.allclose(
Expand Down
Loading