diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 2eacad9f..f640af43 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -1,6 +1,7 @@ import typing import collections import dataclasses +import types from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING class Vector(NamedTuple): @@ -59,16 +60,31 @@ class AnalyzedTypeInfo: elem_type: type | None struct_fields: tuple[dataclasses.Field, ...] | None attrs: dict[str, Any] | None + nullable: bool = False def analyze_type_info(t) -> AnalyzedTypeInfo: """ Analyze a Python type and return the analyzed info. """ annotations: tuple[Annotation, ...] = () - if typing.get_origin(t) is Annotated: - annotations = t.__metadata__ - t = t.__origin__ - base_type = typing.get_origin(t) + base_type = None + nullable = False + while True: + base_type = typing.get_origin(t) + if base_type is Annotated: + annotations = t.__metadata__ + t = t.__origin__ + elif base_type is types.UnionType: + possible_types = typing.get_args(t) + non_none_types = [arg for arg in possible_types if arg not in (None, types.NoneType)] + if len(non_none_types) != 1: + raise ValueError( + f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}") + t = non_none_types[0] + if len(possible_types) > 1: + nullable = True + else: + break attrs = None vector_info = None @@ -118,7 +134,7 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: raise ValueError(f"type unsupported yet: {base_type}") return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type, - struct_fields=struct_fields, attrs=attrs) + struct_fields=struct_fields, attrs=attrs, nullable=nullable) def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: encoded_type: dict[str, Any] = { 'kind': type_info.kind } @@ -150,9 +166,15 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: def _encode_enriched_type(t) -> dict[str, Any]: enriched_type_info = analyze_type_info(t) - encoded = {'type': _encode_type(enriched_type_info)} + + encoded: dict[str, Any] = {'type': _encode_type(enriched_type_info)} + if enriched_type_info.attrs is not None: encoded['attrs'] = enriched_type_info.attrs + + if enriched_type_info.nullable: + encoded['nullable'] = True + return encoded