diff --git a/CHANGELOG.md b/CHANGELOG.md index 39206df44..ef8e3508a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Write the date in place of the "Unreleased" in the case a new version is release ### Changed - Typehint utils collection implementations +- Typehint Structure types ## v0.1.0-b35 (2025-08-20) diff --git a/tiled/client/base.py b/tiled/client/base.py index b5d83cd7f..ccbd89a32 100644 --- a/tiled/client/base.py +++ b/tiled/client/base.py @@ -2,7 +2,7 @@ from copy import copy, deepcopy from dataclasses import asdict from pathlib import Path -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union from urllib.parse import parse_qs, urlparse import json_merge_patch @@ -11,6 +11,7 @@ from httpx import URL from tiled.client.context import Context +from tiled.structures.root import Structure from ..structures.core import STRUCTURE_TYPES, Spec, StructureFamily from ..structures.data_source import DataSource @@ -131,8 +132,8 @@ def __init__( *, item, structure_clients, - structure=None, - include_data_sources=False, + structure: Optional[Structure] = None, + include_data_sources: bool = False, ): self._context = context self._item = item diff --git a/tiled/server/schemas.py b/tiled/server/schemas.py index 0c4c82f3c..76cd2551a 100644 --- a/tiled/server/schemas.py +++ b/tiled/server/schemas.py @@ -1,15 +1,24 @@ from __future__ import annotations import enum +import json import uuid from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union -import pydantic.generics -from pydantic import ConfigDict, Field, StringConstraints +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StringConstraints, + ValidationInfo, + field_validator, +) from pydantic_core import PydanticCustomError from typing_extensions import Annotated, TypedDict +from tiled.structures.root import Structure + from ..structures.array import ArrayStructure from ..structures.awkward import AwkwardStructure from ..structures.core import STRUCTURE_TYPES, StructureFamily @@ -24,24 +33,24 @@ DataT = TypeVar("DataT") LinksT = TypeVar("LinksT") MetaT = TypeVar("MetaT") -StructureT = TypeVar("StructureT") +StructureT = TypeVar("StructureT", bound=Structure) MAX_ALLOWED_SPECS = 20 -class Error(pydantic.BaseModel): +class Error(BaseModel): code: int message: str -class Response(pydantic.BaseModel, Generic[DataT, LinksT, MetaT]): +class Response(BaseModel, Generic[DataT, LinksT, MetaT]): data: Optional[DataT] error: Optional[Error] = None links: Optional[LinksT] = None meta: Optional[MetaT] = None - @pydantic.field_validator("error") + @field_validator("error") def check_consistency(cls, v, values): if v is not None and values["data"] is not None: raise ValueError("must not provide both data and error") @@ -50,7 +59,7 @@ def check_consistency(cls, v, values): return v -class PaginationLinks(pydantic.BaseModel): +class PaginationLinks(BaseModel): self: str next: str prev: str @@ -70,11 +79,11 @@ class EntryFields(str, enum.Enum): access_blob = "access_blob" -class NodeStructure(pydantic.BaseModel): +class NodeStructure(BaseModel): contents: Optional[Dict[str, Any]] count: int - model_config = pydantic.ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid") class SortingDirection(int, enum.Enum): @@ -82,12 +91,12 @@ class SortingDirection(int, enum.Enum): DESCENDING = -1 -class SortingItem(pydantic.BaseModel): +class SortingItem(BaseModel): key: str direction: SortingDirection -class Spec(pydantic.BaseModel, extra="forbid", frozen=True): +class Spec(BaseModel, extra="forbid", frozen=True): name: Annotated[str, StringConstraints(max_length=255)] version: Optional[Annotated[str, StringConstraints(max_length=255)]] = None @@ -97,7 +106,7 @@ class Spec(pydantic.BaseModel, extra="forbid", frozen=True): Specs = Annotated[List[Spec], Field(max_length=MAX_ALLOWED_SPECS)] -class Asset(pydantic.BaseModel): +class Asset(BaseModel): data_uri: str is_directory: bool parameter: Optional[str] = None @@ -123,7 +132,7 @@ def from_assoc_orm(cls, orm): ) -class Revision(pydantic.BaseModel): +class Revision(BaseModel): revision_number: int metadata: dict specs: Specs @@ -143,7 +152,7 @@ def from_orm(cls, orm: tiled.catalog.orm.Revision) -> Revision: ) -class DataSource(pydantic.BaseModel, Generic[StructureT]): +class DataSource(BaseModel, Generic[StructureT]): id: Optional[int] = None structure_family: StructureFamily structure: Optional[StructureT] @@ -152,13 +161,12 @@ class DataSource(pydantic.BaseModel, Generic[StructureT]): assets: List[Asset] = [] management: Management = Management.writable - model_config = pydantic.ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid") @classmethod def from_orm(cls, orm: tiled.catalog.orm.DataSource) -> DataSource: if hasattr(orm.structure, "structure"): - structure_cls = STRUCTURE_TYPES[orm.structure_family] - structure = structure_cls.from_json(orm.structure.structure) + structure = orm.structure.structure else: structure = None return cls( @@ -171,8 +179,24 @@ def from_orm(cls, orm: tiled.catalog.orm.DataSource) -> DataSource: management=orm.management, ) - -class NodeAttributes(pydantic.BaseModel): + @field_validator("structure", mode="before") + @classmethod + def _coerce_structure_family( + cls, value: Any, info: ValidationInfo + ) -> Optional[StructureT]: + "Convert the structure on each data_source from a dict to the appropriate pydantic model." + if isinstance(value, str): + value = json.loads(value) + if isinstance(value, Structure): + return value + if isinstance(value, dict[str, Any]): + family: Optional[StructureFamily] = info.data.get("structure_family") + if family in STRUCTURE_TYPES: + return STRUCTURE_TYPES[family].from_json(value) + return None + + +class NodeAttributes(BaseModel): ancestors: List[str] structure_family: Optional[StructureFamily] = None specs: Optional[Specs] = None @@ -191,7 +215,7 @@ class NodeAttributes(pydantic.BaseModel): sorting: Optional[List[SortingItem]] = None data_sources: Optional[List[DataSource]] = None - model_config = pydantic.ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid") AttributesT = TypeVar("AttributesT") @@ -199,35 +223,35 @@ class NodeAttributes(pydantic.BaseModel): ResourceLinksT = TypeVar("ResourceLinksT") -class SelfLinkOnly(pydantic.BaseModel): +class SelfLinkOnly(BaseModel): self: str -class ContainerLinks(pydantic.BaseModel): +class ContainerLinks(BaseModel): self: str search: str full: str -class ArrayLinks(pydantic.BaseModel): +class ArrayLinks(BaseModel): self: str full: str block: str -class AwkwardLinks(pydantic.BaseModel): +class AwkwardLinks(BaseModel): self: str buffers: str full: str -class DataFrameLinks(pydantic.BaseModel): +class DataFrameLinks(BaseModel): self: str full: str partition: str -class SparseLinks(pydantic.BaseModel): +class SparseLinks(BaseModel): self: str full: str block: str @@ -243,15 +267,15 @@ class SparseLinks(pydantic.BaseModel): } -class EmptyDict(pydantic.BaseModel): +class EmptyDict(BaseModel): pass -class ContainerMeta(pydantic.BaseModel): +class ContainerMeta(BaseModel): count: int -class Resource(pydantic.BaseModel, Generic[AttributesT, ResourceLinksT, ResourceMetaT]): +class Resource(BaseModel, Generic[AttributesT, ResourceLinksT, ResourceMetaT]): "A JSON API Resource" id: Union[str, uuid.UUID] attributes: AttributesT @@ -259,7 +283,7 @@ class Resource(pydantic.BaseModel, Generic[AttributesT, ResourceLinksT, Resource meta: Optional[ResourceMetaT] = None -class AccessAndRefreshTokens(pydantic.BaseModel): +class AccessAndRefreshTokens(BaseModel): access_token: str expires_in: int refresh_token: str @@ -267,11 +291,11 @@ class AccessAndRefreshTokens(pydantic.BaseModel): token_type: str -class RefreshToken(pydantic.BaseModel): +class RefreshToken(BaseModel): refresh_token: str -class DeviceCode(pydantic.BaseModel): +class DeviceCode(BaseModel): device_code: str grant_type: str @@ -281,8 +305,8 @@ class PrincipalType(str, enum.Enum): service = "service" -class Identity(pydantic.BaseModel): - model_config = pydantic.ConfigDict(from_attributes=True) +class Identity(BaseModel): + model_config = ConfigDict(from_attributes=True) id: Annotated[str, StringConstraints(max_length=255)] provider: Annotated[str, StringConstraints(max_length=255)] latest_login: Optional[datetime] = None @@ -292,8 +316,8 @@ def from_orm(cls, orm: tiled.authn_database.orm.Identity) -> Identity: return cls(id=orm.id, provider=orm.provider, latest_login=orm.latest_login) -class Role(pydantic.BaseModel): - model_config = pydantic.ConfigDict(from_attributes=True) +class Role(BaseModel): + model_config = ConfigDict(from_attributes=True) name: str scopes: List[str] # principals @@ -303,8 +327,8 @@ def from_orm(cls, orm: tiled.authn_database.orm.Role) -> Role: return cls(name=orm.name, scopes=orm.scopes) -class APIKey(pydantic.BaseModel): - model_config = pydantic.ConfigDict(from_attributes=True) +class APIKey(BaseModel): + model_config = ConfigDict(from_attributes=True) first_eight: Annotated[str, StringConstraints(min_length=8, max_length=8)] expiration_time: Optional[datetime] = None note: Optional[Annotated[str, StringConstraints(max_length=255)]] = None @@ -339,7 +363,7 @@ def from_orm( ) -class Session(pydantic.BaseModel): +class Session(BaseModel): """ This related to refresh tokens, which have a session uuid ("sid") claim. @@ -350,7 +374,7 @@ class Session(pydantic.BaseModel): # The id field (primary key) is intentionally not exposed to the application. # It is left as an internal database concern. - model_config = pydantic.ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True) uuid: uuid.UUID expiration_time: datetime revoked: bool @@ -366,11 +390,11 @@ def from_orm(cls, orm: tiled.authn_database.orm.Session) -> Session: ) -class Principal(pydantic.BaseModel): +class Principal(BaseModel): "Represents a User or Service" # The id field (primary key) is intentionally not exposed to the application. # It is left as an internal database concern. - model_config = pydantic.ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True) uuid: uuid.UUID type: PrincipalType identities: List[Identity] = [] @@ -396,20 +420,18 @@ def from_orm( ) -class APIKeyRequestParams(pydantic.BaseModel): +class APIKeyRequestParams(BaseModel): # Provide an example for expires_in. Otherwise, OpenAPI suggests lifetime=0. # If the user is not reading carefully, they will be frustrated when they # try to use the instantly-expiring API key! - expires_in: Optional[int] = pydantic.Field( + expires_in: Optional[int] = Field( ..., json_schema_extra={"example": 600} ) # seconds - scopes: Optional[List[str]] = pydantic.Field( - ..., json_schema_extra={"example": ["inherit"]} - ) + scopes: Optional[List[str]] = Field(..., json_schema_extra={"example": ["inherit"]}) note: Optional[str] = None -class PostMetadataRequest(pydantic.BaseModel): +class PostMetadataRequest(BaseModel): id: Optional[str] = None structure_family: StructureFamily metadata: Dict = {} @@ -419,7 +441,7 @@ class PostMetadataRequest(pydantic.BaseModel): # Wait for fix https://github.com/pydantic/pydantic/issues/3957 # to do this with `unique_items` parameters to `pydantic.constr`. - @pydantic.field_validator("specs") + @field_validator("specs") def specs_uniqueness_validator(cls, v): if v is None: return None @@ -428,27 +450,12 @@ def specs_uniqueness_validator(cls, v): raise ValueError return v - @pydantic.model_validator(mode="after") - def narrow_structure_type(self): - "Convert the structure on each data_source from a dict to the appropriate pydantic model." - for data_source in self.data_sources: - if self.structure_family not in { - StructureFamily.container, - StructureFamily.composite, - }: - structure_cls = STRUCTURE_TYPES[self.structure_family] - if data_source.structure is not None: - data_source.structure = structure_cls.from_json( - data_source.structure - ) - return self - - -class PutDataSourceRequest(pydantic.BaseModel): + +class PutDataSourceRequest(BaseModel): data_source: DataSource -class PostMetadataResponse(pydantic.BaseModel, Generic[ResourceLinksT]): +class PostMetadataResponse(BaseModel, Generic[ResourceLinksT]): id: str links: Union[ArrayLinks, DataFrameLinks, SparseLinks] metadata: Dict @@ -456,7 +463,7 @@ class PostMetadataResponse(pydantic.BaseModel, Generic[ResourceLinksT]): access_blob: Dict -class PutMetadataResponse(pydantic.BaseModel, Generic[ResourceLinksT]): +class PutMetadataResponse(BaseModel, Generic[ResourceLinksT]): id: str links: Union[ArrayLinks, DataFrameLinks, SparseLinks] # May be None if not altered @@ -465,18 +472,18 @@ class PutMetadataResponse(pydantic.BaseModel, Generic[ResourceLinksT]): access_blob: Optional[Dict] = None -class DistinctValueInfo(pydantic.BaseModel): +class DistinctValueInfo(BaseModel): value: Any = None count: Optional[int] = None -class GetDistinctResponse(pydantic.BaseModel): +class GetDistinctResponse(BaseModel): metadata: Optional[Dict[str, List[DistinctValueInfo]]] = None structure_families: Optional[List[DistinctValueInfo]] = None specs: Optional[List[DistinctValueInfo]] = None -class PutMetadataRequest(pydantic.BaseModel): +class PutMetadataRequest(BaseModel): # These fields are optional because None means "no changes; do not update". metadata: Optional[Dict] = None specs: Optional[Specs] = None @@ -484,7 +491,7 @@ class PutMetadataRequest(pydantic.BaseModel): # Wait for fix https://github.com/pydantic/pydantic/issues/3957 # to do this with `unique_items` parameters to `pydantic.constr`. - @pydantic.field_validator("specs") + @field_validator("specs") def specs_uniqueness_validator(cls, v): if v is None: return None @@ -512,7 +519,7 @@ def JSONPatchType(dtype=Any): JSONPatchAny = JSONPatchType(Any) -class HyphenizedBaseModel(pydantic.BaseModel): +class HyphenizedBaseModel(BaseModel): # This model configuration allows aliases like "content-type" model_config = ConfigDict(alias_generator=lambda f: f.replace("_", "-")) @@ -537,7 +544,7 @@ class PatchMetadataRequest(HyphenizedBaseModel): alias="access_blob", default=None ) - @pydantic.field_validator("specs") + @field_validator("specs") def specs_uniqueness_validator(cls, v): if v is None: return None @@ -556,7 +563,7 @@ def specs_uniqueness_validator(cls, v): return v -class PatchMetadataResponse(pydantic.BaseModel, Generic[ResourceLinksT]): +class PatchMetadataResponse(BaseModel, Generic[ResourceLinksT]): id: str links: Union[ArrayLinks, DataFrameLinks, SparseLinks] # May be None if not altered diff --git a/tiled/structures/array.py b/tiled/structures/array.py index 83bb989fb..084cba24f 100644 --- a/tiled/structures/array.py +++ b/tiled/structures/array.py @@ -1,11 +1,18 @@ import enum import os import sys +from collections.abc import Mapping from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Any, ClassVar, List, Optional, Tuple, Union import numpy +from tiled.structures.root import Structure + +# from dtype.descr +FieldDescr = Union[Tuple[str, str], Tuple[str, str, Tuple[int, ...]]] +NumpyDescr = List[FieldDescr] + class Endianness(str, enum.Enum): """ @@ -57,7 +64,7 @@ class Kind(str, enum.Enum): object = "O" # Object (i.e. the memory contains a pointer to PyObject) @classmethod - def _missing_(cls, key): + def _missing_(cls, key: str): if key == "O": raise ObjectArrayTypeDisabled( "Numpy 'object'-type arrays are not enabled by default " @@ -78,17 +85,19 @@ class BuiltinDtype: itemsize: int dt_units: Optional[str] = None - __endianness_map = { + __endianness_map: ClassVar[Mapping[str, str]] = { ">": "big", "<": "little", "=": sys.byteorder, "|": "not_applicable", } - __endianness_reverse_map = {"big": ">", "little": "<", "not_applicable": "|"} + __endianness_reverse_map: ClassVar[Mapping[str, str]] = { + v: k for k, v in __endianness_map.items() if k != "=" + } @classmethod - def from_numpy_dtype(cls, dtype) -> "BuiltinDtype": + def from_numpy_dtype(cls, dtype: numpy.dtype) -> "BuiltinDtype": # Extract datetime units from the dtype string representation, # e.g. `' "BuiltinDtype": def to_numpy_dtype(self) -> numpy.dtype: return numpy.dtype(self.to_numpy_str()) - def to_numpy_str(self): + def to_numpy_str(self) -> str: endianness = self.__endianness_reverse_map[self.endianness] # dtype.itemsize always reports bytes. The format string from the # numeric types the string format is: {type_code}{byte_count} so we can @@ -125,7 +134,7 @@ def to_numpy_descr(self): return self.to_numpy_str() @classmethod - def from_json(cls, structure): + def from_json(cls, structure: Mapping[str, Any]) -> "BuiltinDtype": return cls( kind=Kind(structure["kind"]), itemsize=structure["itemsize"], @@ -141,7 +150,7 @@ class Field: shape: Optional[Tuple[int, ...]] @classmethod - def from_numpy_descr(cls, field): + def from_numpy_descr(cls, field: FieldDescr) -> "Field": name, *rest = field if name == "": raise ValueError( @@ -159,7 +168,7 @@ def from_numpy_descr(cls, field): FType = StructDtype.from_numpy_dtype(numpy.dtype(f_type)) return cls(name=name, dtype=FType, shape=shape) - def to_numpy_descr(self): + def to_numpy_descr(self) -> FieldDescr: if isinstance(self.dtype, BuiltinDtype): base = [self.name, self.dtype.to_numpy_str()] else: @@ -170,7 +179,7 @@ def to_numpy_descr(self): return tuple(base + [self.shape]) @classmethod - def from_json(cls, structure): + def from_json(cls, structure: Mapping[str, Any]) -> "Field": name = structure["name"] if "fields" in structure["dtype"]: ftype = StructDtype.from_json(structure["dtype"]) @@ -185,7 +194,7 @@ class StructDtype: fields: List[Field] @classmethod - def from_numpy_dtype(cls, dtype): + def from_numpy_dtype(cls, dtype: numpy.dtype) -> "StructDtype": # subdtypes push extra dimensions into arrays, we should handle these # a layer up and report an array with bigger dimensions. if dtype.subdtype is not None: @@ -198,20 +207,20 @@ def from_numpy_dtype(cls, dtype): fields=[Field.from_numpy_descr(f) for f in dtype.descr], ) - def to_numpy_dtype(self): + def to_numpy_dtype(self) -> numpy.dtype: return numpy.dtype(self.to_numpy_descr()) - def to_numpy_descr(self): + def to_numpy_descr(self) -> NumpyDescr: return [f.to_numpy_descr() for f in self.fields] - def max_depth(self): + def max_depth(self) -> int: return max( 1 if isinstance(f.dtype, BuiltinDtype) else 1 + f.dtype.max_depth() for f in self.fields ) @classmethod - def from_json(cls, structure): + def from_json(cls, structure: Mapping[str, Any]) -> "StructDtype": return cls( itemsize=structure["itemsize"], fields=[Field.from_json(f) for f in structure["fields"]], @@ -219,7 +228,7 @@ def from_json(cls, structure): @dataclass -class ArrayStructure: +class ArrayStructure(Structure): data_type: Union[BuiltinDtype, StructDtype] chunks: Tuple[Tuple[int, ...], ...] # tuple-of-tuples-of-ints like ((3,), (3,)) shape: Tuple[int, ...] # tuple of ints like (3, 3) @@ -227,7 +236,7 @@ class ArrayStructure: resizable: Union[bool, Tuple[bool, ...]] = False @classmethod - def from_json(cls, structure): + def from_json(cls, structure: Mapping[str, Any]) -> "ArrayStructure": if "fields" in structure["data_type"]: data_type = StructDtype.from_json(structure["data_type"]) else: @@ -273,7 +282,7 @@ def from_array(cls, array, shape=None, chunks=None, dims=None) -> "ArrayStructur data_type = StructDtype.from_numpy_dtype(array.dtype) else: data_type = BuiltinDtype.from_numpy_dtype(array.dtype) - return ArrayStructure( + return cls( data_type=data_type, shape=shape, chunks=normalized_chunks, diff --git a/tiled/structures/awkward.py b/tiled/structures/awkward.py index edc19286f..e10cdff58 100644 --- a/tiled/structures/awkward.py +++ b/tiled/structures/awkward.py @@ -1,19 +1,24 @@ from dataclasses import dataclass +from typing import Any, Iterable, Mapping, Optional import awkward +from tiled.structures.root import Structure + @dataclass -class AwkwardStructure: +class AwkwardStructure(Structure): length: int form: dict @classmethod - def from_json(cls, structure): + def from_json(cls, structure: Mapping[str, Any]) -> "AwkwardStructure": return cls(**structure) -def project_form(form, form_keys_touched): +def project_form( + form: awkward.forms.Form, form_keys_touched: Iterable[str] +) -> Optional[awkward.forms.Form]: # See https://github.com/bluesky/tiled/issues/450 if isinstance(form, awkward.forms.RecordForm): if form.fields is None: diff --git a/tiled/structures/core.py b/tiled/structures/core.py index 3af87bde1..4ca34865e 100644 --- a/tiled/structures/core.py +++ b/tiled/structures/core.py @@ -9,6 +9,8 @@ from dataclasses import asdict, dataclass from typing import Dict, Optional +from tiled.structures.root import Structure + from ..utils import OneShotCachedMap @@ -47,8 +49,7 @@ def dict(self) -> Dict[str, Optional[str]]: model_dump = dict # For easy interoperability with pydantic 2.x models -# TODO: make type[Structure] after #1036 -STRUCTURE_TYPES = OneShotCachedMap[StructureFamily, type]( +STRUCTURE_TYPES = OneShotCachedMap[StructureFamily, type[Structure]]( { StructureFamily.array: lambda: importlib.import_module( "...structures.array", StructureFamily.__module__ diff --git a/tiled/structures/data_source.py b/tiled/structures/data_source.py index cd5352a8e..d211fdfc7 100644 --- a/tiled/structures/data_source.py +++ b/tiled/structures/data_source.py @@ -1,6 +1,9 @@ import dataclasses import enum -from typing import Generic, List, Optional, TypeVar +from collections.abc import Mapping +from typing import Any, Generic, List, Optional, TypeVar + +from tiled.structures.root import Structure from .core import StructureFamily @@ -21,13 +24,13 @@ class Asset: id: Optional[int] = None -StructureT = TypeVar("StructureT") +StructureT = TypeVar("StructureT", bound=Structure) @dataclasses.dataclass class DataSource(Generic[StructureT]): structure_family: StructureFamily - structure: StructureT + structure: Optional[StructureT] id: Optional[int] = None mimetype: Optional[str] = None parameters: dict = dataclasses.field(default_factory=dict) @@ -35,7 +38,7 @@ class DataSource(Generic[StructureT]): management: Management = Management.writable @classmethod - def from_json(cls, d): - d = d.copy() + def from_json(cls, structure: Mapping[str, Any]) -> "DataSource": + d = structure.copy() assets = [Asset(**a) for a in d.pop("assets")] return cls(assets=assets, **d) diff --git a/tiled/structures/root.py b/tiled/structures/root.py new file mode 100644 index 000000000..27bf5a25b --- /dev/null +++ b/tiled/structures/root.py @@ -0,0 +1,12 @@ +import dataclasses +from abc import ABC +from collections.abc import Mapping +from typing import Any + + +@dataclasses.dataclass +class Structure(ABC): + @classmethod + # TODO: When dropping support for Python 3.10 replace with -> Self + def from_json(cls, structure: Mapping[str, Any]) -> "Structure": + ... diff --git a/tiled/structures/sparse.py b/tiled/structures/sparse.py index 49461a2c5..2d5ea5e18 100644 --- a/tiled/structures/sparse.py +++ b/tiled/structures/sparse.py @@ -1,6 +1,9 @@ import enum +from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union + +from tiled.structures.root import Structure from .array import BuiltinDtype, Endianness, Kind, StructDtype @@ -12,7 +15,7 @@ class SparseLayout(str, enum.Enum): @dataclass -class COOStructure: +class COOStructure(Structure): chunks: Tuple[Tuple[int, ...], ...] # tuple-of-tuples-of-ints like ((3,), (3,)) shape: Tuple[int, ...] # tuple of ints like (3, 3) data_type: Optional[Union[BuiltinDtype, StructDtype]] = None @@ -27,7 +30,7 @@ class COOStructure: # TODO Include fill_value? @classmethod - def from_json(cls, structure): + def from_json(cls, structure: Mapping[str, Any]) -> "COOStructure": data_type = structure.get("data_type", None) if data_type is not None and "fields" in data_type: data_type = StructDtype.from_json(data_type) diff --git a/tiled/structures/table.py b/tiled/structures/table.py index 797742650..2a197d1ae 100644 --- a/tiled/structures/table.py +++ b/tiled/structures/table.py @@ -1,15 +1,24 @@ import base64 import io +from collections.abc import Mapping from dataclasses import dataclass -from typing import List, Tuple, Union +from typing import Any, List, Tuple, Union import pyarrow +from tiled.structures.root import Structure + B64_ENCODED_PREFIX = "data:application/vnd.apache.arrow.file;base64," +def _uri_from_schema(pyarrow_schema: pyarrow.Schema) -> str: + schema_bytes = pyarrow_schema.serialize() + schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") + return B64_ENCODED_PREFIX + schema_b64 + + @dataclass -class TableStructure: +class TableStructure(Structure): # This holds a Arrow schema, base64-encoded so that it can be transported # as JSON. For clarity, the encoded data (...) is prefixed like: # @@ -33,68 +42,63 @@ def __post_init__(self): @classmethod def from_dask_dataframe(cls, ddf) -> "TableStructure": import dask.dataframe.utils - import pyarrow # Make a pandas Table with 0 rows. # We can use this to define an Arrow schema without loading any row data. meta = dask.dataframe.utils.make_meta(ddf) - schema_bytes = pyarrow.Table.from_pandas(meta).schema.serialize() - schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") - data_uri = B64_ENCODED_PREFIX + schema_b64 + schema = pyarrow.Table.from_pandas(meta).schema return cls( - arrow_schema=data_uri, + arrow_schema=_uri_from_schema(schema), npartitions=ddf.npartitions, columns=list(ddf.columns), ) @classmethod - def from_pandas(cls, df): - import pyarrow - - schema_bytes = pyarrow.Table.from_pandas(df).schema.serialize() - schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") - data_uri = B64_ENCODED_PREFIX + schema_b64 - return cls(arrow_schema=data_uri, npartitions=1, columns=list(df.columns)) + def from_pandas(cls, df) -> "TableStructure": + schema = pyarrow.Table.from_pandas(df).schema + return cls( + arrow_schema=_uri_from_schema(schema), + npartitions=1, + columns=list(df.columns), + ) @classmethod - def from_dict(cls, d): - import pyarrow - - schema_bytes = pyarrow.Table.from_pydict(d).schema.serialize() - schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") - data_uri = B64_ENCODED_PREFIX + schema_b64 - return cls(arrow_schema=data_uri, npartitions=1, columns=list(d.keys())) - - def from_arrays(cls, arr, names): - import pyarrow + def from_dict(cls, d: Mapping[str, Any]) -> "TableStructure": + schema = pyarrow.Table.from_pydict(d).schema + return cls( + arrow_schema=_uri_from_schema(schema), npartitions=1, columns=list(d.keys()) + ) - schema_bytes = pyarrow.Table.from_arrays(arr, names).schema.serialize() - schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") - data_uri = B64_ENCODED_PREFIX + schema_b64 - return cls(arrow_schema=data_uri, npartitions=1, columns=list(names)) + @classmethod + def from_arrays(cls, arr, names: List[str]) -> "TableStructure": + schema = pyarrow.Table.from_arrays(arr, names).schema + return cls( + arrow_schema=_uri_from_schema(schema), npartitions=1, columns=list(names) + ) @classmethod - def from_schema(cls, schema: pyarrow.Schema, npartitions: int = 1): - schema_bytes = schema.serialize() - schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") - data_uri = B64_ENCODED_PREFIX + schema_b64 - return cls(arrow_schema=data_uri, npartitions=npartitions, columns=schema.names) + def from_schema( + cls, schema: pyarrow.Schema, npartitions: int = 1 + ) -> "TableStructure": + return cls( + arrow_schema=_uri_from_schema(schema), + npartitions=npartitions, + columns=schema.names, + ) @classmethod - def from_arrow_table(cls, table, npartitions=1) -> "TableStructure": - schema_bytes = table.schema.serialize() - schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") - data_uri = B64_ENCODED_PREFIX + schema_b64 + def from_arrow_table( + cls, table: pyarrow.Table, npartitions: int = 1 + ) -> "TableStructure": + schema = table.schema return cls( - arrow_schema=data_uri, + arrow_schema=_uri_from_schema(schema), npartitions=npartitions, columns=list(table.column_names), ) @property - def arrow_schema_decoded(self): - import pyarrow - + def arrow_schema_decoded(self) -> pyarrow.Schema: if not self.arrow_schema.startswith(B64_ENCODED_PREFIX): raise ValueError( f"Expected base64-encoded data prefixed with {B64_ENCODED_PREFIX}." @@ -110,10 +114,10 @@ def meta(self): return self.arrow_schema_decoded.empty_table().to_pandas() @classmethod - def from_json(cls, content): + def from_json(cls, structure: Mapping[str, Any]) -> "TableStructure": return cls( - arrow_schema=content["arrow_schema"], - npartitions=content["npartitions"], - columns=content["columns"], - resizable=content["resizable"], + arrow_schema=structure["arrow_schema"], + npartitions=structure["npartitions"], + columns=structure["columns"], + resizable=structure["resizable"], )