Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing
import collections
import dataclasses
import types
from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING

class Vector(NamedTuple):
Expand Down Expand Up @@ -59,16 +60,31 @@ class AnalyzedTypeInfo:
elem_type: type | None
struct_fields: tuple[dataclasses.Field, ...] | None
attrs: dict[str, Any] | None
nullable: bool = False

def analyze_type_info(t) -> AnalyzedTypeInfo:
"""
Analyze a Python type and return the analyzed info.
"""
annotations: tuple[Annotation, ...] = ()
if typing.get_origin(t) is Annotated:
annotations = t.__metadata__
t = t.__origin__
base_type = typing.get_origin(t)
base_type = None
nullable = False
while True:
base_type = typing.get_origin(t)
if base_type is Annotated:
annotations = t.__metadata__
t = t.__origin__
elif base_type is types.UnionType:
possible_types = typing.get_args(t)
non_none_types = [arg for arg in possible_types if arg not in (None, types.NoneType)]
if len(non_none_types) != 1:
raise ValueError(
f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}")
t = non_none_types[0]
if len(possible_types) > 1:
nullable = True
else:
break

attrs = None
vector_info = None
Expand Down Expand Up @@ -118,7 +134,7 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
raise ValueError(f"type unsupported yet: {base_type}")

return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type,
struct_fields=struct_fields, attrs=attrs)
struct_fields=struct_fields, attrs=attrs, nullable=nullable)

def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
encoded_type: dict[str, Any] = { 'kind': type_info.kind }
Expand Down Expand Up @@ -150,9 +166,15 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:

def _encode_enriched_type(t) -> dict[str, Any]:
enriched_type_info = analyze_type_info(t)
encoded = {'type': _encode_type(enriched_type_info)}

encoded: dict[str, Any] = {'type': _encode_type(enriched_type_info)}

if enriched_type_info.attrs is not None:
encoded['attrs'] = enriched_type_info.attrs

if enriched_type_info.nullable:
encoded['nullable'] = True

return encoded


Expand Down