diff --git a/python/cocoindex/engine_object.py b/python/cocoindex/engine_object.py index b371d801..52a684a5 100644 --- a/python/cocoindex/engine_object.py +++ b/python/cocoindex/engine_object.py @@ -5,7 +5,6 @@ from __future__ import annotations import datetime -import dataclasses from enum import Enum from typing import Any, Mapping, TypeVar, overload, get_origin @@ -24,18 +23,12 @@ analyze_type_info, encode_enriched_type, is_namedtuple_type, - is_pydantic_model, extract_ndarray_elem_dtype, ) T = TypeVar("T") -try: - import pydantic, pydantic_core -except ImportError: - pass - def get_auto_default_for_type( type_info: AnalyzedTypeInfo, @@ -175,67 +168,16 @@ def load_engine_object(expected_type: Any, v: Any) -> Any: if isinstance(variant, AnalyzedStructType): struct_type = variant.struct_type init_kwargs: dict[str, Any] = {} - missing_fields: list[tuple[str, Any]] = [] - if dataclasses.is_dataclass(struct_type): - if not isinstance(v, Mapping): - raise ValueError(f"Expected dict for dataclass, got {type(v)}") - - for dc_field in dataclasses.fields(struct_type): - if dc_field.name in v: - init_kwargs[dc_field.name] = load_engine_object( - dc_field.type, v[dc_field.name] - ) - else: - if ( - dc_field.default is dataclasses.MISSING - and dc_field.default_factory is dataclasses.MISSING - ): - missing_fields.append((dc_field.name, dc_field.type)) - - elif is_namedtuple_type(struct_type): - if not isinstance(v, Mapping): - raise ValueError(f"Expected dict for NamedTuple, got {type(v)}") - # Dict format (from dump/load functions) - annotations = getattr(struct_type, "__annotations__", {}) - field_names = list(getattr(struct_type, "_fields", ())) - field_defaults = getattr(struct_type, "_field_defaults", {}) - - for name in field_names: - f_type = annotations.get(name, Any) - if name in v: - init_kwargs[name] = load_engine_object(f_type, v[name]) - elif name not in field_defaults: - missing_fields.append((name, f_type)) - - elif is_pydantic_model(struct_type): - if not isinstance(v, Mapping): - raise ValueError(f"Expected dict for Pydantic model, got {type(v)}") - - model_fields: dict[str, pydantic.fields.FieldInfo] - if hasattr(struct_type, "model_fields"): - model_fields = struct_type.model_fields # type: ignore[attr-defined] + for field_info in variant.fields: + if field_info.name in v: + init_kwargs[field_info.name] = load_engine_object( + field_info.type_hint, v[field_info.name] + ) else: - model_fields = {} - - for name, pyd_field in model_fields.items(): - if name in v: - init_kwargs[name] = load_engine_object( - pyd_field.annotation, v[name] - ) - elif ( - getattr(pyd_field, "default", pydantic_core.PydanticUndefined) - is pydantic_core.PydanticUndefined - and getattr(pyd_field, "default_factory") is None - ): - missing_fields.append((name, pyd_field.annotation)) - else: - assert False, "Unsupported struct type" - - for name, f_type in missing_fields: - type_info = analyze_type_info(f_type) - auto_default, is_supported = get_auto_default_for_type(type_info) - if is_supported: - init_kwargs[name] = auto_default + type_info = analyze_type_info(field_info.type_hint) + auto_default, is_supported = get_auto_default_for_type(type_info) + if is_supported: + init_kwargs[field_info.name] = auto_default return struct_type(**init_kwargs) # Union with discriminator support via "kind" diff --git a/python/cocoindex/engine_value.py b/python/cocoindex/engine_value.py index fb027c0d..31f16843 100644 --- a/python/cocoindex/engine_value.py +++ b/python/cocoindex/engine_value.py @@ -4,10 +4,9 @@ from __future__ import annotations -import dataclasses import inspect import warnings -from typing import Any, Callable, Mapping, TypeVar +from typing import Any, Callable, TypeVar import numpy as np from .typing import ( @@ -19,8 +18,8 @@ AnalyzedTypeInfo, AnalyzedUnionType, AnalyzedUnknownType, + AnalyzedStructFieldInfo, analyze_type_info, - is_namedtuple_type, is_pydantic_model, is_numpy_number_type, ValueType, @@ -124,69 +123,20 @@ def encode_struct_dict(value: Any) -> Any: return encode_struct_dict if isinstance(variant, AnalyzedStructType): - struct_type = variant.struct_type - - if dataclasses.is_dataclass(struct_type): - fields = dataclasses.fields(struct_type) - field_encoders = [ - make_engine_value_encoder(analyze_type_info(f.type)) for f in fields - ] - field_names = [f.name for f in fields] - - def encode_dataclass(value: Any) -> Any: - if value is None: - return None - return [ - encoder(getattr(value, name)) - for encoder, name in zip(field_encoders, field_names) - ] - - return encode_dataclass - - elif is_namedtuple_type(struct_type): - annotations = struct_type.__annotations__ - field_names = list(getattr(struct_type, "_fields", ())) - field_encoders = [ - make_engine_value_encoder( - analyze_type_info(annotations[name]) - if name in annotations - else ANY_TYPE_INFO - ) - for name in field_names - ] - - def encode_namedtuple(value: Any) -> Any: - if value is None: - return None - return [ - encoder(getattr(value, name)) - for encoder, name in zip(field_encoders, field_names) - ] - - return encode_namedtuple - - elif is_pydantic_model(struct_type): - # Type guard: ensure we have model_fields attribute - if hasattr(struct_type, "model_fields"): - field_names = list(struct_type.model_fields.keys()) # type: ignore[attr-defined] - field_encoders = [ - make_engine_value_encoder( - analyze_type_info(struct_type.model_fields[name].annotation) # type: ignore[attr-defined] - ) - for name in field_names - ] - else: - raise ValueError(f"Invalid Pydantic model: {struct_type}") + field_encoders = [ + ( + field_info.name, + make_engine_value_encoder(analyze_type_info(field_info.type_hint)), + ) + for field_info in variant.fields + ] - def encode_pydantic(value: Any) -> Any: - if value is None: - return None - return [ - encoder(getattr(value, name)) - for encoder, name in zip(field_encoders, field_names) - ] + def encode_struct(value: Any) -> Any: + if value is None: + return None + return [encoder(getattr(value, name)) for name, encoder in field_encoders] - return encode_pydantic + return encode_struct def encode_basic_value(value: Any) -> Any: if isinstance(value, np.number): @@ -475,51 +425,12 @@ def make_engine_struct_decoder( src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)} dst_struct_type = dst_type_variant.struct_type - parameters: Mapping[str, inspect.Parameter] - if dataclasses.is_dataclass(dst_struct_type): - parameters = inspect.signature(dst_struct_type).parameters - elif is_namedtuple_type(dst_struct_type): - defaults = getattr(dst_struct_type, "_field_defaults", {}) - fields = getattr(dst_struct_type, "_fields", ()) - parameters = { - name: inspect.Parameter( - name=name, - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=defaults.get(name, inspect.Parameter.empty), - annotation=dst_struct_type.__annotations__.get( - name, inspect.Parameter.empty - ), - ) - for name in fields - } - elif is_pydantic_model(dst_struct_type): - # For Pydantic models, we can use model_fields to get field information - parameters = {} - # Type guard: ensure we have model_fields attribute - if hasattr(dst_struct_type, "model_fields"): - model_fields = dst_struct_type.model_fields # type: ignore[attr-defined] - else: - model_fields = {} - for name, field_info in model_fields.items(): - default_value = ( - field_info.default - if field_info.default is not ... - else inspect.Parameter.empty - ) - parameters[name] = inspect.Parameter( - name=name, - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=default_value, - annotation=field_info.annotation, - ) - else: - raise ValueError(f"Unsupported struct type: {dst_struct_type}") - def make_closure_for_field( - name: str, param: inspect.Parameter + field_info: AnalyzedStructFieldInfo, ) -> Callable[[list[Any]], Any]: + name = field_info.name src_idx = src_name_to_idx.get(name) - type_info = analyze_type_info(param.annotation) + type_info = analyze_type_info(field_info.type_hint) with ChildFieldPath(field_path, f".{name}"): if src_idx is not None: @@ -531,14 +442,14 @@ def make_closure_for_field( ) return lambda values: field_decoder(values[src_idx]) - default_value = param.default + default_value = field_info.default_value if default_value is not inspect.Parameter.empty: return lambda _: default_value auto_default, is_supported = get_auto_default_for_type(type_info) if is_supported: warnings.warn( - f"Field '{name}' (type {param.annotation}) without default value is missing in input: " + f"Field '{name}' (type {field_info.type_hint}) without default value is missing in input: " f"{''.join(field_path)}. Auto-assigning default value: {auto_default}", UserWarning, stacklevel=4, @@ -546,27 +457,29 @@ def make_closure_for_field( return lambda _: auto_default raise ValueError( - f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}" + f"Field '{name}' (type {field_info.type_hint}) without default value is missing in input: {''.join(field_path)}" ) - field_value_decoder = [ - make_closure_for_field(name, param) for (name, param) in parameters.items() - ] - # Different construction for different struct types if is_pydantic_model(dst_struct_type): # Pydantic models prefer keyword arguments - field_names = list(parameters.keys()) + pydantic_fields_decoder = [ + (field_info.name, make_closure_for_field(field_info)) + for field_info in dst_type_variant.fields + ] return lambda values: dst_struct_type( **{ - field_names[i]: decoder(values) - for i, decoder in enumerate(field_value_decoder) + field_name: decoder(values) + for field_name, decoder in pydantic_fields_decoder } ) else: + struct_fields_decoder = [ + make_closure_for_field(field_info) for field_info in dst_type_variant.fields + ] # Dataclasses and NamedTuples can use positional arguments return lambda values: dst_struct_type( - *(decoder(values) for decoder in field_value_decoder) + *(decoder(values) for decoder in struct_fields_decoder) ) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 3a15fd5a..781349dd 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -37,7 +37,6 @@ StructSchema, StructType, TableType, - TypeAttr, encode_enriched_type_info, resolve_forward_ref, analyze_type_info, diff --git a/python/cocoindex/tests/test_engine_value.py b/python/cocoindex/tests/test_engine_value.py index ec603c13..a75f775d 100644 --- a/python/cocoindex/tests/test_engine_value.py +++ b/python/cocoindex/tests/test_engine_value.py @@ -1690,3 +1690,35 @@ class MixedStruct: order = OrderPydantic(order_id="O1", name="item1", price=10.0) mixed = MixedStruct(name="test", pydantic_order=order) validate_full_roundtrip(mixed, MixedStruct) + + +def test_forward_ref_in_dataclass() -> None: + """Test mixing Pydantic models with dataclasses.""" + + @dataclass + class Event: + name: "str" + tag: "Tag" + + validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event) + + +def test_forward_ref_in_namedtuple() -> None: + """Test mixing Pydantic models with dataclasses.""" + + class Event(NamedTuple): + name: "str" + tag: "Tag" + + validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event) + + +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") +def test_forward_ref_in_pydantic() -> None: + """Test mixing Pydantic models with dataclasses.""" + + class Event(BaseModel): + name: "str" + tag: "Tag" + + validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event) diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index c4b0ef60..fea72571 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -10,12 +10,14 @@ Annotated, Any, Generic, + Iterator, Literal, NamedTuple, Protocol, TypeVar, overload, Self, + get_type_hints, ) import numpy as np @@ -192,6 +194,17 @@ class AnalyzedListType(NamedTuple): vector_info: VectorInfo | None +class AnalyzedStructFieldInfo(NamedTuple): + """ + Info about a field in a struct type. + """ + + name: str + type_hint: Any + default_value: Any + description: str | None + + class AnalyzedStructType(NamedTuple): """ Any struct type, e.g. dataclass, NamedTuple, etc. @@ -199,6 +212,42 @@ class AnalyzedStructType(NamedTuple): struct_type: type + @property + def fields(self) -> Iterator[AnalyzedStructFieldInfo]: + type_hints = get_type_hints(self.struct_type, include_extras=True) + if dataclasses.is_dataclass(self.struct_type): + parameters = inspect.signature(self.struct_type).parameters + for name, parameter in parameters.items(): + yield AnalyzedStructFieldInfo( + name=name, + type_hint=type_hints.get(name, Any), + default_value=parameter.default, + description=None, + ) + elif is_namedtuple_type(self.struct_type): + fields = getattr(self.struct_type, "_fields", ()) + defaults = getattr(self.struct_type, "_field_defaults", {}) + for name in fields: + yield AnalyzedStructFieldInfo( + name=name, + type_hint=type_hints.get(name, Any), + default_value=defaults.get(name, inspect.Parameter.empty), + description=None, + ) + elif is_pydantic_model(self.struct_type): + model_fields = getattr(self.struct_type, "model_fields", {}) + for name, field_info in model_fields.items(): + yield AnalyzedStructFieldInfo( + name=name, + type_hint=type_hints.get(name, Any), + default_value=field_info.default + if field_info.default is not ... + else inspect.Parameter.empty, + description=field_info.description, + ) + else: + raise ValueError(f"Unsupported struct type: {self.struct_type}") + class AnalyzedUnionType(NamedTuple): """ @@ -355,7 +404,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: def _encode_struct_schema( - struct_type: type, key_type: type | None = None + struct_info: AnalyzedStructType, key_type: type | None = None ) -> tuple[dict[str, Any], int | None]: fields = [] @@ -367,7 +416,7 @@ def add_field( except ValueError as e: e.add_note( f"Failed to encode annotation for field - " - f"{struct_type.__name__}.{name}: {analyzed_type.core_type}" + f"{struct_info.struct_type.__name__}.{name}: {analyzed_type.core_type}" ) raise type_info["name"] = name @@ -375,26 +424,9 @@ def add_field( type_info["description"] = description fields.append(type_info) - def add_fields_from_struct(struct_type: type) -> None: - if dataclasses.is_dataclass(struct_type): - for field in dataclasses.fields(struct_type): - add_field(field.name, analyze_type_info(field.type)) - elif is_namedtuple_type(struct_type): - for name, field_type in struct_type.__annotations__.items(): - add_field(name, analyze_type_info(field_type)) - elif is_pydantic_model(struct_type): - # Type guard: ensure we have pydantic available and struct_type has model_fields - if hasattr(struct_type, "model_fields"): - for name, field_info in struct_type.model_fields.items(): # type: ignore[attr-defined] - # Get the annotation from the field info - field_type = field_info.annotation - # Extract description from Pydantic field info - description = getattr(field_info, "description", None) - add_field(name, analyze_type_info(field_type), description) - else: - raise ValueError(f"Invalid Pydantic model: {struct_type}") - else: - raise ValueError(f"Unsupported struct type: {struct_type}") + def add_fields_from_struct(struct_info: AnalyzedStructType) -> None: + for field in struct_info.fields: + add_field(field.name, analyze_type_info(field.type_hint), field.description) result: dict[str, Any] = {} num_key_parts = None @@ -404,15 +436,15 @@ def add_fields_from_struct(struct_type: type) -> None: add_field(KEY_FIELD_NAME, key_type_info) num_key_parts = 1 elif isinstance(key_type_info.variant, AnalyzedStructType): - add_fields_from_struct(key_type_info.variant.struct_type) + add_fields_from_struct(key_type_info.variant) num_key_parts = len(fields) else: raise ValueError(f"Unsupported key type: {key_type}") - add_fields_from_struct(struct_type) + add_fields_from_struct(struct_info) result["fields"] = fields - if doc := inspect.getdoc(struct_type): + if doc := inspect.getdoc(struct_info): result["description"] = doc return result, num_key_parts @@ -430,7 +462,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: return {"kind": variant.kind} if isinstance(variant, AnalyzedStructType): - encoded_type, _ = _encode_struct_schema(variant.struct_type) + encoded_type, _ = _encode_struct_schema(variant) encoded_type["kind"] = "Struct" return encoded_type @@ -440,7 +472,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: if isinstance(elem_type_info.variant, AnalyzedStructType): if variant.vector_info is not None: raise ValueError("LTable type must not have a vector info") - row_type, _ = _encode_struct_schema(elem_type_info.variant.struct_type) + row_type, _ = _encode_struct_schema(elem_type_info.variant) return {"kind": "LTable", "row": row_type} else: vector_info = variant.vector_info @@ -457,7 +489,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: f"KTable value must have a Struct type, got {value_type_info.core_type}" ) row_type, num_key_parts = _encode_struct_schema( - value_type_info.variant.struct_type, + value_type_info.variant, variant.key_type, ) return {