From e3fc8070d78b28111d92140a498ec434316e2583 Mon Sep 17 00:00:00 2001 From: Chris Dickinson Date: Thu, 21 Nov 2024 15:07:16 -0800 Subject: [PATCH] feat: add host context support Add the ability to pass "per invocation" host context accessible via CurrentPlugin. Additionally, allow referring to CurrentPlugin in type-inferred host functions. --- extism/extism.py | 22 +++++++++++++++++++--- tests/test_extism.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/extism/extism.py b/extism/extism.py index 3a7d29d..6224e57 100644 --- a/extism/extism.py +++ b/extism/extism.py @@ -398,11 +398,18 @@ def __init__(self, namespace, name, func, user_data): arg_names = [arg for arg in hints.keys() if arg != "return"] returns = hints.pop("return", None) + uses_current_plugin = False + if len(arg_names) > 0 and hints.get(arg_names[0], None) == CurrentPlugin: + uses_current_plugin = True + arg_names = arg_names[1:] + args = [_map_arg(arg, hints[arg]) for arg in arg_names] + returns = [] if returns is None else _map_ret(returns) def inner_func(plugin, inputs, outputs, *user_data): - inner_args = [ + first_arg = [plugin] if uses_current_plugin else [] + inner_args = first_arg + [ extract(plugin, slot) for ((_, extract), slot) in zip(args, inputs) ] @@ -523,6 +530,7 @@ def call( function_name: str, data: Union[str, bytes], parse: Callable[[Any], Any] = lambda xs: bytes(xs), + host_context: Any = None, ) -> Any: """ Call a function by name with the provided input data @@ -533,11 +541,13 @@ def call( :raises: An :py:class:`extism.Error <.extism.Error>` if the guest function call was unsuccessful. :returns: The returned bytes from the guest function as interpreted by the ``parse`` parameter. """ + + host_context = _ffi.new_handle(host_context) if isinstance(data, str): data = data.encode() self._check_error( - _lib.extism_plugin_call( - self.plugin, function_name.encode(), data, len(data) + _lib.extism_plugin_call_with_host_context( + self.plugin, function_name.encode(), data, len(data), host_context ) ) out_len = _lib.extism_plugin_output_length(self.plugin) @@ -608,6 +618,12 @@ def memory(self, mem: Memory) -> _ffi.buffer: return None return _ffi.buffer(p + mem.offset, mem.length) + def host_context(self) -> Any: + result = _lib.extism_current_plugin_host_context(self.pointer) + if result == 0: + return None + return _ffi.from_handle(result) + def alloc(self, size: int) -> Memory: """ Allocate a new block of memory. diff --git a/tests/test_extism.py b/tests/test_extism.py index f76e59b..c401a56 100644 --- a/tests/test_extism.py +++ b/tests/test_extism.py @@ -1,3 +1,4 @@ +from collections import namedtuple import unittest import extism import hashlib @@ -148,6 +149,40 @@ def hello_world( self.assertIsInstance(result, Gribble) self.assertEqual(result.frobbitz(), "gromble robble") + def test_host_context(self): + if not hasattr(typing, "Annotated"): + return + + # Testing two things here: one, if we see CurrentPlugin as the first arg, we pass it through. + # Two, it's possible to refer to fetch the host context from the current plugin. + @extism.host_fn(user_data=b"test") + def hello_world( + current_plugin: extism.CurrentPlugin, + inp: typing.Annotated[dict, extism.Json], + *user_data, + ) -> typing.Annotated[Gribble, extism.Pickle]: + ctx = current_plugin.host_context() + ctx.x = 1000 + return Gribble("robble") + + plugin = extism.Plugin( + self._manifest(functions=True), functions=[hello_world], wasi=True + ) + + class Foo: + x = 100 + y = 200 + + foo = Foo() + + res = plugin.call("count_vowels", "aaa", host_context=foo) + + self.assertEqual(foo.x, 1000) + self.assertEqual(foo.y, 200) + result = pickle.loads(res) + self.assertIsInstance(result, Gribble) + self.assertEqual(result.frobbitz(), "gromble robble") + def test_extism_plugin_cancel(self): plugin = extism.Plugin(self._loop_manifest()) cancel_handle = plugin.cancel_handle()