Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/manual_extraction/manual_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,21 @@ class ArgInfo:
@dataclasses.dataclass
class MethodInfo:
name: str
args: list[ArgInfo]
args: cocoindex.typing.List[ArgInfo]
description: str

@dataclasses.dataclass
class ClassInfo:
name: str
description: str
methods: list[MethodInfo]
methods: cocoindex.typing.List[MethodInfo]

@dataclasses.dataclass
class ManualInfo:
title: str
description: str
classes: list[ClassInfo]
methods: list[MethodInfo]
classes: cocoindex.typing.Table[ClassInfo]
methods: cocoindex.typing.Table[MethodInfo]


class ExtractManual(cocoindex.op.FunctionSpec):
Expand Down
40 changes: 34 additions & 6 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing
import collections
import dataclasses
from typing import Annotated, NamedTuple, Any
from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING

class Vector(NamedTuple):
dim: int | None
Expand All @@ -21,6 +21,28 @@ def __init__(self, key: str, value: Any):
Range = Annotated[tuple[int, int], TypeKind('Range')]
Json = Annotated[Any, TypeKind('Json')]

R = TypeVar("R")

if TYPE_CHECKING:
Table = Annotated[list[R], TypeKind('Table')]
List = Annotated[list[R], TypeKind('List')]
else:
# pylint: disable=too-few-public-methods
class Table: # type: ignore[unreachable]
"""
A Table type, which has a list of rows. The first field of each row is the key.
"""
def __class_getitem__(cls, item: type[R]):
return Annotated[list[item], TypeKind('Table')]

# pylint: disable=too-few-public-methods
class List: # type: ignore[unreachable]
"""
A List type, which has a list of ordered rows.
"""
def __class_getitem__(cls, item: type[R]):
return Annotated[list[item], TypeKind('List')]

def _find_annotation(metadata, cls):
for m in iter(metadata):
if isinstance(m, cls):
Expand All @@ -43,6 +65,7 @@ def _dump_fields_schema(cls: type) -> list[dict[str, Any]]:

def _dump_type(t, metadata):
origin_type = typing.get_origin(t)
type_kind = _find_annotation(metadata, TypeKind)
if origin_type is collections.abc.Sequence or origin_type is list:
args = typing.get_args(t)
elem_type, elem_type_metadata = _get_origin_type_and_metadata(args[0])
Expand All @@ -54,10 +77,16 @@ def _dump_type(t, metadata):
'dimension': vector_annot.dim,
}
elif dataclasses.is_dataclass(elem_type):
encoded_type = {
'kind': 'List',
'row': { 'fields': _dump_fields_schema(elem_type) },
}
if type_kind is not None and type_kind.kind == 'Table':
encoded_type = {
'kind': 'Table',
'row': { 'fields': _dump_fields_schema(elem_type) },
}
else:
encoded_type = {
'kind': 'List',
'row': { 'fields': _dump_fields_schema(elem_type) },
}
else:
raise ValueError(f"Unsupported type: {t}")
elif dataclasses.is_dataclass(t):
Expand All @@ -66,7 +95,6 @@ def _dump_type(t, metadata):
'fields': _dump_fields_schema(t),
}
else:
type_kind = _find_annotation(metadata, TypeKind)
if type_kind is not None:
kind = type_kind.kind
else:
Expand Down