diff --git a/examples/manual_extraction/manual_extraction.py b/examples/manual_extraction/manual_extraction.py index bb59cefb..e43927a2 100644 --- a/examples/manual_extraction/manual_extraction.py +++ b/examples/manual_extraction/manual_extraction.py @@ -38,21 +38,21 @@ class ArgInfo: @dataclasses.dataclass class MethodInfo: name: str - args: list[ArgInfo] + args: cocoindex.typing.List[ArgInfo] description: str @dataclasses.dataclass class ClassInfo: name: str description: str - methods: list[MethodInfo] + methods: cocoindex.typing.List[MethodInfo] @dataclasses.dataclass class ManualInfo: title: str description: str - classes: list[ClassInfo] - methods: list[MethodInfo] + classes: cocoindex.typing.Table[ClassInfo] + methods: cocoindex.typing.Table[MethodInfo] class ExtractManual(cocoindex.op.FunctionSpec): diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index bfd526ac..85e68a2a 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -1,7 +1,7 @@ import typing import collections import dataclasses -from typing import Annotated, NamedTuple, Any +from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING class Vector(NamedTuple): dim: int | None @@ -21,6 +21,28 @@ def __init__(self, key: str, value: Any): Range = Annotated[tuple[int, int], TypeKind('Range')] Json = Annotated[Any, TypeKind('Json')] +R = TypeVar("R") + +if TYPE_CHECKING: + Table = Annotated[list[R], TypeKind('Table')] + List = Annotated[list[R], TypeKind('List')] +else: + # pylint: disable=too-few-public-methods + class Table: # type: ignore[unreachable] + """ + A Table type, which has a list of rows. The first field of each row is the key. + """ + def __class_getitem__(cls, item: type[R]): + return Annotated[list[item], TypeKind('Table')] + + # pylint: disable=too-few-public-methods + class List: # type: ignore[unreachable] + """ + A List type, which has a list of ordered rows. + """ + def __class_getitem__(cls, item: type[R]): + return Annotated[list[item], TypeKind('List')] + def _find_annotation(metadata, cls): for m in iter(metadata): if isinstance(m, cls): @@ -43,6 +65,7 @@ def _dump_fields_schema(cls: type) -> list[dict[str, Any]]: def _dump_type(t, metadata): origin_type = typing.get_origin(t) + type_kind = _find_annotation(metadata, TypeKind) if origin_type is collections.abc.Sequence or origin_type is list: args = typing.get_args(t) elem_type, elem_type_metadata = _get_origin_type_and_metadata(args[0]) @@ -54,10 +77,16 @@ def _dump_type(t, metadata): 'dimension': vector_annot.dim, } elif dataclasses.is_dataclass(elem_type): - encoded_type = { - 'kind': 'List', - 'row': { 'fields': _dump_fields_schema(elem_type) }, - } + if type_kind is not None and type_kind.kind == 'Table': + encoded_type = { + 'kind': 'Table', + 'row': { 'fields': _dump_fields_schema(elem_type) }, + } + else: + encoded_type = { + 'kind': 'List', + 'row': { 'fields': _dump_fields_schema(elem_type) }, + } else: raise ValueError(f"Unsupported type: {t}") elif dataclasses.is_dataclass(t): @@ -66,7 +95,6 @@ def _dump_type(t, metadata): 'fields': _dump_fields_schema(t), } else: - type_kind = _find_annotation(metadata, TypeKind) if type_kind is not None: kind = type_kind.kind else: