Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 90 additions & 22 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import _engine # type: ignore
from .subprocess_exec import executor_stub
from .convert import (
dump_engine_object,
make_engine_value_encoder,
make_engine_value_decoder,
make_engine_key_decoder,
Expand All @@ -32,6 +33,7 @@
AnalyzedDictType,
EnrichedValueType,
decode_engine_field_schemas,
FieldSchema,
)
from .runtime import to_async_call

Expand Down Expand Up @@ -432,16 +434,43 @@ class _TargetConnectorContext:
target_name: str
spec: Any
prepared_spec: Any
key_fields_schema: list[FieldSchema]
key_decoder: Callable[[Any], Any]
value_fields_schema: list[FieldSchema]
value_decoder: Callable[[Any], Any]


def _build_args(
method: Callable[..., Any], num_required_args: int, **kwargs: Any
) -> list[Any]:
signature = inspect.signature(method)
for param in signature.parameters.values():
if param.kind not in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
raise ValueError(
f"Method {method.__name__} should only have positional arguments, got {param.kind.name}"
)
if len(signature.parameters) < num_required_args:
raise ValueError(
f"Method {method.__name__} must have at least {num_required_args} required arguments: "
f"{', '.join(list(kwargs.keys())[:num_required_args])}"
)
if len(kwargs) > len(kwargs):
raise ValueError(
f"Method {method.__name__} can only have at most {num_required_args} arguments: {', '.join(kwargs.keys())}"
)
return [v for _, v in zip(signature.parameters, kwargs.values())]


class _TargetConnector:
"""
The connector class passed to the engine.
"""

_spec_cls: type
_state_cls: type
_connector_cls: type

_get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
Expand All @@ -451,8 +480,9 @@ class _TargetConnector:
_mutate_async_fn: Callable[..., Awaitable[None]]
_mutatation_type: AnalyzedDictType | None

def __init__(self, spec_cls: type, connector_cls: type):
def __init__(self, spec_cls: type, state_cls: type, connector_cls: type):
self._spec_cls = spec_cls
self._state_cls = state_cls
self._connector_cls = connector_cls

self._get_persistent_key_fn = _get_required_method(
Expand Down Expand Up @@ -517,8 +547,8 @@ def create_export_context(
self,
name: str,
spec: dict[str, Any],
key_fields_schema: list[Any],
value_fields_schema: list[Any],
raw_key_fields_schema: list[Any],
raw_value_fields_schema: list[Any],
) -> _TargetConnectorContext:
key_annotation, value_annotation = (
(
Expand All @@ -529,36 +559,72 @@ def create_export_context(
else (Any, Any)
)

key_fields_schema = decode_engine_field_schemas(raw_key_fields_schema)
key_decoder = make_engine_key_decoder(
["(key)"],
decode_engine_field_schemas(key_fields_schema),
analyze_type_info(key_annotation),
["<key>"], key_fields_schema, analyze_type_info(key_annotation)
)
value_fields_schema = decode_engine_field_schemas(raw_value_fields_schema)
value_decoder = make_engine_struct_decoder(
["(value)"],
decode_engine_field_schemas(value_fields_schema),
analyze_type_info(value_annotation),
["<value>"], value_fields_schema, analyze_type_info(value_annotation)
)

loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
prepare_method = getattr(self._connector_cls, "prepare", None)
if prepare_method is None:
prepared_spec = loaded_spec
else:
prepared_spec = prepare_method(loaded_spec)

return _TargetConnectorContext(
target_name=name,
spec=loaded_spec,
prepared_spec=prepared_spec,
prepared_spec=None,
key_fields_schema=key_fields_schema,
key_decoder=key_decoder,
value_fields_schema=value_fields_schema,
value_decoder=value_decoder,
)

def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
return self._get_persistent_key_fn(
export_context.spec, export_context.target_name
args = _build_args(
self._get_persistent_key_fn,
1,
spec=export_context.spec,
target_name=export_context.target_name,
)
return dump_engine_object(self._get_persistent_key_fn(*args))

def get_setup_state(self, export_context: _TargetConnectorContext) -> Any:
get_persistent_state_fn = getattr(self._connector_cls, "get_setup_state", None)
if get_persistent_state_fn is None:
state = export_context.spec
if not isinstance(state, self._state_cls):
raise ValueError(
f"Expect a get_setup_state() method for {self._connector_cls} that returns an instance of {self._state_cls}"
)
else:
args = _build_args(
get_persistent_state_fn,
1,
spec=export_context.spec,
key_fields_schema=export_context.key_fields_schema,
value_fields_schema=export_context.value_fields_schema,
)
state = get_persistent_state_fn(*args)
if not isinstance(state, self._state_cls):
raise ValueError(
f"Method {get_persistent_state_fn.__name__} must return an instance of {self._state_cls}, got {type(state)}"
)
return dump_engine_object(state)

async def prepare_async(self, export_context: _TargetConnectorContext) -> None:
prepare_fn = getattr(self._connector_cls, "prepare", None)
if prepare_fn is None:
export_context.prepared_spec = export_context.spec
return
args = _build_args(
prepare_fn,
1,
spec=export_context.spec,
key_fields_schema=export_context.key_fields_schema,
value_fields_schema=export_context.value_fields_schema,
)
async_prepare_fn = to_async_call(prepare_fn)
export_context.prepared_spec = await async_prepare_fn(*args)

def describe_resource(self, key: Any) -> str:
describe_fn = getattr(self._connector_cls, "describe", None)
Expand All @@ -572,13 +638,13 @@ async def apply_setup_changes_async(
) -> None:
for key, previous, current in changes:
prev_specs = [
_load_spec_from_engine(self._spec_cls, spec)
_load_spec_from_engine(self._state_cls, spec)
if spec is not None
else None
for spec in previous
]
curr_spec = (
_load_spec_from_engine(self._spec_cls, current)
_load_spec_from_engine(self._state_cls, current)
if current is not None
else None
)
Expand Down Expand Up @@ -611,7 +677,9 @@ async def mutate_async(
)


def target_connector(spec_cls: type) -> Callable[[type], type]:
def target_connector(
spec_cls: type, state_cls: type | None = None
) -> Callable[[type], type]:
"""
Decorate a class to provide a target connector for an op.
"""
Expand All @@ -622,7 +690,7 @@ def target_connector(spec_cls: type) -> Callable[[type], type]:

# Register the target connector.
def _inner(connector_cls: type) -> type:
connector = _TargetConnector(spec_cls, connector_cls)
connector = _TargetConnector(spec_cls, state_cls or spec_cls, connector_cls)
_engine.register_target_connector(spec_cls.__name__, connector)
return connector_cls

Expand Down
78 changes: 50 additions & 28 deletions src/ops/py_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,43 +283,65 @@ impl interface::TargetFactory for PyExportTargetFactory {
.ok_or_else(|| anyhow!("Python execution context is missing"))?
.clone();
for data_collection in data_collections.into_iter() {
let (py_export_ctx, persistent_key) =
Python::with_gil(|py| -> Result<(Py<PyAny>, serde_json::Value)> {
// Deserialize the spec to Python object.
let py_export_ctx = self
.py_target_connector
.call_method(
py,
"create_export_context",
(
&data_collection.name,
pythonize(py, &data_collection.spec)?,
pythonize(py, &data_collection.key_fields_schema)?,
pythonize(py, &data_collection.value_fields_schema)?,
),
None,
)
.to_result_with_py_trace(py)?;

// Call the `get_persistent_key` method to get the persistent key.
let persistent_key = self
.py_target_connector
.call_method(py, "get_persistent_key", (&py_export_ctx,), None)
.to_result_with_py_trace(py)?;
let persistent_key = depythonize(&persistent_key.into_bound(py))?;
Ok((py_export_ctx, persistent_key))
})?;
let (py_export_ctx, persistent_key, setup_state) = Python::with_gil(|py| {
// Deserialize the spec to Python object.
let py_export_ctx = self
.py_target_connector
.call_method(
py,
"create_export_context",
(
&data_collection.name,
pythonize(py, &data_collection.spec)?,
pythonize(py, &data_collection.key_fields_schema)?,
pythonize(py, &data_collection.value_fields_schema)?,
),
None,
)
.to_result_with_py_trace(py)?;

// Call the `get_persistent_key` method to get the persistent key.
let persistent_key = self
.py_target_connector
.call_method(py, "get_persistent_key", (&py_export_ctx,), None)
.to_result_with_py_trace(py)?;
let persistent_key: serde_json::Value =
depythonize(&persistent_key.into_bound(py))?;

let setup_state = self
.py_target_connector
.call_method(py, "get_setup_state", (&py_export_ctx,), None)
.to_result_with_py_trace(py)?;
let setup_state: serde_json::Value = depythonize(&setup_state.into_bound(py))?;

anyhow::Ok((py_export_ctx, persistent_key, setup_state))
})?;

let factory = self.clone();
let py_exec_ctx = py_exec_ctx.clone();
let build_output = interface::ExportDataCollectionBuildOutput {
export_context: Box::pin(async move {
Ok(Arc::new(PyTargetExecutorContext {
Python::with_gil(|py| {
let prepare_coro = factory
.py_target_connector
.call_method(py, "prepare_async", (&py_export_ctx,), None)
.to_result_with_py_trace(py)?;
let task_locals = pyo3_async_runtimes::TaskLocals::new(
py_exec_ctx.event_loop.bind(py).clone(),
);
anyhow::Ok(pyo3_async_runtimes::into_future_with_locals(
&task_locals,
prepare_coro.into_bound(py),
)?)
})?
.await?;
anyhow::Ok(Arc::new(PyTargetExecutorContext {
py_export_ctx,
py_exec_ctx,
}) as Arc<dyn Any + Send + Sync>)
}),
setup_key: persistent_key,
desired_setup_state: data_collection.spec,
desired_setup_state: setup_state,
};
build_outputs.push(build_output);
}
Expand Down
Loading