Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 141 additions & 95 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
from .typing import (
KEY_FIELD_NAME,
TABLE_TYPES,
DtypeRegistry,
analyze_type_info,
encode_enriched_type,
extract_ndarray_scalar_dtype,
is_namedtuple_type,
is_struct_type,
AnalyzedTypeInfo,
AnalyzedAnyType,
AnalyzedDictType,
AnalyzedListType,
AnalyzedBasicType,
AnalyzedUnionType,
AnalyzedUnknownType,
AnalyzedStructType,
is_numpy_number_type,
)


Expand Down Expand Up @@ -79,46 +86,88 @@ def make_engine_value_decoder(
Returns:
A decoder from an engine value to a Python value.
"""

src_type_kind = src_type["kind"]

dst_is_any = (
dst_annotation is None
or dst_annotation is inspect.Parameter.empty
or dst_annotation is Any
)
if dst_is_any:
if src_type_kind == "Union":
return lambda value: value[1]
if src_type_kind == "Struct":
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
if src_type_kind in TABLE_TYPES:
if src_type_kind == "LTable":
dst_type_info = analyze_type_info(dst_annotation)
dst_type_variant = dst_type_info.variant

if isinstance(dst_type_variant, AnalyzedUnknownType):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.core_type}`, an unsupported type"
)

if src_type_kind == "Struct":
return _make_engine_struct_value_decoder(
field_path,
src_type["fields"],
dst_type_info,
)

if src_type_kind in TABLE_TYPES:
field_path.append("[*]")
engine_fields_schema = src_type["row"]["fields"]

if src_type_kind == "LTable":
if isinstance(dst_type_variant, AnalyzedAnyType):
return _make_engine_ltable_to_list_dict_decoder(
field_path, src_type["row"]["fields"]
field_path, engine_fields_schema
)
if not isinstance(dst_type_variant, AnalyzedListType):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.core_type}`, a list type expected"
)
elif src_type_kind == "KTable":
row_decoder = _make_engine_struct_value_decoder(
field_path,
engine_fields_schema,
analyze_type_info(dst_type_variant.elem_type),
)

def decode(value: Any) -> Any | None:
if value is None:
return None
return [row_decoder(v) for v in value]

elif src_type_kind == "KTable":
if isinstance(dst_type_variant, AnalyzedAnyType):
return _make_engine_ktable_to_dict_dict_decoder(
field_path, src_type["row"]["fields"]
field_path, engine_fields_schema
)
if not isinstance(dst_type_variant, AnalyzedDictType):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.core_type}`, a dict type expected"
)
return lambda value: value

# Handle struct -> dict binding for explicit dict annotations
is_dict_annotation = False
if dst_annotation is dict:
is_dict_annotation = True
elif getattr(dst_annotation, "__origin__", None) is dict:
args = getattr(dst_annotation, "__args__", ())
if args == (str, Any):
is_dict_annotation = True
if is_dict_annotation and src_type_kind == "Struct":
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
key_field_schema = engine_fields_schema[0]
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
key_decoder = make_engine_value_decoder(
field_path, key_field_schema["type"], dst_type_variant.key_type
)
field_path.pop()
value_decoder = _make_engine_struct_value_decoder(
field_path,
engine_fields_schema[1:],
analyze_type_info(dst_type_variant.value_type),
)

dst_type_info = analyze_type_info(dst_annotation)
def decode(value: Any) -> Any | None:
if value is None:
return None
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}

field_path.pop()
return decode

if src_type_kind == "Union":
if isinstance(dst_type_variant, AnalyzedAnyType):
return lambda value: value[1]

dst_type_variants = (
dst_type_info.union_variant_types
if dst_type_info.union_variant_types is not None
dst_type_variant.variant_types
if isinstance(dst_type_variant, AnalyzedUnionType)
else [dst_annotation]
)
src_type_variants = src_type["types"]
Expand All @@ -142,43 +191,36 @@ def make_engine_value_decoder(
decoders.append(decoder)
return lambda value: decoders[value[0]](value[1])

if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
)

if dst_type_info.kind in ("Float32", "Float64", "Int64"):
dst_core_type = dst_type_info.core_type

def decode_scalar(value: Any) -> Any | None:
if value is None:
if dst_type_info.nullable:
return None
raise ValueError(
f"Received null for non-nullable scalar `{''.join(field_path)}`"
)
return dst_core_type(value)

return decode_scalar
if isinstance(dst_type_variant, AnalyzedAnyType):
return lambda value: value

if src_type_kind == "Vector":
field_path_str = "".join(field_path)
if not isinstance(dst_type_variant, AnalyzedListType):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.core_type}`, a list type expected"
)
expected_dim = (
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
dst_type_variant.vector_info.dim
if dst_type_variant and dst_type_variant.vector_info
else None
)

elem_decoder = None
vec_elem_decoder = None
scalar_dtype = None
if dst_type_info.np_number_type is None: # for Non-NDArray vector
elem_decoder = make_engine_value_decoder(
if (
dst_type_variant
and is_numpy_number_type(dst_type_variant.elem_type)
and dst_type_info.base_type is np.ndarray
):
scalar_dtype = dst_type_variant.elem_type
else:
vec_elem_decoder = make_engine_value_decoder(
field_path + ["[*]"],
src_type["element_type"],
dst_type_info.elem_type,
dst_type_variant and dst_type_variant.elem_type,
)
else: # for NDArray vector
scalar_dtype = extract_ndarray_scalar_dtype(dst_type_info.np_number_type)
_ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)

def decode_vector(value: Any) -> Any | None:
if value is None:
Expand All @@ -197,66 +239,70 @@ def decode_vector(value: Any) -> Any | None:
f"expected {expected_dim}, got {len(value)}"
)

if elem_decoder is not None: # for Non-NDArray vector
return [elem_decoder(v) for v in value]
if vec_elem_decoder is not None: # for Non-NDArray vector
return [vec_elem_decoder(v) for v in value]
else: # for NDArray vector
return np.array(value, dtype=scalar_dtype)

return decode_vector

if dst_type_info.struct_type is not None:
return _make_engine_struct_value_decoder(
field_path, src_type["fields"], dst_type_info.struct_type
)

if src_type_kind in TABLE_TYPES:
field_path.append("[*]")
elem_type_info = analyze_type_info(dst_type_info.elem_type)
if elem_type_info.struct_type is None:
if isinstance(dst_type_variant, AnalyzedBasicType):
if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected"
)
engine_fields_schema = src_type["row"]["fields"]
if elem_type_info.key_type is not None:
key_field_schema = engine_fields_schema[0]
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
key_decoder = make_engine_value_decoder(
field_path, key_field_schema["type"], elem_type_info.key_type
)
field_path.pop()
value_decoder = _make_engine_struct_value_decoder(
field_path, engine_fields_schema[1:], elem_type_info.struct_type
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_variant.kind})"
)

def decode(value: Any) -> Any | None:
if value is None:
return None
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
else:
elem_decoder = _make_engine_struct_value_decoder(
field_path, engine_fields_schema, elem_type_info.struct_type
)
if dst_type_variant.kind in ("Float32", "Float64", "Int64"):
dst_core_type = dst_type_info.core_type

def decode(value: Any) -> Any | None:
def decode_scalar(value: Any) -> Any | None:
if value is None:
return None
return [elem_decoder(v) for v in value]
if dst_type_info.nullable:
return None
raise ValueError(
f"Received null for non-nullable scalar `{''.join(field_path)}`"
)
return dst_core_type(value)

field_path.pop()
return decode
return decode_scalar

return lambda value: value


def _make_engine_struct_value_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
dst_struct_type: type,
dst_type_info: AnalyzedTypeInfo,
) -> Callable[[list[Any]], Any]:
"""Make a decoder from an engine field values to a Python value."""

dst_type_variant = dst_type_info.variant

use_dict = False
if isinstance(dst_type_variant, AnalyzedAnyType):
use_dict = True
elif isinstance(dst_type_variant, AnalyzedDictType):
analyzed_key_type = analyze_type_info(dst_type_variant.key_type)
analyzed_value_type = analyze_type_info(dst_type_variant.value_type)
use_dict = (
isinstance(analyzed_key_type.variant, AnalyzedAnyType)
or (
isinstance(analyzed_key_type.variant, AnalyzedBasicType)
and analyzed_key_type.variant.kind == "Str"
)
) and isinstance(analyzed_value_type.variant, AnalyzedAnyType)
if use_dict:
return _make_engine_struct_to_dict_decoder(field_path, src_fields)

if not isinstance(dst_type_variant, AnalyzedStructType):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
)

src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
dst_struct_type = dst_type_variant.struct_type

parameters: Mapping[str, inspect.Parameter]
if dataclasses.is_dataclass(dst_struct_type):
Expand Down
2 changes: 1 addition & 1 deletion python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ class MixedStruct:
numpy_float: np.float64
python_float: float
string: str
annotated_int: Annotated[np.int64, TypeKind("int")]
annotated_int: Annotated[np.int64, TypeKind("Int64")]
annotated_float: Float32

instance = MixedStruct(
Expand Down
Loading
Loading