diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 6ad30892..a523c49a 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -18,6 +18,7 @@ from . import _engine # type: ignore from .subprocess_exec import executor_stub from .convert import ( + dump_engine_object, make_engine_value_encoder, make_engine_value_decoder, make_engine_key_decoder, @@ -32,6 +33,7 @@ AnalyzedDictType, EnrichedValueType, decode_engine_field_schemas, + FieldSchema, ) from .runtime import to_async_call @@ -432,16 +434,43 @@ class _TargetConnectorContext: target_name: str spec: Any prepared_spec: Any + key_fields_schema: list[FieldSchema] key_decoder: Callable[[Any], Any] + value_fields_schema: list[FieldSchema] value_decoder: Callable[[Any], Any] +def _build_args( + method: Callable[..., Any], num_required_args: int, **kwargs: Any +) -> list[Any]: + signature = inspect.signature(method) + for param in signature.parameters.values(): + if param.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + raise ValueError( + f"Method {method.__name__} should only have positional arguments, got {param.kind.name}" + ) + if len(signature.parameters) < num_required_args: + raise ValueError( + f"Method {method.__name__} must have at least {num_required_args} required arguments: " + f"{', '.join(list(kwargs.keys())[:num_required_args])}" + ) + if len(kwargs) > len(kwargs): + raise ValueError( + f"Method {method.__name__} can only have at most {num_required_args} arguments: {', '.join(kwargs.keys())}" + ) + return [v for _, v in zip(signature.parameters, kwargs.values())] + + class _TargetConnector: """ The connector class passed to the engine. """ _spec_cls: type + _state_cls: type _connector_cls: type _get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any] @@ -451,8 +480,9 @@ class _TargetConnector: _mutate_async_fn: Callable[..., Awaitable[None]] _mutatation_type: AnalyzedDictType | None - def __init__(self, spec_cls: type, connector_cls: type): + def __init__(self, spec_cls: type, state_cls: type, connector_cls: type): self._spec_cls = spec_cls + self._state_cls = state_cls self._connector_cls = connector_cls self._get_persistent_key_fn = _get_required_method( @@ -517,8 +547,8 @@ def create_export_context( self, name: str, spec: dict[str, Any], - key_fields_schema: list[Any], - value_fields_schema: list[Any], + raw_key_fields_schema: list[Any], + raw_value_fields_schema: list[Any], ) -> _TargetConnectorContext: key_annotation, value_annotation = ( ( @@ -529,36 +559,72 @@ def create_export_context( else (Any, Any) ) + key_fields_schema = decode_engine_field_schemas(raw_key_fields_schema) key_decoder = make_engine_key_decoder( - ["(key)"], - decode_engine_field_schemas(key_fields_schema), - analyze_type_info(key_annotation), + [""], key_fields_schema, analyze_type_info(key_annotation) ) + value_fields_schema = decode_engine_field_schemas(raw_value_fields_schema) value_decoder = make_engine_struct_decoder( - ["(value)"], - decode_engine_field_schemas(value_fields_schema), - analyze_type_info(value_annotation), + [""], value_fields_schema, analyze_type_info(value_annotation) ) loaded_spec = _load_spec_from_engine(self._spec_cls, spec) - prepare_method = getattr(self._connector_cls, "prepare", None) - if prepare_method is None: - prepared_spec = loaded_spec - else: - prepared_spec = prepare_method(loaded_spec) - return _TargetConnectorContext( target_name=name, spec=loaded_spec, - prepared_spec=prepared_spec, + prepared_spec=None, + key_fields_schema=key_fields_schema, key_decoder=key_decoder, + value_fields_schema=value_fields_schema, value_decoder=value_decoder, ) def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any: - return self._get_persistent_key_fn( - export_context.spec, export_context.target_name + args = _build_args( + self._get_persistent_key_fn, + 1, + spec=export_context.spec, + target_name=export_context.target_name, + ) + return dump_engine_object(self._get_persistent_key_fn(*args)) + + def get_setup_state(self, export_context: _TargetConnectorContext) -> Any: + get_persistent_state_fn = getattr(self._connector_cls, "get_setup_state", None) + if get_persistent_state_fn is None: + state = export_context.spec + if not isinstance(state, self._state_cls): + raise ValueError( + f"Expect a get_setup_state() method for {self._connector_cls} that returns an instance of {self._state_cls}" + ) + else: + args = _build_args( + get_persistent_state_fn, + 1, + spec=export_context.spec, + key_fields_schema=export_context.key_fields_schema, + value_fields_schema=export_context.value_fields_schema, + ) + state = get_persistent_state_fn(*args) + if not isinstance(state, self._state_cls): + raise ValueError( + f"Method {get_persistent_state_fn.__name__} must return an instance of {self._state_cls}, got {type(state)}" + ) + return dump_engine_object(state) + + async def prepare_async(self, export_context: _TargetConnectorContext) -> None: + prepare_fn = getattr(self._connector_cls, "prepare", None) + if prepare_fn is None: + export_context.prepared_spec = export_context.spec + return + args = _build_args( + prepare_fn, + 1, + spec=export_context.spec, + key_fields_schema=export_context.key_fields_schema, + value_fields_schema=export_context.value_fields_schema, ) + async_prepare_fn = to_async_call(prepare_fn) + export_context.prepared_spec = await async_prepare_fn(*args) def describe_resource(self, key: Any) -> str: describe_fn = getattr(self._connector_cls, "describe", None) @@ -572,13 +638,13 @@ async def apply_setup_changes_async( ) -> None: for key, previous, current in changes: prev_specs = [ - _load_spec_from_engine(self._spec_cls, spec) + _load_spec_from_engine(self._state_cls, spec) if spec is not None else None for spec in previous ] curr_spec = ( - _load_spec_from_engine(self._spec_cls, current) + _load_spec_from_engine(self._state_cls, current) if current is not None else None ) @@ -611,7 +677,9 @@ async def mutate_async( ) -def target_connector(spec_cls: type) -> Callable[[type], type]: +def target_connector( + spec_cls: type, state_cls: type | None = None +) -> Callable[[type], type]: """ Decorate a class to provide a target connector for an op. """ @@ -622,7 +690,7 @@ def target_connector(spec_cls: type) -> Callable[[type], type]: # Register the target connector. def _inner(connector_cls: type) -> type: - connector = _TargetConnector(spec_cls, connector_cls) + connector = _TargetConnector(spec_cls, state_cls or spec_cls, connector_cls) _engine.register_target_connector(spec_cls.__name__, connector) return connector_cls diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index 9a49db74..ea7ec62d 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -283,43 +283,65 @@ impl interface::TargetFactory for PyExportTargetFactory { .ok_or_else(|| anyhow!("Python execution context is missing"))? .clone(); for data_collection in data_collections.into_iter() { - let (py_export_ctx, persistent_key) = - Python::with_gil(|py| -> Result<(Py, serde_json::Value)> { - // Deserialize the spec to Python object. - let py_export_ctx = self - .py_target_connector - .call_method( - py, - "create_export_context", - ( - &data_collection.name, - pythonize(py, &data_collection.spec)?, - pythonize(py, &data_collection.key_fields_schema)?, - pythonize(py, &data_collection.value_fields_schema)?, - ), - None, - ) - .to_result_with_py_trace(py)?; - - // Call the `get_persistent_key` method to get the persistent key. - let persistent_key = self - .py_target_connector - .call_method(py, "get_persistent_key", (&py_export_ctx,), None) - .to_result_with_py_trace(py)?; - let persistent_key = depythonize(&persistent_key.into_bound(py))?; - Ok((py_export_ctx, persistent_key)) - })?; + let (py_export_ctx, persistent_key, setup_state) = Python::with_gil(|py| { + // Deserialize the spec to Python object. + let py_export_ctx = self + .py_target_connector + .call_method( + py, + "create_export_context", + ( + &data_collection.name, + pythonize(py, &data_collection.spec)?, + pythonize(py, &data_collection.key_fields_schema)?, + pythonize(py, &data_collection.value_fields_schema)?, + ), + None, + ) + .to_result_with_py_trace(py)?; + + // Call the `get_persistent_key` method to get the persistent key. + let persistent_key = self + .py_target_connector + .call_method(py, "get_persistent_key", (&py_export_ctx,), None) + .to_result_with_py_trace(py)?; + let persistent_key: serde_json::Value = + depythonize(&persistent_key.into_bound(py))?; + let setup_state = self + .py_target_connector + .call_method(py, "get_setup_state", (&py_export_ctx,), None) + .to_result_with_py_trace(py)?; + let setup_state: serde_json::Value = depythonize(&setup_state.into_bound(py))?; + + anyhow::Ok((py_export_ctx, persistent_key, setup_state)) + })?; + + let factory = self.clone(); let py_exec_ctx = py_exec_ctx.clone(); let build_output = interface::ExportDataCollectionBuildOutput { export_context: Box::pin(async move { - Ok(Arc::new(PyTargetExecutorContext { + Python::with_gil(|py| { + let prepare_coro = factory + .py_target_connector + .call_method(py, "prepare_async", (&py_export_ctx,), None) + .to_result_with_py_trace(py)?; + let task_locals = pyo3_async_runtimes::TaskLocals::new( + py_exec_ctx.event_loop.bind(py).clone(), + ); + anyhow::Ok(pyo3_async_runtimes::into_future_with_locals( + &task_locals, + prepare_coro.into_bound(py), + )?) + })? + .await?; + anyhow::Ok(Arc::new(PyTargetExecutorContext { py_export_ctx, py_exec_ctx, }) as Arc) }), setup_key: persistent_key, - desired_setup_state: data_collection.spec, + desired_setup_state: setup_state, }; build_outputs.push(build_output); }