Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 9 additions & 67 deletions python/cocoindex/engine_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import datetime
import dataclasses
from enum import Enum
from typing import Any, Mapping, TypeVar, overload, get_origin

Expand All @@ -24,18 +23,12 @@
analyze_type_info,
encode_enriched_type,
is_namedtuple_type,
is_pydantic_model,
extract_ndarray_elem_dtype,
)


T = TypeVar("T")

try:
import pydantic, pydantic_core
except ImportError:
pass


def get_auto_default_for_type(
type_info: AnalyzedTypeInfo,
Expand Down Expand Up @@ -175,67 +168,16 @@ def load_engine_object(expected_type: Any, v: Any) -> Any:
if isinstance(variant, AnalyzedStructType):
struct_type = variant.struct_type
init_kwargs: dict[str, Any] = {}
missing_fields: list[tuple[str, Any]] = []
if dataclasses.is_dataclass(struct_type):
if not isinstance(v, Mapping):
raise ValueError(f"Expected dict for dataclass, got {type(v)}")

for dc_field in dataclasses.fields(struct_type):
if dc_field.name in v:
init_kwargs[dc_field.name] = load_engine_object(
dc_field.type, v[dc_field.name]
)
else:
if (
dc_field.default is dataclasses.MISSING
and dc_field.default_factory is dataclasses.MISSING
):
missing_fields.append((dc_field.name, dc_field.type))

elif is_namedtuple_type(struct_type):
if not isinstance(v, Mapping):
raise ValueError(f"Expected dict for NamedTuple, got {type(v)}")
# Dict format (from dump/load functions)
annotations = getattr(struct_type, "__annotations__", {})
field_names = list(getattr(struct_type, "_fields", ()))
field_defaults = getattr(struct_type, "_field_defaults", {})

for name in field_names:
f_type = annotations.get(name, Any)
if name in v:
init_kwargs[name] = load_engine_object(f_type, v[name])
elif name not in field_defaults:
missing_fields.append((name, f_type))

elif is_pydantic_model(struct_type):
if not isinstance(v, Mapping):
raise ValueError(f"Expected dict for Pydantic model, got {type(v)}")

model_fields: dict[str, pydantic.fields.FieldInfo]
if hasattr(struct_type, "model_fields"):
model_fields = struct_type.model_fields # type: ignore[attr-defined]
for field_info in variant.fields:
if field_info.name in v:
init_kwargs[field_info.name] = load_engine_object(
field_info.type_hint, v[field_info.name]
)
else:
model_fields = {}

for name, pyd_field in model_fields.items():
if name in v:
init_kwargs[name] = load_engine_object(
pyd_field.annotation, v[name]
)
elif (
getattr(pyd_field, "default", pydantic_core.PydanticUndefined)
is pydantic_core.PydanticUndefined
and getattr(pyd_field, "default_factory") is None
):
missing_fields.append((name, pyd_field.annotation))
else:
assert False, "Unsupported struct type"

for name, f_type in missing_fields:
type_info = analyze_type_info(f_type)
auto_default, is_supported = get_auto_default_for_type(type_info)
if is_supported:
init_kwargs[name] = auto_default
type_info = analyze_type_info(field_info.type_hint)
auto_default, is_supported = get_auto_default_for_type(type_info)
if is_supported:
init_kwargs[field_info.name] = auto_default
return struct_type(**init_kwargs)

# Union with discriminator support via "kind"
Expand Down
147 changes: 30 additions & 117 deletions python/cocoindex/engine_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

from __future__ import annotations

import dataclasses
import inspect
import warnings
from typing import Any, Callable, Mapping, TypeVar
from typing import Any, Callable, TypeVar

import numpy as np
from .typing import (
Expand All @@ -19,8 +18,8 @@
AnalyzedTypeInfo,
AnalyzedUnionType,
AnalyzedUnknownType,
AnalyzedStructFieldInfo,
analyze_type_info,
is_namedtuple_type,
is_pydantic_model,
is_numpy_number_type,
ValueType,
Expand Down Expand Up @@ -124,69 +123,20 @@ def encode_struct_dict(value: Any) -> Any:
return encode_struct_dict

if isinstance(variant, AnalyzedStructType):
struct_type = variant.struct_type

if dataclasses.is_dataclass(struct_type):
fields = dataclasses.fields(struct_type)
field_encoders = [
make_engine_value_encoder(analyze_type_info(f.type)) for f in fields
]
field_names = [f.name for f in fields]

def encode_dataclass(value: Any) -> Any:
if value is None:
return None
return [
encoder(getattr(value, name))
for encoder, name in zip(field_encoders, field_names)
]

return encode_dataclass

elif is_namedtuple_type(struct_type):
annotations = struct_type.__annotations__
field_names = list(getattr(struct_type, "_fields", ()))
field_encoders = [
make_engine_value_encoder(
analyze_type_info(annotations[name])
if name in annotations
else ANY_TYPE_INFO
)
for name in field_names
]

def encode_namedtuple(value: Any) -> Any:
if value is None:
return None
return [
encoder(getattr(value, name))
for encoder, name in zip(field_encoders, field_names)
]

return encode_namedtuple

elif is_pydantic_model(struct_type):
# Type guard: ensure we have model_fields attribute
if hasattr(struct_type, "model_fields"):
field_names = list(struct_type.model_fields.keys()) # type: ignore[attr-defined]
field_encoders = [
make_engine_value_encoder(
analyze_type_info(struct_type.model_fields[name].annotation) # type: ignore[attr-defined]
)
for name in field_names
]
else:
raise ValueError(f"Invalid Pydantic model: {struct_type}")
field_encoders = [
(
field_info.name,
make_engine_value_encoder(analyze_type_info(field_info.type_hint)),
)
for field_info in variant.fields
]

def encode_pydantic(value: Any) -> Any:
if value is None:
return None
return [
encoder(getattr(value, name))
for encoder, name in zip(field_encoders, field_names)
]
def encode_struct(value: Any) -> Any:
if value is None:
return None
return [encoder(getattr(value, name)) for name, encoder in field_encoders]

return encode_pydantic
return encode_struct

def encode_basic_value(value: Any) -> Any:
if isinstance(value, np.number):
Expand Down Expand Up @@ -475,51 +425,12 @@ def make_engine_struct_decoder(
src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)}
dst_struct_type = dst_type_variant.struct_type

parameters: Mapping[str, inspect.Parameter]
if dataclasses.is_dataclass(dst_struct_type):
parameters = inspect.signature(dst_struct_type).parameters
elif is_namedtuple_type(dst_struct_type):
defaults = getattr(dst_struct_type, "_field_defaults", {})
fields = getattr(dst_struct_type, "_fields", ())
parameters = {
name: inspect.Parameter(
name=name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=defaults.get(name, inspect.Parameter.empty),
annotation=dst_struct_type.__annotations__.get(
name, inspect.Parameter.empty
),
)
for name in fields
}
elif is_pydantic_model(dst_struct_type):
# For Pydantic models, we can use model_fields to get field information
parameters = {}
# Type guard: ensure we have model_fields attribute
if hasattr(dst_struct_type, "model_fields"):
model_fields = dst_struct_type.model_fields # type: ignore[attr-defined]
else:
model_fields = {}
for name, field_info in model_fields.items():
default_value = (
field_info.default
if field_info.default is not ...
else inspect.Parameter.empty
)
parameters[name] = inspect.Parameter(
name=name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=default_value,
annotation=field_info.annotation,
)
else:
raise ValueError(f"Unsupported struct type: {dst_struct_type}")

def make_closure_for_field(
name: str, param: inspect.Parameter
field_info: AnalyzedStructFieldInfo,
) -> Callable[[list[Any]], Any]:
name = field_info.name
src_idx = src_name_to_idx.get(name)
type_info = analyze_type_info(param.annotation)
type_info = analyze_type_info(field_info.type_hint)

with ChildFieldPath(field_path, f".{name}"):
if src_idx is not None:
Expand All @@ -531,42 +442,44 @@ def make_closure_for_field(
)
return lambda values: field_decoder(values[src_idx])

default_value = param.default
default_value = field_info.default_value
if default_value is not inspect.Parameter.empty:
return lambda _: default_value

auto_default, is_supported = get_auto_default_for_type(type_info)
if is_supported:
warnings.warn(
f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
f"Field '{name}' (type {field_info.type_hint}) without default value is missing in input: "
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
UserWarning,
stacklevel=4,
)
return lambda _: auto_default

raise ValueError(
f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
f"Field '{name}' (type {field_info.type_hint}) without default value is missing in input: {''.join(field_path)}"
)

field_value_decoder = [
make_closure_for_field(name, param) for (name, param) in parameters.items()
]

# Different construction for different struct types
if is_pydantic_model(dst_struct_type):
# Pydantic models prefer keyword arguments
field_names = list(parameters.keys())
pydantic_fields_decoder = [
(field_info.name, make_closure_for_field(field_info))
for field_info in dst_type_variant.fields
]
return lambda values: dst_struct_type(
**{
field_names[i]: decoder(values)
for i, decoder in enumerate(field_value_decoder)
field_name: decoder(values)
for field_name, decoder in pydantic_fields_decoder
}
)
else:
struct_fields_decoder = [
make_closure_for_field(field_info) for field_info in dst_type_variant.fields
]
# Dataclasses and NamedTuples can use positional arguments
return lambda values: dst_struct_type(
*(decoder(values) for decoder in field_value_decoder)
*(decoder(values) for decoder in struct_fields_decoder)
)


Expand Down
1 change: 0 additions & 1 deletion python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
StructSchema,
StructType,
TableType,
TypeAttr,
encode_enriched_type_info,
resolve_forward_ref,
analyze_type_info,
Expand Down
32 changes: 32 additions & 0 deletions python/cocoindex/tests/test_engine_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,3 +1690,35 @@ class MixedStruct:
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
mixed = MixedStruct(name="test", pydantic_order=order)
validate_full_roundtrip(mixed, MixedStruct)


def test_forward_ref_in_dataclass() -> None:
"""Test mixing Pydantic models with dataclasses."""

@dataclass
class Event:
name: "str"
tag: "Tag"

validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event)


def test_forward_ref_in_namedtuple() -> None:
"""Test mixing Pydantic models with dataclasses."""

class Event(NamedTuple):
name: "str"
tag: "Tag"

validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event)


@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_forward_ref_in_pydantic() -> None:
"""Test mixing Pydantic models with dataclasses."""

class Event(BaseModel):
name: "str"
tag: "Tag"

validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event)
Loading
Loading