Skip to content

Commit

Permalink
Add keyboard event callback for Python passive viewer.
Browse files Browse the repository at this point in the history
Fixes #766
Fixes #846

PiperOrigin-RevId: 549449372
Change-Id: I37d17f1162c66d0e8482d402aae22d9a16b59deb
  • Loading branch information
saran-t authored and Copybara-Service committed Jul 19, 2023
1 parent f7847ba commit 06b7083
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 13 deletions.
2 changes: 2 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Python bindings
``update_texture`` methods to allow users to update renderable assets.
(`#812 <https://github.com/deepmind/mujoco/issues/812>`_, `#958 <https://github.com/deepmind/mujoco/issues/958>`_,
`#965 <https://github.com/deepmind/mujoco/issues/965>`_)
- Allow a custom keyboard event callback to be specified in the :ref:`passive viewer<PyViewerPassive>`.
(`#766 <https://github.com/deepmind/mujoco/issues/766>`_)
- Fix GLFW crash when Python exits while the passive viewer is running.
(`#790 <https://github.com/deepmind/mujoco/issues/790>`_)

Expand Down
23 changes: 23 additions & 0 deletions doc/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,29 @@ illustrative example that does **not** necessarily keep the physics ticking at t
if time_until_next_step > 0:
time.sleep(time_until_next_step)
Optionally, ``viewer.launch_passive`` also accepts a callable as a keyword argument ``key_callback``, which gets called
each time a keyboard event occurs in the viewer window. This allows user scripts to react to various key presses, e.g.
pause or resume the run loop when the spacebar is pressed.

.. code-block:: python
paused = False
def key_callback(keycode):
if chr(keycode) == ' ':
global paused
paused = not paused
...
with mujoco.viewer.launch_passive(m, d, key_callback=key_callback) as viewer:
while viewer.is_running():
...
if not paused:
mujoco.mj_step(m, d)
viewer.sync()
...
.. _PyUsage:

Expand Down
9 changes: 5 additions & 4 deletions python/mujoco/mjpython/mjpython.mm
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def __init__(self):
self._termination = self.__class__.NOT_TERMINATED
self._busy = False
def launch_on_ui_thread(self, model, data, handle_return):
def launch_on_ui_thread(self, model, data, handle_return, key_callback):
with self._cond:
if self._busy or self._task is not None:
raise RuntimeError('another MuJoCo viewer is already open')
else:
self._task = (model, data, handle_return)
self._task = (model, data, handle_return, key_callback)
self._cond.notify()
def terminate(self):
Expand Down Expand Up @@ -294,10 +294,11 @@ int main(int argc, char** argv) {
break
# Otherwise, launch the viewer.
model, data, handle_return = task
model, data, handle_return, key_callback = task
ctypes.CDLL(None).mjpython_show_dock_icon()
mujoco.viewer._launch_internal(
model, data, run_physics_thread=False, handle_return=handle_return)
model, data, run_physics_thread=False, handle_return=handle_return,
key_callback=key_callback)
ctypes.CDLL(None).mjpython_hide_dock_icon()
finally:
Expand Down
35 changes: 32 additions & 3 deletions python/mujoco/simulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,33 @@ constexpr inline std::size_t sizeof_arr(const T (&arr)[N]) {
return sizeof(arr);
}

template <typename Adapter>
class UIAdapterWithPyCallback : public Adapter {
public:
template <typename... Args>
UIAdapterWithPyCallback(py::handle key_callback, Args&&... args)
: Adapter(std::forward<Args>(args)...) {
if (!key_callback.is_none()) {
Py_XINCREF(key_callback.ptr());
key_callback_ = key_callback.ptr();
}
}

~UIAdapterWithPyCallback() override { Py_XDECREF(key_callback_); }

protected:
void OnKey(int key, int scancode, int act) override {
Adapter::OnKey(key, scancode, act);
if (this->IsKeyDownEvent(act) && key_callback_) {
py::gil_scoped_acquire gil;
(py::handle(key_callback_))(this->last_key_);
}
}

private:
PyObject* key_callback_ = nullptr;
};

class SimulateWrapper {
public:
SimulateWrapper(std::unique_ptr<PlatformUIAdapter> platform_ui_adapter,
Expand Down Expand Up @@ -166,10 +193,12 @@ PYBIND11_MODULE(_simulate, pymodule) {
py::class_<SimulateWrapper>(pymodule, "Simulate")
.def_readonly_static("MAX_GEOM", &mujoco::Simulate::kMaxGeom)
.def(py::init([](py::object scn, py::object cam, py::object opt,
py::object pert, bool fully_managed) {
py::object pert, bool fully_managed,
py::object key_callback) {
return std::make_unique<SimulateWrapper>(
std::make_unique<mujoco::GlfwAdapter>(), scn, cam, opt, pert,
fully_managed);
std::make_unique<UIAdapterWithPyCallback<mujoco::GlfwAdapter>>(
key_callback),
scn, cam, opt, pert, fully_managed);
}))
.def("destroy", &SimulateWrapper::Destroy,
py::call_guard<py::gil_scoped_release>())
Expand Down
31 changes: 26 additions & 5 deletions python/mujoco/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

CallbackType = Callable[[mujoco.MjModel, mujoco.MjData], None]
LoaderType = Callable[[], Tuple[mujoco.MjModel, mujoco.MjData]]
KeyCallbackType = Callable[[int], None]

# Loader function that also returns a file path for the GUI to display.
_LoaderWithPathType = Callable[[], Tuple[mujoco.MjModel, mujoco.MjData, str]]
Expand Down Expand Up @@ -142,9 +143,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):
# Python launcher (mjpython) to implement the required dispatching mechanism.
class _MjPythonBase(metaclass=abc.ABCMeta):

def launch_on_ui_thread(self, model: mujoco.MjModel, data: mujoco.MjData):
def launch_on_ui_thread(
self,
model: mujoco.MjModel,
data: mujoco.MjData,
handle_return: Optional['queue.Queue[Handle]'],
key_callback: Optional[KeyCallbackType],
):
pass


# When running under mjpython, the launcher initializes this object.
_MJPYTHON: Optional[_MjPythonBase] = None

Expand Down Expand Up @@ -299,6 +307,7 @@ def _launch_internal(
run_physics_thread: bool,
loader: Optional[_InternalLoaderType] = None,
handle_return: Optional['queue.Queue[Handle]'] = None,
key_callback: Optional[KeyCallbackType] = None,
) -> None:
"""Internal API, so that the public API has more readable type annotations."""
if model is None and data is not None:
Expand Down Expand Up @@ -327,7 +336,7 @@ def _loader(m=model, d=data) -> Tuple[mujoco.MjModel, mujoco.MjData]:
cam = mujoco.MjvCamera()
opt = mujoco.MjvOption()
pert = mujoco.MjvPerturb()
simulate = _Simulate(scn, cam, opt, pert, run_physics_thread)
simulate = _Simulate(scn, cam, opt, pert, run_physics_thread, key_callback)

# Initialize GLFW if not using mjpython.
if _MJPYTHON is None:
Expand Down Expand Up @@ -377,12 +386,20 @@ def launch_from_path(path: str) -> None:
_launch_internal(run_physics_thread=True, loader=_file_loader(path))


def launch_passive(model: mujoco.MjModel, data: mujoco.MjData) -> Handle:
def launch_passive(
model: mujoco.MjModel,
data: mujoco.MjData,
*,
key_callback: Optional[KeyCallbackType] = None,
) -> Handle:
"""Launches a passive Simulate GUI without blocking the running thread."""
if not isinstance(model, mujoco.MjModel):
raise ValueError(f'`model` is not a mujoco.MjModel: got {model!r}')
if not isinstance(data, mujoco.MjData):
raise ValueError(f'`data` is not a mujoco.MjData: got {data!r}')
if key_callback is not None and not callable(key_callback):
raise ValueError(
f'`key_callback` is not callable: got {key_callback!r}')

mujoco.mj_forward(model, data)
handle_return = queue.Queue(1)
Expand All @@ -391,7 +408,11 @@ def launch_passive(model: mujoco.MjModel, data: mujoco.MjData) -> Handle:
thread = threading.Thread(
target=_launch_internal,
args=(model, data),
kwargs=dict(run_physics_thread=False, handle_return=handle_return),
kwargs=dict(
run_physics_thread=False,
handle_return=handle_return,
key_callback=key_callback,
),
)
thread.daemon = True
thread.start()
Expand All @@ -400,7 +421,7 @@ def launch_passive(model: mujoco.MjModel, data: mujoco.MjData) -> Handle:
raise RuntimeError(
'`launch_passive` requires that the Python script be run under '
'`mjpython` on macOS')
_MJPYTHON.launch_on_ui_thread(model, data, handle_return)
_MJPYTHON.launch_on_ui_thread(model, data, handle_return, key_callback)

return handle_return.get()

Expand Down
2 changes: 2 additions & 0 deletions simulate/platform_ui_adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ void PlatformUIAdapter::OnKey(int key, int scancode, int act) {
if (event_callback_) {
event_callback_(&state_);
}

last_key_ = mj_key;
}

void PlatformUIAdapter::OnMouseButton(int button, int act) {
Expand Down
3 changes: 2 additions & 1 deletion simulate/platform_ui_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@ class PlatformUIAdapter {

// Event handlers
void OnFilesDrop(int count, const char** paths);
void OnKey(int key, int scancode, int act);
virtual void OnKey(int key, int scancode, int act);
void OnMouseButton(int button, int act);
void OnMouseMove(double x, double y);
void OnScroll(double xoffset, double yoffset);
void OnWindowRefresh();
void OnWindowResize(int width, int height);

mjuiState state_;
int last_key_;
void (*event_callback_)(mjuiState*);
void (*layout_callback_)(mjuiState*);

Expand Down

0 comments on commit 06b7083

Please sign in to comment.