diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index fa005883..7e415838 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -12,7 +12,7 @@ from . import _engine from . import vector from . import op -from .typing import encode_type +from .typing import encode_enriched_type class _NameBuilder: _existing_names: set[str] @@ -419,7 +419,7 @@ def __init__( inspect.Parameter.KEYWORD_ONLY): raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name") engine_ds = flow_builder_state.engine_flow_builder.add_direct_input( - param_name, encode_type(param_type)) + param_name, encode_enriched_type(param_type)) kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds)) output = flow_fn(**kwargs) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 0a6df995..1f625255 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -8,7 +8,7 @@ from enum import Enum from threading import Lock -from .typing import encode_type +from .typing import encode_enriched_type, analyze_type_info, COLLECTION_TYPES from . import _engine @@ -57,16 +57,86 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs): spec = self._spec_cls(**spec) executor = self._executor_cls(spec) result_type = executor.analyze(*args, **kwargs) - return (encode_type(result_type), executor) + return (encode_enriched_type(result_type), executor) def _to_engine_value(value: Any) -> Any: """Convert a Python value to an engine value.""" if dataclasses.is_dataclass(value): return [_to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)] - elif isinstance(value, list) or isinstance(value, tuple): + if isinstance(value, (list, tuple)): return [_to_engine_value(v) for v in value] return value +def _make_engine_struct_value_converter( + field_path: list[str], + src_fields: list[dict[str, Any]], + dst_dataclass_type: type, + ) -> Callable[[list], Any]: + """Make a converter from an engine field values to a Python value.""" + + src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)} + def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]: + src_idx = src_name_to_idx.get(name) + if src_idx is not None: + field_path.append(f'.{name}') + field_converter = _make_engine_value_converter( + field_path, src_fields[src_idx]['type'], param.annotation) + field_path.pop() + return lambda values: field_converter(values[src_idx]) + + default_value = param.default + if default_value is inspect.Parameter.empty: + raise ValueError( + f"Field without default value is missing in input: {''.join(field_path)}") + + return lambda _: default_value + + field_value_converters = [ + make_closure_for_value(name, param) + for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()] + + return lambda values: dst_dataclass_type( + *(converter(values) for converter in field_value_converters)) + +def _make_engine_value_converter( + field_path: list[str], + src_type: dict[str, Any], + dst_annotation, + ) -> Callable[[Any], Any]: + """Make a converter from an engine value to a Python value.""" + + src_type_kind = src_type['kind'] + + if dst_annotation is inspect.Parameter.empty: + if src_type_kind == 'Struct' or src_type_kind in COLLECTION_TYPES: + raise ValueError(f"Missing type annotation for `{''.join(field_path)}`." + f"It's required for {src_type_kind} type.") + return lambda value: value + + dst_type_info = analyze_type_info(dst_annotation) + + if src_type_kind != dst_type_info.kind: + raise ValueError( + f"Type mismatch for `{''.join(field_path)}`: " + f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})") + + if dst_type_info.dataclass_type is not None: + return _make_engine_struct_value_converter( + field_path, src_type['fields'], dst_type_info.dataclass_type) + + if src_type_kind in COLLECTION_TYPES: + field_path.append('[*]') + elem_type_info = analyze_type_info(dst_type_info.elem_type) + if elem_type_info.dataclass_type is None: + raise ValueError(f"Type mismatch for `{''.join(field_path)}`: " + f"declared `{dst_type_info.kind}`, a dataclass type expected") + elem_converter = _make_engine_struct_value_converter( + field_path, src_type['row']['fields'], elem_type_info.dataclass_type) + field_path.pop() + return lambda value: [elem_converter(v) for v in value] if value is not None else None + + return lambda value: value + _gpu_dispatch_lock = Lock() def executor_class(gpu: bool = False, cache: bool = False, behavior_version: int | None = None) -> Callable[[type], type]: @@ -105,6 +175,9 @@ def behavior_version(self): return behavior_version class _WrappedClass(cls_type, _Fallback): + _args_converters: list[Callable[[Any], Any]] + _kwargs_converters: dict[str, Callable[[str, Any], Any]] + def __init__(self, spec): super().__init__() self.spec = spec @@ -114,16 +187,19 @@ def analyze(self, *args, **kwargs): Analyze the spec and arguments. In this phase, argument types should be validated. It should return the expected result type for the current op. """ + self._args_converters = [] + self._kwargs_converters = {} + # Match arguments with parameters. next_param_idx = 0 - for arg in args: + for arg in args: if next_param_idx >= len(expected_args): - raise ValueError(f"Too many arguments: {len(args)} > {len(expected_args)}") + raise ValueError(f"Too many arguments passed in: {len(args)} > {len(expected_args)}") arg_name, arg_param = expected_args[next_param_idx] if arg_param.kind == inspect.Parameter.KEYWORD_ONLY or arg_param.kind == inspect.Parameter.VAR_KEYWORD: - raise ValueError(f"Too many positional arguments: {len(args)} > {next_param_idx}") - if arg_param.annotation is not inspect.Parameter.empty: - arg.validate_arg(arg_name, encode_type(arg_param.annotation)) + raise ValueError(f"Too many positional arguments passed in: {len(args)} > {next_param_idx}") + self._args_converters.append( + _make_engine_value_converter([arg_name], arg.value_type['type'], arg_param.annotation)) if arg_param.kind != inspect.Parameter.VAR_POSITIONAL: next_param_idx += 1 @@ -136,10 +212,10 @@ def analyze(self, *args, **kwargs): or arg[1].kind == inspect.Parameter.VAR_KEYWORD), None) if expected_arg is None: - raise ValueError(f"Unexpected keyword argument: {kwarg_name}") + raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}") arg_param = expected_arg[1] - if arg_param.annotation is not inspect.Parameter.empty: - kwarg.validate_arg(kwarg_name, encode_type(arg_param.annotation)) + self._kwargs_converters[kwarg_name] = _make_engine_value_converter( + [kwarg_name], kwarg.value_type['type'], arg_param.annotation) missing_args = [name for (name, arg) in expected_kwargs if arg.default is inspect.Parameter.empty @@ -164,15 +240,17 @@ def prepare(self): setup_method(self) def __call__(self, *args, **kwargs): + converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args)) + converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg) for arg_name, arg in kwargs.items()} if gpu: # For GPU executions, data-level parallelism is applied, so we don't want to execute different tasks in parallel. # Besides, multiprocessing is more appropriate for pytorch. # For now, we use a lock to ensure only one task is executed at a time. # TODO: Implement multi-processing dispatching. with _gpu_dispatch_lock: - output = super().__call__(*args, **kwargs) + output = super().__call__(*converted_args, **converted_kwargs) else: - output = super().__call__(*args, **kwargs) + output = super().__call__(*converted_args, **converted_kwargs) return _to_engine_value(output) _WrappedClass.__name__ = cls.__name__ diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index f640af43..3fc85583 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -9,6 +9,7 @@ class Vector(NamedTuple): class TypeKind(NamedTuple): kind: str + class TypeAttr: key: str value: Any @@ -24,6 +25,8 @@ def __init__(self, key: str, value: Any): Range = Annotated[tuple[int, int], TypeKind('Range')] Json = Annotated[Any, TypeKind('Json')] +COLLECTION_TYPES = ('Table', 'List') + R = TypeVar("R") if TYPE_CHECKING: @@ -46,10 +49,6 @@ class List: # type: ignore[unreachable] def __class_getitem__(cls, item: type[R]): return Annotated[list[item], TypeKind('List')] -def _dump_field_schema(field: dataclasses.Field) -> dict[str, Any]: - encoded = _encode_enriched_type(field.type) - encoded['name'] = field.name - return encoded @dataclasses.dataclass class AnalyzedTypeInfo: """ @@ -58,7 +57,7 @@ class AnalyzedTypeInfo: kind: str vector_info: Vector | None elem_type: type | None - struct_fields: tuple[dataclasses.Field, ...] | None + dataclass_type: type | None attrs: dict[str, Any] | None nullable: bool = False @@ -99,18 +98,18 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: elif isinstance(attr, TypeKind): kind = attr.kind - struct_fields = None + dataclass_type = None elem_type = None - if dataclasses.is_dataclass(t): + if isinstance(t, type) and dataclasses.is_dataclass(t): if kind is None: kind = 'Struct' elif kind != 'Struct': raise ValueError(f"Unexpected type kind for struct: {kind}") - struct_fields = dataclasses.fields(t) + dataclass_type = t elif base_type is collections.abc.Sequence or base_type is list: if kind is None: kind = 'Vector' if vector_info is not None else 'List' - elif kind not in ('Vector', 'List', 'Table'): + elif not (kind == 'Vector' or kind in COLLECTION_TYPES): raise ValueError(f"Unexpected type kind for list: {kind}") args = typing.get_args(t) @@ -134,15 +133,20 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: raise ValueError(f"type unsupported yet: {base_type}") return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type, - struct_fields=struct_fields, attrs=attrs, nullable=nullable) + dataclass_type=dataclass_type, attrs=attrs, nullable=nullable) + +def _encode_fields_schema(dataclass_type: type) -> list[dict[str, Any]]: + return [{ 'name': field.name, + **encode_enriched_type_info(analyze_type_info(field.type)) + } for field in dataclasses.fields(dataclass_type)] def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: encoded_type: dict[str, Any] = { 'kind': type_info.kind } if type_info.kind == 'Struct': - if type_info.struct_fields is None: - raise ValueError("Struct type must have a struct fields") - encoded_type['fields'] = [_dump_field_schema(field) for field in type_info.struct_fields] + if type_info.dataclass_type is None: + raise ValueError("Struct type must have a dataclass type") + encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type) elif type_info.kind == 'Vector': if type_info.vector_info is None: @@ -152,21 +156,22 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type)) encoded_type['dimension'] = type_info.vector_info.dim - elif type_info.kind in ('List', 'Table'): + elif type_info.kind in COLLECTION_TYPES: if type_info.elem_type is None: raise ValueError(f"{type_info.kind} type must have an element type") - row_type_inof = analyze_type_info(type_info.elem_type) - if row_type_inof.struct_fields is None: - raise ValueError(f"{type_info.kind} type must have a struct fields") + row_type_info = analyze_type_info(type_info.elem_type) + if row_type_info.dataclass_type is None: + raise ValueError(f"{type_info.kind} type must have a dataclass type") encoded_type['row'] = { - 'fields': [_dump_field_schema(field) for field in row_type_inof.struct_fields], + 'fields': _encode_fields_schema(row_type_info.dataclass_type), } return encoded_type -def _encode_enriched_type(t) -> dict[str, Any]: - enriched_type_info = analyze_type_info(t) - +def encode_enriched_type_info(enriched_type_info: AnalyzedTypeInfo) -> dict[str, Any]: + """ + Encode an enriched type info to a CocoIndex engine's type representation + """ encoded: dict[str, Any] = {'type': _encode_type(enriched_type_info)} if enriched_type_info.attrs is not None: @@ -178,10 +183,11 @@ def _encode_enriched_type(t) -> dict[str, Any]: return encoded -def encode_type(t) -> dict[str, Any] | None: +def encode_enriched_type(t) -> dict[str, Any] | None: """ - Convert a Python type to a CocoIndex's type in JSON. + Convert a Python type to a CocoIndex engine's type representation """ if t is None: return None - return _encode_enriched_type(t) + + return encode_enriched_type_info(analyze_type_info(t)) diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index 3ca7b4bd..6ed35adf 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -139,20 +139,6 @@ impl PyOpArgSchema { fn analyzed_value(&self) -> &crate::py::Pythonized { &self.analyzed_value } - - fn validate_arg( - &self, - name: &str, - typ: crate::py::Pythonized, - ) -> PyResult<()> { - if self.value_type.0.typ != typ.0.typ { - return Err(PyException::new_err(format!( - "argument `{}` type mismatch, input type: {}, argument type: {}", - name, self.value_type.0.typ, typ.0.typ - ))); - } - Ok(()) - } } struct PyFunctionExecutor {