From 23661ca975ff9fed71862cd7dc2af6fce18fc5b6 Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Fri, 25 Jul 2025 09:30:10 -0700 Subject: [PATCH 1/2] refactor: make the logic to encode/decode values better arranged --- python/cocoindex/convert.py | 236 +++++++++++--------- python/cocoindex/tests/test_convert.py | 2 +- python/cocoindex/tests/test_typing.py | 285 ++++++++----------------- python/cocoindex/typing.py | 261 ++++++++++++---------- 4 files changed, 378 insertions(+), 406 deletions(-) diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 946153e0..e36b59eb 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -13,12 +13,19 @@ from .typing import ( KEY_FIELD_NAME, TABLE_TYPES, - DtypeRegistry, analyze_type_info, encode_enriched_type, - extract_ndarray_scalar_dtype, is_namedtuple_type, is_struct_type, + AnalyzedTypeInfo, + AnalyzedAnyType, + AnalyzedDictType, + AnalyzedListType, + AnalyzedBasicType, + AnalyzedUnionType, + AnalyzedUnknownType, + AnalyzedStructType, + is_numpy_number_type, ) @@ -79,46 +86,88 @@ def make_engine_value_decoder( Returns: A decoder from an engine value to a Python value. """ + src_type_kind = src_type["kind"] - dst_is_any = ( - dst_annotation is None - or dst_annotation is inspect.Parameter.empty - or dst_annotation is Any - ) - if dst_is_any: - if src_type_kind == "Union": - return lambda value: value[1] - if src_type_kind == "Struct": - return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"]) - if src_type_kind in TABLE_TYPES: - if src_type_kind == "LTable": + dst_type_info = analyze_type_info(dst_annotation) + dst_type_variant = dst_type_info.variant + + if isinstance(dst_type_variant, AnalyzedUnknownType): + raise ValueError( + f"Type mismatch for `{''.join(field_path)}`: " + f"declared `{dst_type_info.core_type}`, an unsupported type" + ) + + if src_type_kind == "Struct": + return _make_engine_struct_value_decoder( + field_path, + src_type["fields"], + dst_type_info, + ) + + if src_type_kind in TABLE_TYPES: + field_path.append("[*]") + engine_fields_schema = src_type["row"]["fields"] + + if src_type_kind == "LTable": + if isinstance(dst_type_variant, AnalyzedAnyType): return _make_engine_ltable_to_list_dict_decoder( - field_path, src_type["row"]["fields"] + field_path, engine_fields_schema + ) + if not isinstance(dst_type_variant, AnalyzedListType): + raise ValueError( + f"Type mismatch for `{''.join(field_path)}`: " + f"declared `{dst_type_info.core_type}`, a list type expected" ) - elif src_type_kind == "KTable": + row_decoder = _make_engine_struct_value_decoder( + field_path, + engine_fields_schema, + analyze_type_info(dst_type_variant.elem_type), + ) + + def decode(value: Any) -> Any | None: + if value is None: + return None + return [row_decoder(v) for v in value] + + elif src_type_kind == "KTable": + if isinstance(dst_type_variant, AnalyzedAnyType): return _make_engine_ktable_to_dict_dict_decoder( - field_path, src_type["row"]["fields"] + field_path, engine_fields_schema + ) + if not isinstance(dst_type_variant, AnalyzedDictType): + raise ValueError( + f"Type mismatch for `{''.join(field_path)}`: " + f"declared `{dst_type_info.core_type}`, a dict type expected" ) - return lambda value: value - # Handle struct -> dict binding for explicit dict annotations - is_dict_annotation = False - if dst_annotation is dict: - is_dict_annotation = True - elif getattr(dst_annotation, "__origin__", None) is dict: - args = getattr(dst_annotation, "__args__", ()) - if args == (str, Any): - is_dict_annotation = True - if is_dict_annotation and src_type_kind == "Struct": - return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"]) + key_field_schema = engine_fields_schema[0] + field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}") + key_decoder = make_engine_value_decoder( + field_path, key_field_schema["type"], dst_type_variant.key_type + ) + field_path.pop() + value_decoder = _make_engine_struct_value_decoder( + field_path, + engine_fields_schema[1:], + analyze_type_info(dst_type_variant.value_type), + ) - dst_type_info = analyze_type_info(dst_annotation) + def decode(value: Any) -> Any | None: + if value is None: + return None + return {key_decoder(v[0]): value_decoder(v[1:]) for v in value} + + field_path.pop() + return decode if src_type_kind == "Union": + if isinstance(dst_type_variant, AnalyzedAnyType): + return lambda value: value[1] + dst_type_variants = ( - dst_type_info.union_variant_types - if dst_type_info.union_variant_types is not None + dst_type_variant.variant_types + if isinstance(dst_type_variant, AnalyzedUnionType) else [dst_annotation] ) src_type_variants = src_type["types"] @@ -142,43 +191,36 @@ def make_engine_value_decoder( decoders.append(decoder) return lambda value: decoders[value[0]](value[1]) - if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind): - raise ValueError( - f"Type mismatch for `{''.join(field_path)}`: " - f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})" - ) - - if dst_type_info.kind in ("Float32", "Float64", "Int64"): - dst_core_type = dst_type_info.core_type - - def decode_scalar(value: Any) -> Any | None: - if value is None: - if dst_type_info.nullable: - return None - raise ValueError( - f"Received null for non-nullable scalar `{''.join(field_path)}`" - ) - return dst_core_type(value) - - return decode_scalar + if isinstance(dst_type_variant, AnalyzedAnyType): + return lambda value: value if src_type_kind == "Vector": field_path_str = "".join(field_path) + if not isinstance(dst_type_variant, AnalyzedListType): + raise ValueError( + f"Type mismatch for `{''.join(field_path)}`: " + f"declared `{dst_type_info.core_type}`, a list type expected" + ) expected_dim = ( - dst_type_info.vector_info.dim if dst_type_info.vector_info else None + dst_type_variant.vector_info.dim + if dst_type_variant and dst_type_variant.vector_info + else None ) - elem_decoder = None + vec_elem_decoder = None scalar_dtype = None - if dst_type_info.np_number_type is None: # for Non-NDArray vector - elem_decoder = make_engine_value_decoder( + if ( + dst_type_variant + and is_numpy_number_type(dst_type_variant.elem_type) + and dst_type_info.base_type is np.ndarray + ): + scalar_dtype = dst_type_variant.elem_type + else: + vec_elem_decoder = make_engine_value_decoder( field_path + ["[*]"], src_type["element_type"], - dst_type_info.elem_type, + dst_type_variant and dst_type_variant.elem_type, ) - else: # for NDArray vector - scalar_dtype = extract_ndarray_scalar_dtype(dst_type_info.np_number_type) - _ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype) def decode_vector(value: Any) -> Any | None: if value is None: @@ -197,54 +239,33 @@ def decode_vector(value: Any) -> Any | None: f"expected {expected_dim}, got {len(value)}" ) - if elem_decoder is not None: # for Non-NDArray vector - return [elem_decoder(v) for v in value] + if vec_elem_decoder is not None: # for Non-NDArray vector + return [vec_elem_decoder(v) for v in value] else: # for NDArray vector return np.array(value, dtype=scalar_dtype) return decode_vector - if dst_type_info.struct_type is not None: - return _make_engine_struct_value_decoder( - field_path, src_type["fields"], dst_type_info.struct_type - ) - - if src_type_kind in TABLE_TYPES: - field_path.append("[*]") - elem_type_info = analyze_type_info(dst_type_info.elem_type) - if elem_type_info.struct_type is None: + if isinstance(dst_type_variant, AnalyzedBasicType): + if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind): raise ValueError( f"Type mismatch for `{''.join(field_path)}`: " - f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected" - ) - engine_fields_schema = src_type["row"]["fields"] - if elem_type_info.key_type is not None: - key_field_schema = engine_fields_schema[0] - field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}") - key_decoder = make_engine_value_decoder( - field_path, key_field_schema["type"], elem_type_info.key_type - ) - field_path.pop() - value_decoder = _make_engine_struct_value_decoder( - field_path, engine_fields_schema[1:], elem_type_info.struct_type + f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_variant.kind})" ) - def decode(value: Any) -> Any | None: - if value is None: - return None - return {key_decoder(v[0]): value_decoder(v[1:]) for v in value} - else: - elem_decoder = _make_engine_struct_value_decoder( - field_path, engine_fields_schema, elem_type_info.struct_type - ) + if dst_type_variant.kind in ("Float32", "Float64", "Int64"): + dst_core_type = dst_type_info.core_type - def decode(value: Any) -> Any | None: + def decode_scalar(value: Any) -> Any | None: if value is None: - return None - return [elem_decoder(v) for v in value] + if dst_type_info.nullable: + return None + raise ValueError( + f"Received null for non-nullable scalar `{''.join(field_path)}`" + ) + return dst_core_type(value) - field_path.pop() - return decode + return decode_scalar return lambda value: value @@ -252,11 +273,36 @@ def decode(value: Any) -> Any | None: def _make_engine_struct_value_decoder( field_path: list[str], src_fields: list[dict[str, Any]], - dst_struct_type: type, + dst_type_info: AnalyzedTypeInfo, ) -> Callable[[list[Any]], Any]: """Make a decoder from an engine field values to a Python value.""" + dst_type_variant = dst_type_info.variant + + use_dict = False + if isinstance(dst_type_variant, AnalyzedAnyType): + use_dict = True + elif isinstance(dst_type_variant, AnalyzedDictType): + analyzed_key_type = analyze_type_info(dst_type_variant.key_type) + analyzed_value_type = analyze_type_info(dst_type_variant.value_type) + use_dict = ( + isinstance(analyzed_key_type.variant, AnalyzedAnyType) + or ( + isinstance(analyzed_key_type.variant, AnalyzedBasicType) + and analyzed_key_type.variant.kind == "Str" + ) + ) and isinstance(analyzed_value_type.variant, AnalyzedAnyType) + if use_dict: + return _make_engine_struct_to_dict_decoder(field_path, src_fields) + + if not isinstance(dst_type_variant, AnalyzedStructType): + raise ValueError( + f"Type mismatch for `{''.join(field_path)}`: " + f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected" + ) + src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)} + dst_struct_type = dst_type_variant.struct_type parameters: Mapping[str, inspect.Parameter] if dataclasses.is_dataclass(dst_struct_type): diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 57260a55..8d913f97 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -1216,7 +1216,7 @@ class MixedStruct: numpy_float: np.float64 python_float: float string: str - annotated_int: Annotated[np.int64, TypeKind("int")] + annotated_int: Annotated[np.int64, TypeKind("Int64")] annotated_float: Float32 instance = MixedStruct( diff --git a/python/cocoindex/tests/test_typing.py b/python/cocoindex/tests/test_typing.py index adf995a0..8f64071a 100644 --- a/python/cocoindex/tests/test_typing.py +++ b/python/cocoindex/tests/test_typing.py @@ -9,6 +9,10 @@ from numpy.typing import NDArray from cocoindex.typing import ( + AnalyzedBasicType, + AnalyzedDictType, + AnalyzedListType, + AnalyzedStructType, AnalyzedTypeInfo, TypeAttr, TypeKind, @@ -33,83 +37,67 @@ class SimpleNamedTuple(NamedTuple): def test_ndarray_float32_no_dim() -> None: typ = NDArray[np.float32] result = analyze_type_info(typ) - assert result.kind == "Vector" - assert result.vector_info == VectorInfo(dim=None) - assert result.elem_type == np.float32 - assert result.key_type is None - assert result.struct_type is None + assert isinstance(result.variant, AnalyzedListType) + assert result.variant.vector_info is None + assert result.variant.elem_type == np.float32 assert result.nullable is False - assert result.np_number_type is not None - assert get_origin(result.np_number_type) == np.ndarray - assert get_args(result.np_number_type)[1] == np.dtype[np.float32] + assert get_origin(result.core_type) == np.ndarray + assert get_args(result.core_type)[1] == np.dtype[np.float32] def test_vector_float32_no_dim() -> None: typ = Vector[np.float32] result = analyze_type_info(typ) - assert result.kind == "Vector" - assert result.vector_info == VectorInfo(dim=None) - assert result.elem_type == np.float32 - assert result.key_type is None - assert result.struct_type is None + assert isinstance(result.variant, AnalyzedListType) + assert result.variant.vector_info == VectorInfo(dim=None) + assert result.variant.elem_type == np.float32 assert result.nullable is False - assert result.np_number_type is not None - assert get_origin(result.np_number_type) == np.ndarray - assert get_args(result.np_number_type)[1] == np.dtype[np.float32] + assert get_origin(result.core_type) == np.ndarray + assert get_args(result.core_type)[1] == np.dtype[np.float32] def test_ndarray_float64_with_dim() -> None: typ = Annotated[NDArray[np.float64], VectorInfo(dim=128)] result = analyze_type_info(typ) - assert result.kind == "Vector" - assert result.vector_info == VectorInfo(dim=128) - assert result.elem_type == np.float64 - assert result.key_type is None - assert result.struct_type is None + assert isinstance(result.variant, AnalyzedListType) + assert result.variant.vector_info == VectorInfo(dim=128) + assert result.variant.elem_type == np.float64 assert result.nullable is False - assert result.np_number_type is not None - assert get_origin(result.np_number_type) == np.ndarray - assert get_args(result.np_number_type)[1] == np.dtype[np.float64] + assert get_origin(result.core_type) == np.ndarray + assert get_args(result.core_type)[1] == np.dtype[np.float64] def test_vector_float32_with_dim() -> None: typ = Vector[np.float32, Literal[384]] result = analyze_type_info(typ) - assert result.kind == "Vector" - assert result.vector_info == VectorInfo(dim=384) - assert result.elem_type == np.float32 - assert result.key_type is None - assert result.struct_type is None + assert isinstance(result.variant, AnalyzedListType) + assert result.variant.vector_info == VectorInfo(dim=384) + assert result.variant.elem_type == np.float32 assert result.nullable is False - assert result.np_number_type is not None - assert get_origin(result.np_number_type) == np.ndarray - assert get_args(result.np_number_type)[1] == np.dtype[np.float32] + assert get_origin(result.core_type) == np.ndarray + assert get_args(result.core_type)[1] == np.dtype[np.float32] def test_ndarray_int64_no_dim() -> None: typ = NDArray[np.int64] result = analyze_type_info(typ) - assert result.kind == "Vector" - assert result.vector_info == VectorInfo(dim=None) - assert result.elem_type == np.int64 + assert isinstance(result.variant, AnalyzedListType) + assert result.variant.vector_info is None + assert result.variant.elem_type == np.int64 assert result.nullable is False - assert result.np_number_type is not None - assert get_origin(result.np_number_type) == np.ndarray - assert get_args(result.np_number_type)[1] == np.dtype[np.int64] + assert get_origin(result.core_type) == np.ndarray + assert get_args(result.core_type)[1] == np.dtype[np.int64] def test_nullable_ndarray() -> None: typ = NDArray[np.float32] | None result = analyze_type_info(typ) - assert result.kind == "Vector" - assert result.vector_info == VectorInfo(dim=None) - assert result.elem_type == np.float32 - assert result.key_type is None - assert result.struct_type is None + assert isinstance(result.variant, AnalyzedListType) + assert result.variant.vector_info is None + assert result.variant.elem_type == np.float32 assert result.nullable is True - assert result.np_number_type is not None - assert get_origin(result.np_number_type) == np.ndarray - assert get_args(result.np_number_type)[1] == np.dtype[np.float32] + assert get_origin(result.core_type) == np.ndarray + assert get_args(result.core_type)[1] == np.dtype[np.float32] def test_scalar_numpy_types() -> None: @@ -119,38 +107,37 @@ def test_scalar_numpy_types() -> None: (np.float64, "Float64"), ]: type_info = analyze_type_info(np_type) - assert type_info.kind == expected_kind, ( - f"Expected {expected_kind} for {np_type}, got {type_info.kind}" + assert isinstance(type_info.variant, AnalyzedBasicType) + assert type_info.variant.kind == expected_kind, ( + f"Expected {expected_kind} for {np_type}, got {type_info.variant.kind}" ) - assert type_info.np_number_type == np_type, ( - f"Expected {np_type}, got {type_info.np_number_type}" + assert type_info.core_type == np_type, ( + f"Expected {np_type}, got {type_info.core_type}" ) - assert type_info.elem_type is None - assert type_info.vector_info is None def test_vector_str() -> None: typ = Vector[str] result = analyze_type_info(typ) - assert result.kind == "Vector" - assert result.elem_type is str - assert result.vector_info == VectorInfo(dim=None) + assert isinstance(result.variant, AnalyzedListType) + assert result.variant.elem_type is str + assert result.variant.vector_info == VectorInfo(dim=None) def test_vector_complex64() -> None: typ = Vector[np.complex64] result = analyze_type_info(typ) - assert result.kind == "Vector" - assert result.elem_type == np.complex64 - assert result.vector_info == VectorInfo(dim=None) + assert isinstance(result.variant, AnalyzedListType) + assert result.variant.elem_type == np.complex64 + assert result.variant.vector_info == VectorInfo(dim=None) def test_non_numpy_vector() -> None: typ = Vector[float, Literal[3]] result = analyze_type_info(typ) - assert result.kind == "Vector" - assert result.elem_type is float - assert result.vector_info == VectorInfo(dim=3) + assert isinstance(result.variant, AnalyzedListType) + assert result.variant.elem_type is float + assert result.variant.vector_info == VectorInfo(dim=3) def test_ndarray_any_dtype() -> None: @@ -165,13 +152,9 @@ def test_list_of_primitives() -> None: typ = list[str] result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Vector", core_type=list[str], - vector_info=VectorInfo(dim=None), - elem_type=str, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=list, + variant=AnalyzedListType(elem_type=str, vector_info=None), attrs=None, nullable=False, ) @@ -181,13 +164,9 @@ def test_list_of_structs() -> None: typ = list[SimpleDataclass] result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="LTable", core_type=list[SimpleDataclass], - vector_info=None, - elem_type=SimpleDataclass, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=list, + variant=AnalyzedListType(elem_type=SimpleDataclass, vector_info=None), attrs=None, nullable=False, ) @@ -197,13 +176,9 @@ def test_sequence_of_int() -> None: typ = Sequence[int] result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Vector", core_type=Sequence[int], - vector_info=VectorInfo(dim=None), - elem_type=int, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=Sequence, + variant=AnalyzedListType(elem_type=int, vector_info=None), attrs=None, nullable=False, ) @@ -213,13 +188,9 @@ def test_list_with_vector_info() -> None: typ = Annotated[list[int], VectorInfo(dim=5)] result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Vector", core_type=list[int], - vector_info=VectorInfo(dim=5), - elem_type=int, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=list, + variant=AnalyzedListType(elem_type=int, vector_info=VectorInfo(dim=5)), attrs=None, nullable=False, ) @@ -229,13 +200,9 @@ def test_dict_str_int() -> None: typ = dict[str, int] result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="KTable", core_type=dict[str, int], - vector_info=None, - elem_type=(str, int), - key_type=None, - struct_type=None, - np_number_type=None, + base_type=dict, + variant=AnalyzedDictType(key_type=str, value_type=int), attrs=None, nullable=False, ) @@ -245,13 +212,9 @@ def test_mapping_str_dataclass() -> None: typ = Mapping[str, SimpleDataclass] result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="KTable", core_type=Mapping[str, SimpleDataclass], - vector_info=None, - elem_type=(str, SimpleDataclass), - key_type=None, - struct_type=None, - np_number_type=None, + base_type=Mapping, + variant=AnalyzedDictType(key_type=str, value_type=SimpleDataclass), attrs=None, nullable=False, ) @@ -261,13 +224,9 @@ def test_dataclass() -> None: typ = SimpleDataclass result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Struct", core_type=SimpleDataclass, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=SimpleDataclass, - np_number_type=None, + base_type=SimpleDataclass, + variant=AnalyzedStructType(struct_type=SimpleDataclass), attrs=None, nullable=False, ) @@ -277,29 +236,9 @@ def test_named_tuple() -> None: typ = SimpleNamedTuple result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Struct", core_type=SimpleNamedTuple, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=SimpleNamedTuple, - np_number_type=None, - attrs=None, - nullable=False, - ) - - -def test_tuple_key_value() -> None: - typ = (str, int) - result = analyze_type_info(typ) - assert result == AnalyzedTypeInfo( - kind="Int64", - core_type=int, - vector_info=None, - elem_type=None, - key_type=str, - struct_type=None, - np_number_type=None, + base_type=SimpleNamedTuple, + variant=AnalyzedStructType(struct_type=SimpleNamedTuple), attrs=None, nullable=False, ) @@ -309,13 +248,9 @@ def test_str() -> None: typ = str result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Str", core_type=str, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=str, + variant=AnalyzedBasicType(kind="Str"), attrs=None, nullable=False, ) @@ -325,13 +260,9 @@ def test_bool() -> None: typ = bool result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Bool", core_type=bool, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=bool, + variant=AnalyzedBasicType(kind="Bool"), attrs=None, nullable=False, ) @@ -341,13 +272,9 @@ def test_bytes() -> None: typ = bytes result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Bytes", core_type=bytes, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=bytes, + variant=AnalyzedBasicType(kind="Bytes"), attrs=None, nullable=False, ) @@ -357,13 +284,9 @@ def test_uuid() -> None: typ = uuid.UUID result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Uuid", core_type=uuid.UUID, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=uuid.UUID, + variant=AnalyzedBasicType(kind="Uuid"), attrs=None, nullable=False, ) @@ -373,13 +296,9 @@ def test_date() -> None: typ = datetime.date result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Date", core_type=datetime.date, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=datetime.date, + variant=AnalyzedBasicType(kind="Date"), attrs=None, nullable=False, ) @@ -389,13 +308,9 @@ def test_time() -> None: typ = datetime.time result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Time", core_type=datetime.time, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=datetime.time, + variant=AnalyzedBasicType(kind="Time"), attrs=None, nullable=False, ) @@ -405,13 +320,9 @@ def test_timedelta() -> None: typ = datetime.timedelta result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="TimeDelta", core_type=datetime.timedelta, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=datetime.timedelta, + variant=AnalyzedBasicType(kind="TimeDelta"), attrs=None, nullable=False, ) @@ -421,13 +332,9 @@ def test_float() -> None: typ = float result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Float64", core_type=float, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=float, + variant=AnalyzedBasicType(kind="Float64"), attrs=None, nullable=False, ) @@ -437,13 +344,9 @@ def test_int() -> None: typ = int result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Int64", core_type=int, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=int, + variant=AnalyzedBasicType(kind="Int64"), attrs=None, nullable=False, ) @@ -453,13 +356,9 @@ def test_type_with_attributes() -> None: typ = Annotated[str, TypeAttr("key", "value")] result = analyze_type_info(typ) assert result == AnalyzedTypeInfo( - kind="Str", core_type=str, - vector_info=None, - elem_type=None, - key_type=None, - struct_type=None, - np_number_type=None, + base_type=str, + variant=AnalyzedBasicType(kind="Str"), attrs={"key": "value"}, nullable=False, ) @@ -494,7 +393,7 @@ def test_encode_enriched_type_ltable() -> None: typ = list[SimpleDataclass] result = encode_enriched_type(typ) assert result["type"]["kind"] == "LTable" - assert result["type"]["row"]["kind"] == "Struct" + assert "fields" in result["type"]["row"] assert len(result["type"]["row"]["fields"]) == 2 @@ -525,16 +424,18 @@ def test_encode_scalar_numpy_types_schema() -> None: assert not schema.get("nullable", False) -def test_invalid_struct_kind() -> None: +def test_annotated_struct_with_type_kind() -> None: typ = Annotated[SimpleDataclass, TypeKind("Vector")] - with pytest.raises(ValueError, match="Unexpected type kind for struct: Vector"): - analyze_type_info(typ) + result = analyze_type_info(typ) + assert isinstance(result.variant, AnalyzedBasicType) + assert result.variant.kind == "Vector" -def test_invalid_list_kind() -> None: +def test_annotated_list_with_type_kind() -> None: typ = Annotated[list[int], TypeKind("Struct")] - with pytest.raises(ValueError, match="Unexpected type kind for list: Struct"): - analyze_type_info(typ) + result = analyze_type_info(typ) + assert isinstance(result.variant, AnalyzedBasicType) + assert result.variant.kind == "Struct" def test_unsupported_type() -> None: diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 3c7e1dc0..b53c476e 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -88,8 +88,6 @@ def __class_getitem__(self, params): TABLE_TYPES: tuple[str, str] = ("KTable", "LTable") KEY_FIELD_NAME: str = "_key" -ElementType = type | tuple[type, type] | Annotated[Any, TypeKind] - def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any: args = typing.get_args(ndarray_type) @@ -108,7 +106,7 @@ def is_namedtuple_type(t: type) -> bool: return isinstance(t, type) and issubclass(t, tuple) and hasattr(t, "_fields") -def is_struct_type(t: ElementType | None) -> bool: +def is_struct_type(t: Any) -> bool: return isinstance(t, type) and ( dataclasses.is_dataclass(t) or is_namedtuple_type(t) ) @@ -120,14 +118,14 @@ class DtypeRegistry: Maps NumPy dtypes to their CocoIndex type kind. """ - _DTYPE_TO_KIND: dict[ElementType, str] = { + _DTYPE_TO_KIND: dict[Any, str] = { np.float32: "Float32", np.float64: "Float64", np.int64: "Int64", } @classmethod - def validate_dtype_and_get_kind(cls, dtype: ElementType) -> str: + def validate_dtype_and_get_kind(cls, dtype: Any) -> str: """ Validate that the given dtype is supported, and get its CocoIndex kind by dtype. """ @@ -144,26 +142,65 @@ def validate_dtype_and_get_kind(cls, dtype: ElementType) -> str: return kind +class AnalyzedAnyType(NamedTuple): + pass + + +class AnalyzedBasicType(NamedTuple): + """ + For types that fit into basic type, and annotated with basic type or Json type. + """ + + kind: str + + +class AnalyzedListType(NamedTuple): + elem_type: Any + vector_info: VectorInfo | None + + +class AnalyzedStructType(NamedTuple): + struct_type: type + + +class AnalyzedUnionType(NamedTuple): + variant_types: list[Any] + nullable: bool + + +class AnalyzedDictType(NamedTuple): + key_type: Any + value_type: Any + + +class AnalyzedUnknownType(NamedTuple): + pass + + +AnalyzedTypeVariant = ( + AnalyzedAnyType + | AnalyzedBasicType + | AnalyzedListType + | AnalyzedStructType + | AnalyzedUnionType + | AnalyzedDictType + | AnalyzedUnknownType +) + + @dataclasses.dataclass class AnalyzedTypeInfo: """ Analyzed info of a Python type. """ - kind: str + # The type without annotations. e.g. int, list[int], dict[str, int] core_type: Any - vector_info: VectorInfo | None # For Vector - elem_type: ElementType | None # For Vector and Table - - key_type: type | None # For element of KTable - struct_type: type | None # For Struct, a dataclass or namedtuple - np_number_type: ( - type | None - ) # NumPy dtype for the element type, if represented by numpy.ndarray or a NumPy scalar - + # The type without annotations and parameters. e.g. int, list, dict + base_type: Any + variant: AnalyzedTypeVariant attrs: dict[str, Any] | None nullable: bool = False - union_variant_types: typing.List[ElementType] | None = None # For Union def analyze_type_info(t: Any) -> AnalyzedTypeInfo: @@ -171,14 +208,10 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: Analyze a Python type annotation and extract CocoIndex-specific type information. Type annotations for specific CocoIndex types are expected. Raises ValueError for Any, empty, or untyped dict types. """ - if isinstance(t, tuple) and len(t) == 2: - kt, vt = t - result = analyze_type_info(vt) - result.key_type = kt - return result annotations: tuple[Annotation, ...] = () base_type = None + type_args: tuple[Any, ...] = () nullable = False while True: base_type = typing.get_origin(t) @@ -186,7 +219,12 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: annotations = t.__metadata__ t = t.__origin__ else: + if base_type is None: + base_type = t + else: + type_args = typing.get_args(t) break + core_type = t attrs: dict[str, Any] | None = None vector_info: VectorInfo | None = None @@ -201,74 +239,42 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: elif isinstance(attr, TypeKind): kind = attr.kind - struct_type: type | None = None - elem_type: ElementType | None = None - union_variant_types: typing.List[ElementType] | None = None - key_type: type | None = None - np_number_type: type | None = None - if is_struct_type(t): - struct_type = t + variant: AnalyzedTypeVariant | None = None - if kind is None: - kind = "Struct" - elif kind != "Struct": - raise ValueError(f"Unexpected type kind for struct: {kind}") + if kind is not None: + variant = AnalyzedBasicType(kind=kind) + elif base_type is None or base_type is Any or base_type is inspect.Parameter.empty: + variant = AnalyzedAnyType() + elif is_struct_type(base_type): + variant = AnalyzedStructType(struct_type=t) elif is_numpy_number_type(t): - np_number_type = t kind = DtypeRegistry.validate_dtype_and_get_kind(t) + variant = AnalyzedBasicType(kind=kind) elif base_type is collections.abc.Sequence or base_type is list: - args = typing.get_args(t) - elem_type = args[0] - - if kind is None: - if is_struct_type(elem_type): - kind = "LTable" - if vector_info is not None: - raise ValueError( - "Vector element must be a simple type, not a struct" - ) - else: - kind = "Vector" - if vector_info is None: - vector_info = VectorInfo(dim=None) - elif not (kind == "Vector" or kind in TABLE_TYPES): - raise ValueError(f"Unexpected type kind for list: {kind}") + elem_type = type_args[0] if len(type_args) > 0 else None + variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info) elif base_type is np.ndarray: - kind = "Vector" np_number_type = t elem_type = extract_ndarray_scalar_dtype(np_number_type) _ = DtypeRegistry.validate_dtype_and_get_kind(elem_type) - vector_info = VectorInfo(dim=None) if vector_info is None else vector_info - + variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info) elif base_type is collections.abc.Mapping or base_type is dict or t is dict: - args = typing.get_args(t) - if len(args) == 0: # Handle untyped dict - raise ValueError( - "Untyped dict is not accepted as a specific type annotation; please provide a concrete type, " - "e.g. a dataclass or namedtuple for Struct types, a dict[str, T] for KTable types." - ) - else: - elem_type = (args[0], args[1]) - kind = "KTable" + key_type = type_args[0] if len(type_args) > 0 else None + elem_type = type_args[1] if len(type_args) > 1 else None + variant = AnalyzedDictType(key_type=key_type, value_type=elem_type) elif base_type in (types.UnionType, typing.Union): - possible_types = typing.get_args(t) - non_none_types = [ - arg for arg in possible_types if arg not in (None, types.NoneType) - ] - + non_none_types = [arg for arg in type_args if arg not in (None, types.NoneType)] if len(non_none_types) == 0: return analyze_type_info(None) - nullable = len(non_none_types) < len(possible_types) - + nullable = len(non_none_types) < len(type_args) if len(non_none_types) == 1: result = analyze_type_info(non_none_types[0]) result.nullable = nullable return result - kind = "Union" - union_variant_types = non_none_types - elif kind is None: + variant = AnalyzedUnionType(variant_types=non_none_types, nullable=nullable) + else: if t is bytes: kind = "Bytes" elif t is str: @@ -293,25 +299,21 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: raise ValueError( f"Unsupported as a specific type annotation for CocoIndex data type (https://cocoindex.io/docs/core/data_types): {t}" ) + variant = AnalyzedBasicType(kind=kind) return AnalyzedTypeInfo( - kind=kind, - core_type=t, - vector_info=vector_info, - elem_type=elem_type, - union_variant_types=union_variant_types, - key_type=key_type, - struct_type=struct_type, - np_number_type=np_number_type, + core_type=core_type, + base_type=base_type, + variant=variant, attrs=attrs, nullable=nullable, ) -def _encode_fields_schema( +def _encode_struct_schema( struct_type: type, key_type: type | None = None -) -> list[dict[str, Any]]: - result = [] +) -> dict[str, Any]: + fields = [] def add_field(name: str, t: Any) -> None: try: @@ -323,7 +325,7 @@ def add_field(name: str, t: Any) -> None: ) raise type_info["name"] = name - result.append(type_info) + fields.append(type_info) if key_type is not None: add_field(KEY_FIELD_NAME, key_type) @@ -335,45 +337,68 @@ def add_field(name: str, t: Any) -> None: for name, field_type in struct_type.__annotations__.items(): add_field(name, field_type) + result: dict[str, Any] = {"fields": fields} + if doc := inspect.getdoc(struct_type): + result["description"] = doc return result def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: - encoded_type: dict[str, Any] = {"kind": type_info.kind} - - if type_info.kind == "Struct": - if type_info.struct_type is None: - raise ValueError("Struct type must have a dataclass or namedtuple type") - encoded_type["fields"] = _encode_fields_schema( - type_info.struct_type, type_info.key_type - ) - if doc := inspect.getdoc(type_info.struct_type): - encoded_type["description"] = doc - - elif type_info.kind == "Vector": - if type_info.vector_info is None: - raise ValueError("Vector type must have a vector info") - if type_info.elem_type is None: - raise ValueError("Vector type must have an element type") - elem_type_info = analyze_type_info(type_info.elem_type) - encoded_type["element_type"] = _encode_type(elem_type_info) - encoded_type["dimension"] = type_info.vector_info.dim - - elif type_info.kind == "Union": - if type_info.union_variant_types is None: - raise ValueError("Union type must have a variant type list") - encoded_type["types"] = [ - _encode_type(analyze_type_info(typ)) - for typ in type_info.union_variant_types - ] - - elif type_info.kind in TABLE_TYPES: - if type_info.elem_type is None: - raise ValueError(f"{type_info.kind} type must have an element type") - row_type_info = analyze_type_info(type_info.elem_type) - encoded_type["row"] = _encode_type(row_type_info) - - return encoded_type + variant = type_info.variant + + if isinstance(variant, AnalyzedAnyType): + raise ValueError("Specific type annotation is expected") + + if isinstance(variant, AnalyzedUnknownType): + raise ValueError(f"Unsupported type annotation: {type_info.core_type}") + + if isinstance(variant, AnalyzedBasicType): + return {"kind": variant.kind} + + if isinstance(variant, AnalyzedStructType): + encoded_type = _encode_struct_schema(variant.struct_type) + encoded_type["kind"] = "Struct" + return encoded_type + + if isinstance(variant, AnalyzedListType): + elem_type_info = analyze_type_info(variant.elem_type) + encoded_elem_type = _encode_type(elem_type_info) + if isinstance(elem_type_info.variant, AnalyzedStructType): + if variant.vector_info is not None: + raise ValueError("LTable type must not have a vector info") + return { + "kind": "LTable", + "row": _encode_struct_schema(elem_type_info.variant.struct_type), + } + else: + vector_info = variant.vector_info + return { + "kind": "Vector", + "element_type": encoded_elem_type, + "dimension": vector_info and vector_info.dim, + } + + if isinstance(variant, AnalyzedDictType): + value_type_info = analyze_type_info(variant.value_type) + if not isinstance(value_type_info.variant, AnalyzedStructType): + raise ValueError( + f"KTable value must have a Struct type, got {value_type_info.core_type}" + ) + return { + "kind": "KTable", + "row": _encode_struct_schema( + value_type_info.variant.struct_type, + variant.key_type, + ), + } + + if isinstance(variant, AnalyzedUnionType): + return { + "kind": "Union", + "types": [ + _encode_type(analyze_type_info(typ)) for typ in variant.variant_types + ], + } def encode_enriched_type_info(enriched_type_info: AnalyzedTypeInfo) -> dict[str, Any]: From 4e2a397d53880b9266faccea20d8f94356a55cbb Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Fri, 25 Jul 2025 11:07:34 -0700 Subject: [PATCH 2/2] docs: fix docstring for new types --- python/cocoindex/typing.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index b53c476e..cb0ae887 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -143,7 +143,9 @@ def validate_dtype_and_get_kind(cls, dtype: Any) -> str: class AnalyzedAnyType(NamedTuple): - pass + """ + When the type annotation is missing or matches any type. + """ class AnalyzedBasicType(NamedTuple): @@ -155,26 +157,43 @@ class AnalyzedBasicType(NamedTuple): class AnalyzedListType(NamedTuple): + """ + Any list type, e.g. list[T], Sequence[T], NDArray[T], etc. + """ + elem_type: Any vector_info: VectorInfo | None class AnalyzedStructType(NamedTuple): + """ + Any struct type, e.g. dataclass, NamedTuple, etc. + """ + struct_type: type class AnalyzedUnionType(NamedTuple): + """ + Any union type, e.g. T1 | T2 | ..., etc. + """ + variant_types: list[Any] - nullable: bool class AnalyzedDictType(NamedTuple): + """ + Any dict type, e.g. dict[T1, T2], Mapping[T1, T2], etc. + """ + key_type: Any value_type: Any class AnalyzedUnknownType(NamedTuple): - pass + """ + Any type that is not supported by CocoIndex. + """ AnalyzedTypeVariant = ( @@ -206,7 +225,6 @@ class AnalyzedTypeInfo: def analyze_type_info(t: Any) -> AnalyzedTypeInfo: """ Analyze a Python type annotation and extract CocoIndex-specific type information. - Type annotations for specific CocoIndex types are expected. Raises ValueError for Any, empty, or untyped dict types. """ annotations: tuple[Annotation, ...] = () @@ -273,7 +291,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: result.nullable = nullable return result - variant = AnalyzedUnionType(variant_types=non_none_types, nullable=nullable) + variant = AnalyzedUnionType(variant_types=non_none_types) else: if t is bytes: kind = "Bytes"