Skip to content

Commit

Permalink
start recursive ser/deseriaization tests, fix issues in the Array API
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed May 26, 2024
1 parent 5aa0c17 commit 94a60ae
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 39 deletions.
4 changes: 2 additions & 2 deletions src/zarr/abc/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@dataclass(frozen=True)
class Metadata:
def to_dict(self) -> JSON:
def to_dict(self) -> dict[str, JSON]:
"""
Recursively serialize this model to a dictionary.
This method inspects the fields of self and calls `x.to_dict()` for any fields that
Expand All @@ -37,7 +37,7 @@ def to_dict(self) -> JSON:
return out_dict

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
def from_dict(cls: type[Self], data: dict[str, JSON]) -> Self:
"""
Create an instance of the model from a dictionary
"""
Expand Down
18 changes: 8 additions & 10 deletions src/zarr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,13 @@ async def _create_v2(
return array

@classmethod
def from_dict(
cls,
store_path: StorePath,
data: dict[str, JSON],
async def from_dict(
cls, store_path: StorePath, data: dict[str, JSON], order: Literal["C", "F"] | None = None
) -> AsyncArray:
metadata = parse_array_metadata(data)
async_array = cls(metadata=metadata, store_path=store_path)
data_parsed = parse_array_metadata(data)
async_array = cls(metadata=data_parsed, store_path=store_path, order=order)
# weird that this method doesn't use the metadata attribute
await async_array._save_metadata(async_array.metadata)
return async_array

@classmethod
Expand Down Expand Up @@ -535,11 +535,9 @@ def create(

@classmethod
def from_dict(
cls,
store_path: StorePath,
data: dict[str, JSON],
cls, store_path: StorePath, data: dict[str, JSON], order: Literal["C", "F"] | None = None
) -> Array:
async_array = AsyncArray.from_dict(store_path=store_path, data=data)
async_array = sync(AsyncArray.from_dict(store_path=store_path, data=data))
return cls(async_array)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/codecs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def from_dict(cls, data: Iterable[JSON | Codec], *, batch_size: int | None = Non
out.append(get_codec_class(name_parsed).from_dict(c)) # type: ignore[arg-type]
return cls.from_list(out, batch_size=batch_size)

def to_dict(self) -> JSON:
def to_dict(self) -> list[JSON]:
return [c.to_dict() for c in self]

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def getitem(
if zarr_json["node_type"] == "group":
return type(self).from_dict(store_path, zarr_json)
elif zarr_json["node_type"] == "array":
return AsyncArray.from_dict(store_path, zarr_json)
return sync(AsyncArray.from_dict(store_path, zarr_json))
else:
raise ValueError(f"unexpected node_type: {zarr_json['node_type']}")
elif self.metadata.zarr_format == 2:
Expand All @@ -242,7 +242,7 @@ async def getitem(
if zarray is not None:
# TODO: update this once the V2 array support is part of the primary array class
zarr_json = {**zarray, "attributes": zattrs}
return AsyncArray.from_dict(store_path, zarray)
return sync(AsyncArray.from_dict(store_path, zarray))
else:
zgroup = (
json.loads(zgroup_bytes.to_bytes())
Expand Down
46 changes: 32 additions & 14 deletions src/zarr/hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from __future__ import annotations

from typing import Any
from dataclasses import dataclass, field

from typing_extensions import Self

Expand All @@ -28,23 +28,29 @@ class ArrayModel(ArrayV3Metadata):
"""

@classmethod
def from_stored(cls: type[Self], node: Array):
def from_stored(cls: type[Self], node: Array) -> Self:
"""
Create an array model from a stored array.
"""
return cls.from_dict(node.metadata.to_dict())

def to_stored(self, store_path: StorePath) -> Array:
def to_stored(self, store_path: StorePath, exists_ok: bool = False) -> Array:
"""
Create a stored version of this array.
"""
# exists_ok kwarg is unhandled until we wire it up to the
# array creation routines

return Array.from_dict(store_path=store_path, data=self.to_dict())


@dataclass(frozen=True)
class GroupModel(GroupMetadata):
"""
A model of a Zarr v3 group.
"""

members: dict[str, GroupModel | ArrayModel] | None

@classmethod
def from_dict(cls: type[Self], data: dict[str, Any]):
return cls(**data)
members: dict[str, GroupModel | ArrayModel] | None = field(default_factory=dict)

@classmethod
def from_stored(cls: type[Self], node: Group, *, depth: int | None = None) -> Self:
Expand All @@ -53,7 +59,7 @@ def from_stored(cls: type[Self], node: Group, *, depth: int | None = None) -> Se
controlled by the `depth` argument, which is either None (no depth limit) or a finite natural number
specifying how deep into the hierarchy to parse.
"""
members: dict[str, GroupModel | ArrayModel]
members: dict[str, GroupModel | ArrayModel] = {}

if depth is None:
new_depth = depth
Expand All @@ -64,16 +70,18 @@ def from_stored(cls: type[Self], node: Group, *, depth: int | None = None) -> Se
return cls(**node.metadata.to_dict(), members=None)

else:
for name, member in node.members():
for name, member in node.members:
item_out: ArrayModel | GroupModel
if isinstance(member, Array):
item_out = ArrayModel.from_stored(member)
else:
item_out = GroupModel.from_stored(member, depth=new_depth)

members[name] = item_out

return cls(**node.metadata.to_dict(), members=members)
return cls(attributes=node.metadata.attributes, members=members)

# todo: make this async
def to_stored(self, store_path: StorePath, *, exists_ok: bool = False) -> Group:
"""
Serialize this GroupModel to storage.
Expand All @@ -90,15 +98,18 @@ def to_stored(self, store_path: StorePath, *, exists_ok: bool = False) -> Group:
def to_flat(
node: ArrayModel | GroupModel, root_path: str = ""
) -> dict[str, ArrayModel | GroupModel]:
"""
Generate a dict representation of an ArrayModel or GroupModel, where the hierarchy structure
is represented by the keys of the dict.
"""
result = {}
model_copy: ArrayModel | GroupModel
node_dict = node.to_dict()
if isinstance(node, ArrayModel):
model_copy = ArrayModel(**node_dict)
else:
members = node_dict.pop("members")
model_copy = GroupModel(node_dict)
if members is not None:
model_copy = GroupModel(**node_dict)
if node.members is not None:
for name, value in node.members.items():
result.update(to_flat(value, "/".join([root_path, name])))

Expand All @@ -109,6 +120,9 @@ def to_flat(


def from_flat(data: dict[str, ArrayModel | GroupModel]) -> ArrayModel | GroupModel:
"""
Create a GroupModel or ArrayModel from a dict representation.
"""
# minimal check that the keys are valid
invalid_keys = []
for key in data.keys():
Expand All @@ -125,6 +139,10 @@ def from_flat(data: dict[str, ArrayModel | GroupModel]) -> ArrayModel | GroupMod


def from_flat_group(data: dict[str, ArrayModel | GroupModel]) -> GroupModel:
"""
Create a GroupModel from a hierarchy represented as a dict with string keys and ArrayModel
or GroupModel values.
"""
root_name = ""
sep = "/"
# arrays that will be members of the returned GroupModel
Expand Down
20 changes: 11 additions & 9 deletions src/zarr/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,18 +267,19 @@ def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]:
}

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> ArrayV3Metadata:
def from_dict(cls: type[Self], data: dict[str, JSON]) -> Self:
data_copy = data.copy()
# check that the zarr_format attribute is correct
_ = parse_zarr_format_v3(data.pop("zarr_format"))
_ = parse_zarr_format_v3(data_copy.pop("zarr_format"))
# check that the node_type attribute is correct
_ = parse_node_type_array(data.pop("node_type"))
_ = parse_node_type_array(data_copy.pop("node_type"))

data["dimension_names"] = data.pop("dimension_names", None)
data_copy["dimension_names"] = data_copy.pop("dimension_names", None)

# TODO: Remove the ignores and use a TypedDict to type `data`
return cls(**data) # type: ignore[arg-type]
return cls(**data_copy) # type: ignore[arg-type]

def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> dict[str, JSON]:
out_dict = super().to_dict()

if not isinstance(out_dict, dict):
Expand Down Expand Up @@ -391,11 +392,12 @@ def _json_convert(

@classmethod
def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
data_copy = data.copy()
# check that the zarr_format attribute is correct
_ = parse_zarr_format_v2(data.pop("zarr_format"))
return cls(**data)
_ = parse_zarr_format_v2(data_copy.pop("zarr_format"))
return cls(**data_copy)

def to_dict(self) -> JSON:
def to_dict(self) -> dict[str, JSON]:
zarray_dict = super().to_dict()

assert isinstance(zarray_dict, dict)
Expand Down
48 changes: 47 additions & 1 deletion tests/v3/test_hierarchy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import pytest

from zarr.array import Array
from zarr.chunk_grids import RegularChunkGrid
from zarr.chunk_key_encodings import DefaultChunkKeyEncoding
from zarr.group import GroupMetadata
from zarr.hierarchy import ArrayModel, GroupModel
from zarr.metadata import ArrayV3Metadata
from zarr.store.core import StorePath
from zarr.store.memory import MemoryStore


Expand Down Expand Up @@ -58,4 +61,47 @@ def test_groupmodel_from_dict() -> None:
assert model.to_dict() == {**group_meta.to_dict(), "members": None}


def test_groupmodel_to_stored(): ...
@pytest.mark.parametrize("attributes", ({}, {"foo": 100}))
@pytest.mark.parametrize(
"members",
(
None,
{},
{
"foo": ArrayModel(
shape=(100,),
data_type="uint8",
chunk_grid=RegularChunkGrid(chunk_shape=(10,)),
chunk_key_encoding=DefaultChunkKeyEncoding(),
fill_value=0,
attributes={"foo": 10},
),
"bar": GroupModel(
attributes={"name": "bar"},
members={
"subarray": ArrayModel(
shape=(100,),
data_type="uint8",
chunk_grid=RegularChunkGrid(chunk_shape=(10,)),
chunk_key_encoding=DefaultChunkKeyEncoding(),
fill_value=0,
attributes={"foo": 10},
)
},
),
},
),
)
def test_groupmodel_to_stored(
memory_store: MemoryStore,
attributes: dict[str, int],
members: None | dict[str, ArrayModel | GroupModel],
):
model = GroupModel(attributes=attributes, members=members)
group = model.to_stored(StorePath(memory_store, path="test"))
model_rt = GroupModel.from_stored(group)
assert model_rt.attributes == model.attributes
if members is not None:
assert model_rt.members == model.members
else:
assert model_rt.members == {}

0 comments on commit 94a60ae

Please sign in to comment.