Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Return callables from callables in Deephaven UI #540

Merged
merged 8 commits into from
Jun 19, 2024
7 changes: 5 additions & 2 deletions plugins/ui/src/deephaven/ui/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _wrapped_callable(
func: Callable,
*args: Any,
**kwargs: Any,
) -> None:
) -> Any:
"""
Filter the args and kwargs and call the specified function with the filtered args and kwargs.

Expand All @@ -139,14 +139,17 @@ def _wrapped_callable(
func: The function to call
*args: args, used by the dispatcher
**kwargs: kwargs, used by the dispatcher

Returns:
The result of the function call.
"""
args = args if max_args is None else args[:max_args]
kwargs = (
kwargs
if kwargs_set is None
else {k: v for k, v in kwargs.items() if k in kwargs_set}
)
func(*args, **kwargs)
return func(*args, **kwargs)


def wrap_callable(func: Callable) -> Callable:
Expand Down
83 changes: 80 additions & 3 deletions plugins/ui/src/deephaven/ui/object_types/ElementMessageStream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from hmac import new
mattrunyon marked this conversation as resolved.
Show resolved Hide resolved
import io
import json
import sys
Expand All @@ -19,6 +20,7 @@
from .._internal import wrap_callable
from ..elements import Element
from ..renderer import NodeEncoder, Renderer, RenderedNode
from ..renderer.NodeEncoder import CALLABLE_KEY
from .._internal import RenderContext, StateUpdateCallable, ExportedRenderState
from .ErrorCode import ErrorCode

Expand Down Expand Up @@ -98,6 +100,24 @@ class ElementMessageStream(MessageStream):
Callables and render functions to be called on the next render loop.
"""

_callable_dict: dict[str, Callable]
"""
Dict of callable IDs to callables.
This is intended to be used by the renderer and can be replaced on each render.
"""

_temp_callable_dict: dict[str, Callable]
"""
Dict of callable IDs to callables returned from other callables.
These are not generated by the renderer and can be removed by the client.
This should not be cleaned out on each render like _callable_dict.
"""

_next_temp_callable_id: int
"""
The next ID to use for temporary callables.
"""

_render_lock: threading.Lock
"""
Lock to ensure only one thread is rendering at a time.
Expand Down Expand Up @@ -147,6 +167,9 @@ def __init__(self, element: Element, connection: MessageStream):
self._renderer = Renderer(self._context)
self._update_queue = Queue()
self._callable_queue = Queue()
self._callable_dict = {}
self._temp_callable_dict = {}
self._next_temp_callable_id = 0
self._render_lock = threading.Lock()
self._is_dirty = False
self._render_state = _RenderState.IDLE
Expand Down Expand Up @@ -327,6 +350,8 @@ def _make_request(self, method: str, *params: Any) -> dict[str, Any]:
def _make_dispatcher(self) -> Dispatcher:
dispatcher = Dispatcher()
dispatcher["setState"] = self._set_state
dispatcher["callCallable"] = self._call_callable
dispatcher["closeCallable"] = self._close_callable
return dispatcher

def _set_state(self, state: ExportedRenderState) -> None:
Expand All @@ -340,6 +365,58 @@ def _set_state(self, state: ExportedRenderState) -> None:
self._context.import_state(state)
self._mark_dirty()

def _call_callable(self, callable_id: str, args: Any) -> Any:
"""
Call a callable by its ID.
If the result is a callable, it is registered as a temporary callable.

Args:
callable_id: The ID of the callable to call
args: The array of arguments to pass to the callable. These will be spread as positional args to the callable.
"""
logger.debug("Calling callable %s with %s", callable_id, args)
fn = self._callable_dict.get(callable_id) or self._temp_callable_dict.get(
callable_id
)
if fn is None:
logger.error("Callable not found: %s", callable_id)
return
result = fn(*args)

def serialize_callables(node: Any) -> Any:
if callable(node):
new_id = f"tempCb{self._next_temp_callable_id}"
self._next_temp_callable_id += 1
self._temp_callable_dict[new_id] = node
return {
CALLABLE_KEY: new_id,
}
raise TypeError(
f"A Deephaven UI callback returned a non-serializable value. Object of type {type(node).__name__} is not JSON serializable"
)

try:
return json.dumps(result, default=serialize_callables)
except Exception as e:
# This is shown to the user in the Python console
# The stack trace from logger.exception is useless to the user
# Stack trace only includes the internals of the serialization process
logger.error(e)
return {
mofojed marked this conversation as resolved.
Show resolved Hide resolved
"serialization_error": f"Cannot serialize callable {callable_id} result"
}

def _close_callable(self, callable_id: str) -> None:
"""
Close a callable by its ID.

Args:
callable_id: The ID of the callable to close
"""
logger.debug("Closing callable %s", callable_id)
self._callable_dict.pop(callable_id, None)
self._temp_callable_dict.pop(callable_id, None)

def _send_document_update(
self, root: RenderedNode, state: ExportedRenderState
) -> None:
Expand All @@ -366,11 +443,11 @@ def _send_document_update(
payload = json.dumps(request)
logger.debug(f"Sending payload: {payload}")

dispatcher = self._make_dispatcher()
callable_dict = {}
for callable, callable_id in callable_id_dict.items():
logger.debug("Registering callable %s", callable_id)
dispatcher[callable_id] = wrap_callable(callable)
self._dispatcher = dispatcher
callable_dict[callable_id] = wrap_callable(callable)
self._callable_dict = callable_dict
if self._is_closed:
# The connection is closed, so this component will not update anymore
# delete the context so the objects in the collected scope are released
Expand Down
25 changes: 18 additions & 7 deletions plugins/ui/src/js/src/widget/WidgetHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import {
import { WidgetDescriptor } from '@deephaven/dashboard';
import type { dh } from '@deephaven/jsapi-types';
import Log from '@deephaven/log';
import { EMPTY_FUNCTION } from '@deephaven/utils';
import { EMPTY_FUNCTION, assertNotNull } from '@deephaven/utils';
import {
CALLABLE_KEY,
OBJECT_KEY,
Expand All @@ -34,7 +34,7 @@ import {
METHOD_DOCUMENT_UPDATED,
} from './WidgetTypes';
import DocumentHandler from './DocumentHandler';
import { getComponentForElement } from './WidgetUtils';
import { getComponentForElement, wrapCallable } from './WidgetUtils';
import WidgetErrorView from './WidgetErrorView';
import ReactPanelContentOverlayContext from '../layout/ReactPanelContentOverlayContext';

Expand Down Expand Up @@ -113,6 +113,15 @@ function WidgetHandler({
[jsonClient]
);

const callableFinalizationRegistry = useMemo(
() =>
new FinalizationRegistry(callableId => {
log.debug2('Closing callable', callableId);
jsonClient?.request('closeCallable', [callableId]);
}),
[jsonClient]
);

const parseDocument = useCallback(
/**
* Parse the data from the server, replacing some of the nodes on the way.
Expand All @@ -124,6 +133,7 @@ function WidgetHandler({
* @returns The parsed data
*/
(data: string) => {
assertNotNull(jsonClient);
// Keep track of exported objects that are no longer in use after this render.
// We close those objects that are no longer referenced, as they will never be referenced again.
const deadObjectMap = new Map(exportedObjectMap.current);
Expand All @@ -133,10 +143,11 @@ function WidgetHandler({
if (isCallableNode(value)) {
const callableId = value[CALLABLE_KEY];
log.debug2('Registering callableId', callableId);
return async (...args: unknown[]) => {
log.debug('Callable called', callableId, ...args);
return jsonClient?.request(callableId, args);
};
return wrapCallable(
jsonClient,
callableId,
callableFinalizationRegistry
);
}
if (isObjectNode(value)) {
// Replace this node with the exported object
Expand Down Expand Up @@ -180,7 +191,7 @@ function WidgetHandler({
);
return parsedData;
},
[jsonClient]
[jsonClient, callableFinalizationRegistry]
);

const updateExportedObjects = useCallback(
Expand Down
122 changes: 121 additions & 1 deletion plugins/ui/src/js/src/widget/WidgetUtils.test.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import React from 'react';
import type { JSONRPCServerAndClient } from 'json-rpc-2.0';
import { Text } from '@deephaven/components';
import { TestUtils } from '@deephaven/utils';
import { ELEMENT_NAME, ELEMENT_PREFIX } from '../elements/ElementConstants';
import { ElementNode, ELEMENT_KEY } from '../elements/ElementUtils';
import {
ElementNode,
ELEMENT_KEY,
CALLABLE_KEY,
} from '../elements/ElementUtils';
import HTMLElementView from '../elements/HTMLElementView';
import IconElementView from '../elements/IconElementView';
import { SPECTRUM_ELEMENT_TYPE_PREFIX } from '../elements/SpectrumElementUtils';
Expand All @@ -11,8 +17,23 @@ import {
getComponentForElement,
getComponentTypeForElement,
getPreservedData,
wrapCallable,
} from './WidgetUtils';

const mockJsonRequest = jest.fn(() =>
Promise.resolve(JSON.stringify({ result: 'mock' }))
);

const mockJsonClient = TestUtils.createMockProxy<JSONRPCServerAndClient>({
request: mockJsonRequest,
});

const mockFinilizationRegistry = TestUtils.createMockProxy<
FinalizationRegistry<unknown>
>({
register: jest.fn(),
});

describe('getComponentTypeForElement', () => {
it.each(
Object.keys(elementComponentMap) as (keyof typeof elementComponentMap)[]
Expand Down Expand Up @@ -91,3 +112,102 @@ describe('getPreservedData', () => {
expect(actual).toEqual({ panelIds: widgetData.panelIds });
});
});

describe('wrapCallable', () => {
beforeEach(() => {
jest.clearAllMocks();
});

it('should return a function that sends a request to the client', () => {
wrapCallable(mockJsonClient, 'testMethod', mockFinilizationRegistry)();
expect(mockJsonClient.request).toHaveBeenCalledWith('callCallable', [
'testMethod',
[],
]);
});

it('should return a function that sends a request to the client with args', () => {
wrapCallable(
mockJsonClient,
'testMethod',
mockFinilizationRegistry
)('a', { b: 'b' });
expect(mockJsonClient.request).toHaveBeenCalledWith('callCallable', [
'testMethod',
['a', { b: 'b' }],
]);
});

it('should register the function in the finalization registry', () => {
const wrapped = wrapCallable(
mockJsonClient,
'testMethod',
mockFinilizationRegistry
);

expect(mockFinilizationRegistry.register).toHaveBeenCalledWith(
wrapped,
'testMethod',
wrapped
);
});

it('should wrap returned callables', async () => {
mockJsonRequest.mockResolvedValueOnce(
JSON.stringify({
[CALLABLE_KEY]: 'nestedCb',
})
);

const wrappedResult = await wrapCallable(
mockJsonClient,
'testMethod',
mockFinilizationRegistry
)();
expect(wrappedResult).toBeInstanceOf(Function);

expect(mockFinilizationRegistry.register).toHaveBeenCalledTimes(2);
expect(mockFinilizationRegistry.register).toHaveBeenLastCalledWith(
wrappedResult,
'nestedCb',
wrappedResult
);
});

it('should wrap nested returned callables', async () => {
mockJsonRequest.mockResolvedValueOnce(
JSON.stringify({
nestedCallable: {
[CALLABLE_KEY]: 'nestedCb',
},
someOtherProp: 'mock',
})
);

const wrappedResult = (await wrapCallable(
mockJsonClient,
'testMethod',
mockFinilizationRegistry
)()) as { nestedCallable: () => void; someOtherProp: string };

expect(wrappedResult).toMatchObject({
nestedCallable: expect.any(Function),
someOtherProp: 'mock',
});

expect(mockFinilizationRegistry.register).toHaveBeenCalledTimes(2);
expect(mockFinilizationRegistry.register).toHaveBeenLastCalledWith(
wrappedResult.nestedCallable,
'nestedCb',
wrappedResult.nestedCallable
);
});

it('should reject if the result is not parseable', () => {
mockJsonRequest.mockResolvedValueOnce('not a json string');

expect(
wrapCallable(mockJsonClient, 'testMethod', mockFinilizationRegistry)()
).rejects.toBeInstanceOf(Error);
});
});
Loading
Loading