Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import _engine
from . import vector
from . import op
from .typing import encode_type
from .typing import encode_enriched_type

class _NameBuilder:
_existing_names: set[str]
Expand Down Expand Up @@ -419,7 +419,7 @@ def __init__(
inspect.Parameter.KEYWORD_ONLY):
raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name")
engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
param_name, encode_type(param_type))
param_name, encode_enriched_type(param_type))
kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds))

output = flow_fn(**kwargs)
Expand Down
104 changes: 91 additions & 13 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import Enum
from threading import Lock

from .typing import encode_type
from .typing import encode_enriched_type, analyze_type_info, COLLECTION_TYPES
from . import _engine


Expand Down Expand Up @@ -57,16 +57,86 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
spec = self._spec_cls(**spec)
executor = self._executor_cls(spec)
result_type = executor.analyze(*args, **kwargs)
return (encode_type(result_type), executor)
return (encode_enriched_type(result_type), executor)

def _to_engine_value(value: Any) -> Any:
"""Convert a Python value to an engine value."""
if dataclasses.is_dataclass(value):
return [_to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
elif isinstance(value, list) or isinstance(value, tuple):
if isinstance(value, (list, tuple)):
return [_to_engine_value(v) for v in value]
return value

def _make_engine_struct_value_converter(
field_path: list[str],
src_fields: list[dict[str, Any]],
dst_dataclass_type: type,
) -> Callable[[list], Any]:
"""Make a converter from an engine field values to a Python value."""

src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)}
def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]:
src_idx = src_name_to_idx.get(name)
if src_idx is not None:
field_path.append(f'.{name}')
field_converter = _make_engine_value_converter(
field_path, src_fields[src_idx]['type'], param.annotation)
field_path.pop()
return lambda values: field_converter(values[src_idx])

default_value = param.default
if default_value is inspect.Parameter.empty:
raise ValueError(
f"Field without default value is missing in input: {''.join(field_path)}")

return lambda _: default_value

field_value_converters = [
make_closure_for_value(name, param)
for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()]

return lambda values: dst_dataclass_type(
*(converter(values) for converter in field_value_converters))

def _make_engine_value_converter(
field_path: list[str],
src_type: dict[str, Any],
dst_annotation,
) -> Callable[[Any], Any]:
"""Make a converter from an engine value to a Python value."""

src_type_kind = src_type['kind']

if dst_annotation is inspect.Parameter.empty:
if src_type_kind == 'Struct' or src_type_kind in COLLECTION_TYPES:
raise ValueError(f"Missing type annotation for `{''.join(field_path)}`."
f"It's required for {src_type_kind} type.")
return lambda value: value

dst_type_info = analyze_type_info(dst_annotation)

if src_type_kind != dst_type_info.kind:
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")

if dst_type_info.dataclass_type is not None:
return _make_engine_struct_value_converter(
field_path, src_type['fields'], dst_type_info.dataclass_type)

if src_type_kind in COLLECTION_TYPES:
field_path.append('[*]')
elem_type_info = analyze_type_info(dst_type_info.elem_type)
if elem_type_info.dataclass_type is None:
raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.kind}`, a dataclass type expected")
elem_converter = _make_engine_struct_value_converter(
field_path, src_type['row']['fields'], elem_type_info.dataclass_type)
field_path.pop()
return lambda value: [elem_converter(v) for v in value] if value is not None else None

return lambda value: value

_gpu_dispatch_lock = Lock()

def executor_class(gpu: bool = False, cache: bool = False, behavior_version: int | None = None) -> Callable[[type], type]:
Expand Down Expand Up @@ -105,6 +175,9 @@ def behavior_version(self):
return behavior_version

class _WrappedClass(cls_type, _Fallback):
_args_converters: list[Callable[[Any], Any]]
_kwargs_converters: dict[str, Callable[[str, Any], Any]]

def __init__(self, spec):
super().__init__()
self.spec = spec
Expand All @@ -114,16 +187,19 @@ def analyze(self, *args, **kwargs):
Analyze the spec and arguments. In this phase, argument types should be validated.
It should return the expected result type for the current op.
"""
self._args_converters = []
self._kwargs_converters = {}

# Match arguments with parameters.
next_param_idx = 0
for arg in args:
for arg in args:
if next_param_idx >= len(expected_args):
raise ValueError(f"Too many arguments: {len(args)} > {len(expected_args)}")
raise ValueError(f"Too many arguments passed in: {len(args)} > {len(expected_args)}")
arg_name, arg_param = expected_args[next_param_idx]
if arg_param.kind == inspect.Parameter.KEYWORD_ONLY or arg_param.kind == inspect.Parameter.VAR_KEYWORD:
raise ValueError(f"Too many positional arguments: {len(args)} > {next_param_idx}")
if arg_param.annotation is not inspect.Parameter.empty:
arg.validate_arg(arg_name, encode_type(arg_param.annotation))
raise ValueError(f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
self._args_converters.append(
_make_engine_value_converter([arg_name], arg.value_type['type'], arg_param.annotation))
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
next_param_idx += 1

Expand All @@ -136,10 +212,10 @@ def analyze(self, *args, **kwargs):
or arg[1].kind == inspect.Parameter.VAR_KEYWORD),
None)
if expected_arg is None:
raise ValueError(f"Unexpected keyword argument: {kwarg_name}")
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
arg_param = expected_arg[1]
if arg_param.annotation is not inspect.Parameter.empty:
kwarg.validate_arg(kwarg_name, encode_type(arg_param.annotation))
self._kwargs_converters[kwarg_name] = _make_engine_value_converter(
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)

missing_args = [name for (name, arg) in expected_kwargs
if arg.default is inspect.Parameter.empty
Expand All @@ -164,15 +240,17 @@ def prepare(self):
setup_method(self)

def __call__(self, *args, **kwargs):
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg) for arg_name, arg in kwargs.items()}
if gpu:
# For GPU executions, data-level parallelism is applied, so we don't want to execute different tasks in parallel.
# Besides, multiprocessing is more appropriate for pytorch.
# For now, we use a lock to ensure only one task is executed at a time.
# TODO: Implement multi-processing dispatching.
with _gpu_dispatch_lock:
output = super().__call__(*args, **kwargs)
output = super().__call__(*converted_args, **converted_kwargs)
else:
output = super().__call__(*args, **kwargs)
output = super().__call__(*converted_args, **converted_kwargs)
return _to_engine_value(output)

_WrappedClass.__name__ = cls.__name__
Expand Down
54 changes: 30 additions & 24 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Vector(NamedTuple):

class TypeKind(NamedTuple):
kind: str

class TypeAttr:
key: str
value: Any
Expand All @@ -24,6 +25,8 @@ def __init__(self, key: str, value: Any):
Range = Annotated[tuple[int, int], TypeKind('Range')]
Json = Annotated[Any, TypeKind('Json')]

COLLECTION_TYPES = ('Table', 'List')

R = TypeVar("R")

if TYPE_CHECKING:
Expand All @@ -46,10 +49,6 @@ class List: # type: ignore[unreachable]
def __class_getitem__(cls, item: type[R]):
return Annotated[list[item], TypeKind('List')]

def _dump_field_schema(field: dataclasses.Field) -> dict[str, Any]:
encoded = _encode_enriched_type(field.type)
encoded['name'] = field.name
return encoded
@dataclasses.dataclass
class AnalyzedTypeInfo:
"""
Expand All @@ -58,7 +57,7 @@ class AnalyzedTypeInfo:
kind: str
vector_info: Vector | None
elem_type: type | None
struct_fields: tuple[dataclasses.Field, ...] | None
dataclass_type: type | None
attrs: dict[str, Any] | None
nullable: bool = False

Expand Down Expand Up @@ -99,18 +98,18 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
elif isinstance(attr, TypeKind):
kind = attr.kind

struct_fields = None
dataclass_type = None
elem_type = None
if dataclasses.is_dataclass(t):
if isinstance(t, type) and dataclasses.is_dataclass(t):
if kind is None:
kind = 'Struct'
elif kind != 'Struct':
raise ValueError(f"Unexpected type kind for struct: {kind}")
struct_fields = dataclasses.fields(t)
dataclass_type = t
elif base_type is collections.abc.Sequence or base_type is list:
if kind is None:
kind = 'Vector' if vector_info is not None else 'List'
elif kind not in ('Vector', 'List', 'Table'):
elif not (kind == 'Vector' or kind in COLLECTION_TYPES):
raise ValueError(f"Unexpected type kind for list: {kind}")

args = typing.get_args(t)
Expand All @@ -134,15 +133,20 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
raise ValueError(f"type unsupported yet: {base_type}")

return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type,
struct_fields=struct_fields, attrs=attrs, nullable=nullable)
dataclass_type=dataclass_type, attrs=attrs, nullable=nullable)

def _encode_fields_schema(dataclass_type: type) -> list[dict[str, Any]]:
return [{ 'name': field.name,
**encode_enriched_type_info(analyze_type_info(field.type))
} for field in dataclasses.fields(dataclass_type)]

def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
encoded_type: dict[str, Any] = { 'kind': type_info.kind }

if type_info.kind == 'Struct':
if type_info.struct_fields is None:
raise ValueError("Struct type must have a struct fields")
encoded_type['fields'] = [_dump_field_schema(field) for field in type_info.struct_fields]
if type_info.dataclass_type is None:
raise ValueError("Struct type must have a dataclass type")
encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type)

elif type_info.kind == 'Vector':
if type_info.vector_info is None:
Expand All @@ -152,21 +156,22 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type))
encoded_type['dimension'] = type_info.vector_info.dim

elif type_info.kind in ('List', 'Table'):
elif type_info.kind in COLLECTION_TYPES:
if type_info.elem_type is None:
raise ValueError(f"{type_info.kind} type must have an element type")
row_type_inof = analyze_type_info(type_info.elem_type)
if row_type_inof.struct_fields is None:
raise ValueError(f"{type_info.kind} type must have a struct fields")
row_type_info = analyze_type_info(type_info.elem_type)
if row_type_info.dataclass_type is None:
raise ValueError(f"{type_info.kind} type must have a dataclass type")
encoded_type['row'] = {
'fields': [_dump_field_schema(field) for field in row_type_inof.struct_fields],
'fields': _encode_fields_schema(row_type_info.dataclass_type),
}

return encoded_type

def _encode_enriched_type(t) -> dict[str, Any]:
enriched_type_info = analyze_type_info(t)

def encode_enriched_type_info(enriched_type_info: AnalyzedTypeInfo) -> dict[str, Any]:
"""
Encode an enriched type info to a CocoIndex engine's type representation
"""
encoded: dict[str, Any] = {'type': _encode_type(enriched_type_info)}

if enriched_type_info.attrs is not None:
Expand All @@ -178,10 +183,11 @@ def _encode_enriched_type(t) -> dict[str, Any]:
return encoded


def encode_type(t) -> dict[str, Any] | None:
def encode_enriched_type(t) -> dict[str, Any] | None:
"""
Convert a Python type to a CocoIndex's type in JSON.
Convert a Python type to a CocoIndex engine's type representation
"""
if t is None:
return None
return _encode_enriched_type(t)

return encode_enriched_type_info(analyze_type_info(t))
14 changes: 0 additions & 14 deletions src/ops/py_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,6 @@ impl PyOpArgSchema {
fn analyzed_value(&self) -> &crate::py::Pythonized<plan::AnalyzedValueMapping> {
&self.analyzed_value
}

fn validate_arg(
&self,
name: &str,
typ: crate::py::Pythonized<schema::EnrichedValueType>,
) -> PyResult<()> {
if self.value_type.0.typ != typ.0.typ {
return Err(PyException::new_err(format!(
"argument `{}` type mismatch, input type: {}, argument type: {}",
name, self.value_type.0.typ, typ.0.typ
)));
}
Ok(())
}
}

struct PyFunctionExecutor {
Expand Down