From b996aff13996fd172f85ddd5554737149221f7c2 Mon Sep 17 00:00:00 2001 From: Davis Bennett Date: Fri, 10 May 2024 10:34:41 +0200 Subject: [PATCH] fix: define utility for converting asyncarray to array, and similar for group, largely to appease mypy --- src/zarr/group.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/zarr/group.py b/src/zarr/group.py index f9dbea3a7..055ef5327 100644 --- a/src/zarr/group.py +++ b/src/zarr/group.py @@ -20,6 +20,7 @@ from zarr.config import RuntimeConfiguration, SyncConfiguration from zarr.store import StoreLike, StorePath, make_store_path from zarr.sync import SyncMixin, sync +from typing import overload logger = logging.getLogger("zarr.group") @@ -41,6 +42,26 @@ def parse_attributes(data: Any) -> dict[str, Any]: raise TypeError(msg) +@overload +def _parse_async_node(node: AsyncArray) -> Array: ... + + +@overload +def _parse_async_node(node: AsyncGroup) -> Group: ... + + +def _parse_async_node(node: AsyncArray | AsyncGroup) -> Array | Group: + """ + Wrap an AsyncArray in an Array, or an AsyncGroup in a Group. + """ + if isinstance(node, AsyncArray): + return Array(node) + elif isinstance(node, Group): + return Group(node) + else: + assert False + + @dataclass(frozen=True) class GroupMetadata(Metadata): attributes: dict[str, Any] = field(default_factory=dict) @@ -509,11 +530,10 @@ def members(self) -> tuple[tuple[str, Array | Group], ...]: Return the sub-arrays and sub-groups of this group as a tuple of (name, array | group) pairs """ - _members: list[AsyncArray | AsyncGroup] = self._sync_iter(self._async_group.members()) - return tuple( - (key, Array(value)) if isinstance(value, AsyncArray) else (key, Group(value)) - for key, value in _members - ) + _members = self._sync_iter(self._async_group.members()) + + result = tuple(map(lambda kv: (kv[0], _parse_async_node(kv[1])), _members)) + return result def __contains__(self, member) -> bool: return self._sync(self._async_group.contains(member))