From a551161617748e023acca65428fadd4ee508c05b Mon Sep 17 00:00:00 2001 From: LJ Date: Sat, 19 Apr 2025 12:39:56 -0700 Subject: [PATCH] feat(async-fn): allow `prepare` to be async; rs always call py in async --- python/cocoindex/op.py | 46 ++++++++----------- src/ops/py_factory.rs | 100 ++++++++++++++++++----------------------- src/py/mod.rs | 7 +-- 3 files changed, 62 insertions(+), 91 deletions(-) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index f97ee853..f46f3cbd 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -5,9 +5,9 @@ import dataclasses import inspect -from typing import get_type_hints, Protocol, Any, Callable, dataclass_transform +from typing import get_type_hints, Protocol, Any, Callable, Awaitable, dataclass_transform from enum import Enum -from threading import Lock +from functools import partial from .typing import encode_enriched_type from .convert import to_engine_value, make_engine_value_converter @@ -61,7 +61,7 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs): return (encode_enriched_type(result_type), executor) -_gpu_dispatch_lock = Lock() +_gpu_dispatch_lock = asyncio.Lock() @dataclasses.dataclass class OpArgs: @@ -75,11 +75,15 @@ class OpArgs: cache: bool = False behavior_version: int | None = None +def _to_async_call(call: Callable) -> Callable[..., Awaitable[Any]]: + if inspect.iscoroutinefunction(call): + return call + return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs)) + def _register_op_factory( category: OpCategory, expected_args: list[tuple[str, inspect.Parameter]], expected_return, - is_async: bool, executor_cls: type, spec_cls: type, op_args: OpArgs, @@ -97,10 +101,12 @@ def behavior_version(self): class _WrappedClass(executor_cls, _Fallback): _args_converters: list[Callable[[Any], Any]] _kwargs_converters: dict[str, Callable[[str, Any], Any]] + _acall: Callable def __init__(self, spec): super().__init__() self.spec = spec + self._acall = _to_async_call(super().__call__) def analyze(self, *args, **kwargs): """ @@ -157,31 +163,19 @@ def analyze(self, *args, **kwargs): else: return expected_return - def prepare(self): + async def prepare(self): """ Prepare for execution. It's executed after `analyze` and before any `__call__` execution. """ - setup_method = getattr(executor_cls, 'prepare', None) + setup_method = getattr(super(), 'prepare', None) if setup_method is not None: - setup_method(self) + await _to_async_call(setup_method)() - def __call__(self, *args, **kwargs): + async def __call__(self, *args, **kwargs): converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args)) converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg) for arg_name, arg in kwargs.items()} - if is_async: - async def _inner(): - if op_args.gpu: - await asyncio.to_thread(_gpu_dispatch_lock.acquire) - try: - output = await super(_WrappedClass, self).__call__( - *converted_args, **converted_kwargs) - finally: - if op_args.gpu: - _gpu_dispatch_lock.release() - return to_engine_value(output) - return _inner() if op_args.gpu: # For GPU executions, data-level parallelism is applied, so we don't want to @@ -189,10 +183,10 @@ async def _inner(): # Besides, multiprocessing is more appropriate for pytorch. # For now, we use a lock to ensure only one task is executed at a time. # TODO: Implement multi-processing dispatching. - with _gpu_dispatch_lock: - output = super().__call__(*converted_args, **converted_kwargs) + async with _gpu_dispatch_lock: + output = await self._acall(*converted_args, **converted_kwargs) else: - output = super().__call__(*converted_args, **converted_kwargs) + output = await self._acall(*converted_args, **converted_kwargs) return to_engine_value(output) _WrappedClass.__name__ = executor_cls.__name__ @@ -203,9 +197,7 @@ async def _inner(): if category == OpCategory.FUNCTION: _engine.register_function_factory( - spec_cls.__name__, - _FunctionExecutorFactory(spec_cls, _WrappedClass), - is_async) + spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass)) else: raise ValueError(f"Unsupported executor type {category}") @@ -230,7 +222,6 @@ def _inner(cls: type[Executor]) -> type: category=spec_cls._op_category, expected_args=list(sig.parameters.items())[1:], # First argument is `self` expected_return=sig.return_annotation, - is_async=inspect.iscoroutinefunction(cls.__call__), executor_cls=cls, spec_cls=spec_cls, op_args=op_args) @@ -266,7 +257,6 @@ class _Spec(FunctionSpec): category=OpCategory.FUNCTION, expected_args=list(sig.parameters.items()), expected_return=sig.return_annotation, - is_async=inspect.iscoroutinefunction(fn), executor_cls=_Executor, spec_cls=_Spec, op_args=op_args) diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index 7dddb7a8..3eec3ac8 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -39,7 +39,6 @@ impl PyOpArgSchema { struct PyFunctionExecutor { py_function_executor: Py, - is_async: bool, py_exec_ctx: Arc, num_positional_args: usize, @@ -91,36 +90,22 @@ impl PyFunctionExecutor { impl SimpleFunctionExecutor for Arc { async fn evaluate(&self, input: Vec) -> Result { let self = self.clone(); - let result = if self.is_async { - let result_fut = Python::with_gil(|py| -> Result<_> { - let result = self.call_py_fn(py, input)?; - let task_locals = pyo3_async_runtimes::TaskLocals::new( - self.py_exec_ctx.event_loop.bind(py).clone(), - ); - Ok(pyo3_async_runtimes::into_future_with_locals( - &task_locals, - result, - )?) - })?; - let result = result_fut.await?; - Python::with_gil(|py| -> Result<_> { - Ok(py::value_from_py_object( - &self.result_type.typ, - &result.into_bound(py), - )?) - })? - } else { - tokio::task::spawn_blocking(move || { - Python::with_gil(|py| -> Result<_> { - Ok(py::value_from_py_object( - &self.result_type.typ, - &self.call_py_fn(py, input)?, - )?) - }) - }) - .await?? - }; - Ok(result) + let result_fut = Python::with_gil(|py| -> Result<_> { + let result_coro = self.call_py_fn(py, input)?; + let task_locals = + pyo3_async_runtimes::TaskLocals::new(self.py_exec_ctx.event_loop.bind(py).clone()); + Ok(pyo3_async_runtimes::into_future_with_locals( + &task_locals, + result_coro, + )?) + })?; + let result = result_fut.await?; + Python::with_gil(|py| -> Result<_> { + Ok(py::value_from_py_object( + &self.result_type.typ, + &result.into_bound(py), + )?) + }) } fn enable_cache(&self) -> bool { @@ -134,7 +119,6 @@ impl SimpleFunctionExecutor for Arc { pub(crate) struct PyFunctionFactory { pub py_function_factory: Py, - pub is_async: bool, } impl SimpleFunctionFactory for PyFunctionFactory { @@ -195,31 +179,33 @@ impl SimpleFunctionFactory for PyFunctionFactory { .as_ref() .ok_or_else(|| anyhow!("Python execution context is missing"))? .clone(); - let executor = tokio::task::spawn_blocking(move || -> Result<_> { - let (enable_cache, behavior_version) = - Python::with_gil(|py| -> anyhow::Result<_> { - executor.call_method(py, "prepare", (), None)?; - let enable_cache = executor - .call_method(py, "enable_cache", (), None)? - .extract::(py)?; - let behavior_version = executor - .call_method(py, "behavior_version", (), None)? - .extract::>(py)?; - Ok((enable_cache, behavior_version)) - })?; - Ok(Box::new(Arc::new(PyFunctionExecutor { - py_function_executor: executor, - is_async: self.is_async, - py_exec_ctx, - num_positional_args, - kw_args_names, - result_type, - enable_cache, - behavior_version, - })) as Box) - }) - .await??; - Ok(executor) + let (prepare_fut, enable_cache, behavior_version) = + Python::with_gil(|py| -> anyhow::Result<_> { + let prepare_coro = executor.call_method(py, "prepare", (), None)?; + let prepare_fut = pyo3_async_runtimes::into_future_with_locals( + &pyo3_async_runtimes::TaskLocals::new( + py_exec_ctx.event_loop.bind(py).clone(), + ), + prepare_coro.into_bound(py), + )?; + let enable_cache = executor + .call_method(py, "enable_cache", (), None)? + .extract::(py)?; + let behavior_version = executor + .call_method(py, "behavior_version", (), None)? + .extract::>(py)?; + Ok((prepare_fut, enable_cache, behavior_version)) + })?; + prepare_fut.await?; + Ok(Box::new(Arc::new(PyFunctionExecutor { + py_function_executor: executor, + py_exec_ctx, + num_positional_args, + kw_args_names, + result_type, + enable_cache, + behavior_version, + })) as Box) } }; diff --git a/src/py/mod.rs b/src/py/mod.rs index cbc88592..42672ebf 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -68,14 +68,9 @@ fn stop(py: Python<'_>) -> PyResult<()> { } #[pyfunction] -fn register_function_factory( - name: String, - py_function_factory: Py, - is_async: bool, -) -> PyResult<()> { +fn register_function_factory(name: String, py_function_factory: Py) -> PyResult<()> { let factory = PyFunctionFactory { py_function_factory, - is_async, }; register_factory(name, ExecutorFactory::SimpleFunction(Arc::new(factory))).into_py_result() }