Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 18 additions & 28 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import dataclasses
import inspect

from typing import get_type_hints, Protocol, Any, Callable, dataclass_transform
from typing import get_type_hints, Protocol, Any, Callable, Awaitable, dataclass_transform
from enum import Enum
from threading import Lock
from functools import partial

from .typing import encode_enriched_type
from .convert import to_engine_value, make_engine_value_converter
Expand Down Expand Up @@ -61,7 +61,7 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
return (encode_enriched_type(result_type), executor)


_gpu_dispatch_lock = Lock()
_gpu_dispatch_lock = asyncio.Lock()

@dataclasses.dataclass
class OpArgs:
Expand All @@ -75,11 +75,15 @@ class OpArgs:
cache: bool = False
behavior_version: int | None = None

def _to_async_call(call: Callable) -> Callable[..., Awaitable[Any]]:
if inspect.iscoroutinefunction(call):
return call
return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs))

def _register_op_factory(
category: OpCategory,
expected_args: list[tuple[str, inspect.Parameter]],
expected_return,
is_async: bool,
executor_cls: type,
spec_cls: type,
op_args: OpArgs,
Expand All @@ -97,10 +101,12 @@ def behavior_version(self):
class _WrappedClass(executor_cls, _Fallback):
_args_converters: list[Callable[[Any], Any]]
_kwargs_converters: dict[str, Callable[[str, Any], Any]]
_acall: Callable

def __init__(self, spec):
super().__init__()
self.spec = spec
self._acall = _to_async_call(super().__call__)

def analyze(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -157,42 +163,30 @@ def analyze(self, *args, **kwargs):
else:
return expected_return

def prepare(self):
async def prepare(self):
"""
Prepare for execution.
It's executed after `analyze` and before any `__call__` execution.
"""
setup_method = getattr(executor_cls, 'prepare', None)
setup_method = getattr(super(), 'prepare', None)
if setup_method is not None:
setup_method(self)
await _to_async_call(setup_method)()

def __call__(self, *args, **kwargs):
async def __call__(self, *args, **kwargs):
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg)
for arg_name, arg in kwargs.items()}
if is_async:
async def _inner():
if op_args.gpu:
await asyncio.to_thread(_gpu_dispatch_lock.acquire)
try:
output = await super(_WrappedClass, self).__call__(
*converted_args, **converted_kwargs)
finally:
if op_args.gpu:
_gpu_dispatch_lock.release()
return to_engine_value(output)
return _inner()

if op_args.gpu:
# For GPU executions, data-level parallelism is applied, so we don't want to
# execute different tasks in parallel.
# Besides, multiprocessing is more appropriate for pytorch.
# For now, we use a lock to ensure only one task is executed at a time.
# TODO: Implement multi-processing dispatching.
with _gpu_dispatch_lock:
output = super().__call__(*converted_args, **converted_kwargs)
async with _gpu_dispatch_lock:
output = await self._acall(*converted_args, **converted_kwargs)
else:
output = super().__call__(*converted_args, **converted_kwargs)
output = await self._acall(*converted_args, **converted_kwargs)
return to_engine_value(output)

_WrappedClass.__name__ = executor_cls.__name__
Expand All @@ -203,9 +197,7 @@ async def _inner():

if category == OpCategory.FUNCTION:
_engine.register_function_factory(
spec_cls.__name__,
_FunctionExecutorFactory(spec_cls, _WrappedClass),
is_async)
spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass))
else:
raise ValueError(f"Unsupported executor type {category}")

Expand All @@ -230,7 +222,6 @@ def _inner(cls: type[Executor]) -> type:
category=spec_cls._op_category,
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
expected_return=sig.return_annotation,
is_async=inspect.iscoroutinefunction(cls.__call__),
executor_cls=cls,
spec_cls=spec_cls,
op_args=op_args)
Expand Down Expand Up @@ -266,7 +257,6 @@ class _Spec(FunctionSpec):
category=OpCategory.FUNCTION,
expected_args=list(sig.parameters.items()),
expected_return=sig.return_annotation,
is_async=inspect.iscoroutinefunction(fn),
executor_cls=_Executor,
spec_cls=_Spec,
op_args=op_args)
Expand Down
100 changes: 43 additions & 57 deletions src/ops/py_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ impl PyOpArgSchema {

struct PyFunctionExecutor {
py_function_executor: Py<PyAny>,
is_async: bool,
py_exec_ctx: Arc<crate::py::PythonExecutionContext>,

num_positional_args: usize,
Expand Down Expand Up @@ -91,36 +90,22 @@ impl PyFunctionExecutor {
impl SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
async fn evaluate(&self, input: Vec<value::Value>) -> Result<value::Value> {
let self = self.clone();
let result = if self.is_async {
let result_fut = Python::with_gil(|py| -> Result<_> {
let result = self.call_py_fn(py, input)?;
let task_locals = pyo3_async_runtimes::TaskLocals::new(
self.py_exec_ctx.event_loop.bind(py).clone(),
);
Ok(pyo3_async_runtimes::into_future_with_locals(
&task_locals,
result,
)?)
})?;
let result = result_fut.await?;
Python::with_gil(|py| -> Result<_> {
Ok(py::value_from_py_object(
&self.result_type.typ,
&result.into_bound(py),
)?)
})?
} else {
tokio::task::spawn_blocking(move || {
Python::with_gil(|py| -> Result<_> {
Ok(py::value_from_py_object(
&self.result_type.typ,
&self.call_py_fn(py, input)?,
)?)
})
})
.await??
};
Ok(result)
let result_fut = Python::with_gil(|py| -> Result<_> {
let result_coro = self.call_py_fn(py, input)?;
let task_locals =
pyo3_async_runtimes::TaskLocals::new(self.py_exec_ctx.event_loop.bind(py).clone());
Ok(pyo3_async_runtimes::into_future_with_locals(
&task_locals,
result_coro,
)?)
})?;
let result = result_fut.await?;
Python::with_gil(|py| -> Result<_> {
Ok(py::value_from_py_object(
&self.result_type.typ,
&result.into_bound(py),
)?)
})
}

fn enable_cache(&self) -> bool {
Expand All @@ -134,7 +119,6 @@ impl SimpleFunctionExecutor for Arc<PyFunctionExecutor> {

pub(crate) struct PyFunctionFactory {
pub py_function_factory: Py<PyAny>,
pub is_async: bool,
}

impl SimpleFunctionFactory for PyFunctionFactory {
Expand Down Expand Up @@ -195,31 +179,33 @@ impl SimpleFunctionFactory for PyFunctionFactory {
.as_ref()
.ok_or_else(|| anyhow!("Python execution context is missing"))?
.clone();
let executor = tokio::task::spawn_blocking(move || -> Result<_> {
let (enable_cache, behavior_version) =
Python::with_gil(|py| -> anyhow::Result<_> {
executor.call_method(py, "prepare", (), None)?;
let enable_cache = executor
.call_method(py, "enable_cache", (), None)?
.extract::<bool>(py)?;
let behavior_version = executor
.call_method(py, "behavior_version", (), None)?
.extract::<Option<u32>>(py)?;
Ok((enable_cache, behavior_version))
})?;
Ok(Box::new(Arc::new(PyFunctionExecutor {
py_function_executor: executor,
is_async: self.is_async,
py_exec_ctx,
num_positional_args,
kw_args_names,
result_type,
enable_cache,
behavior_version,
})) as Box<dyn SimpleFunctionExecutor>)
})
.await??;
Ok(executor)
let (prepare_fut, enable_cache, behavior_version) =
Python::with_gil(|py| -> anyhow::Result<_> {
let prepare_coro = executor.call_method(py, "prepare", (), None)?;
let prepare_fut = pyo3_async_runtimes::into_future_with_locals(
&pyo3_async_runtimes::TaskLocals::new(
py_exec_ctx.event_loop.bind(py).clone(),
),
prepare_coro.into_bound(py),
)?;
let enable_cache = executor
.call_method(py, "enable_cache", (), None)?
.extract::<bool>(py)?;
let behavior_version = executor
.call_method(py, "behavior_version", (), None)?
.extract::<Option<u32>>(py)?;
Ok((prepare_fut, enable_cache, behavior_version))
})?;
prepare_fut.await?;
Ok(Box::new(Arc::new(PyFunctionExecutor {
py_function_executor: executor,
py_exec_ctx,
num_positional_args,
kw_args_names,
result_type,
enable_cache,
behavior_version,
})) as Box<dyn SimpleFunctionExecutor>)
}
};

Expand Down
7 changes: 1 addition & 6 deletions src/py/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,9 @@ fn stop(py: Python<'_>) -> PyResult<()> {
}

#[pyfunction]
fn register_function_factory(
name: String,
py_function_factory: Py<PyAny>,
is_async: bool,
) -> PyResult<()> {
fn register_function_factory(name: String, py_function_factory: Py<PyAny>) -> PyResult<()> {
let factory = PyFunctionFactory {
py_function_factory,
is_async,
};
register_factory(name, ExecutorFactory::SimpleFunction(Arc::new(factory))).into_py_result()
}
Expand Down