From b1f4c509abaee1cb8dec18e3a973e1199226011a Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 22 May 2024 17:09:18 +0200 Subject: [PATCH] Protocols for `Buffer` and `NDBuffer` (#1899) --- src/zarr/array.py | 4 +- src/zarr/buffer.py | 135 ++++++++++++++++++++++---------------- src/zarr/codecs/bytes.py | 16 ++++- src/zarr/testing/store.py | 8 ++- src/zarr/testing/utils.py | 18 +++++ tests/v3/conftest.py | 9 +++ tests/v3/test_buffer.py | 15 +++-- tests/v3/test_codecs.py | 23 +++---- tests/v3/test_store.py | 5 +- 9 files changed, 151 insertions(+), 82 deletions(-) create mode 100644 src/zarr/testing/utils.py diff --git a/src/zarr/array.py b/src/zarr/array.py index 2828e25119..7da39c285e 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -582,12 +582,12 @@ def store_path(self) -> StorePath: def order(self) -> Literal["C", "F"]: return self._async_array.order - def __getitem__(self, selection: Selection) -> npt.NDArray[Any]: + def __getitem__(self, selection: Selection) -> NDArrayLike: return sync( self._async_array.getitem(selection), ) - def __setitem__(self, selection: Selection, value: npt.NDArray[Any]) -> None: + def __setitem__(self, selection: Selection, value: NDArrayLike) -> None: sync( self._async_array.setitem(selection, value), ) diff --git a/src/zarr/buffer.py b/src/zarr/buffer.py index e9aa1120f8..0f055093c1 100644 --- a/src/zarr/buffer.py +++ b/src/zarr/buffer.py @@ -1,28 +1,94 @@ from __future__ import annotations import sys -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence from typing import ( TYPE_CHECKING, Any, Literal, Protocol, - TypeAlias, + SupportsIndex, + runtime_checkable, ) import numpy as np import numpy.typing as npt +from zarr.common import ChunkCoords + if TYPE_CHECKING: from typing_extensions import Self from zarr.codecs.bytes import Endian from zarr.common import BytesLike -# TODO: create a protocol for the attributes we need, for now we alias Numpy's ndarray -# both for the array-like and ndarray-like -ArrayLike: TypeAlias = npt.NDArray[Any] -NDArrayLike: TypeAlias = npt.NDArray[Any] + +@runtime_checkable +class ArrayLike(Protocol): + """Protocol for the array-like type that underlie Buffer""" + + @property + def dtype(self) -> np.dtype[Any]: ... + + @property + def ndim(self) -> int: ... + + @property + def size(self) -> int: ... + + def __getitem__(self, key: slice) -> Self: ... + + def __setitem__(self, key: slice, value: Any) -> None: ... + + +@runtime_checkable +class NDArrayLike(Protocol): + """Protocol for the nd-array-like type that underlie NDBuffer""" + + @property + def dtype(self) -> np.dtype[Any]: ... + + @property + def ndim(self) -> int: ... + + @property + def size(self) -> int: ... + + @property + def shape(self) -> ChunkCoords: ... + + def __len__(self) -> int: ... + + def __getitem__(self, key: slice) -> Self: ... + + def __setitem__(self, key: slice, value: Any) -> None: ... + + def reshape(self, shape: ChunkCoords, *, order: Literal["A", "C", "F"] = ...) -> Self: ... + + def view(self, dtype: npt.DTypeLike) -> Self: ... + + def astype(self, dtype: npt.DTypeLike, order: Literal["K", "A", "C", "F"] = ...) -> Self: ... + + def fill(self, value: Any) -> None: ... + + def copy(self) -> Self: ... + + def transpose(self, axes: SupportsIndex | Sequence[SupportsIndex] | None) -> Self: ... + + def ravel(self, order: Literal["K", "A", "C", "F"] = "C") -> Self: ... + + def all(self) -> bool: ... + + def __eq__(self, other: Any) -> Self: # type: ignore + """Element-wise equal + + Notice + ------ + Type checkers such as mypy complains because the return type isn't a bool like + its supertype "object", which violates the Liskov substitution principle. + This is true, but since NumPy's ndarray is defined as an element-wise equal, + our hands are tied. + """ def check_item_key_is_1d_contiguous(key: Any) -> None: @@ -124,7 +190,7 @@ def create_zero_length(cls) -> Self: return cls(np.array([], dtype="b")) @classmethod - def from_array_like(cls, array_like: NDArrayLike) -> Self: + def from_array_like(cls, array_like: ArrayLike) -> Self: """Create a new buffer of a array-like object Parameters @@ -153,7 +219,7 @@ def from_bytes(cls, bytes_like: BytesLike) -> Self: """ return cls.from_array_like(np.frombuffer(bytes_like, dtype="b")) - def as_array_like(self) -> NDArrayLike: + def as_array_like(self) -> ArrayLike: """Return the underlying array (host or device memory) of this buffer This will never copy data. @@ -164,22 +230,6 @@ def as_array_like(self) -> NDArrayLike: """ return self._data - def as_nd_buffer(self, *, dtype: npt.DTypeLike) -> NDBuffer: - """Create a new NDBuffer from this one. - - This will never copy data. - - Parameters - ---------- - dtype - The datatype of the returned buffer (reinterpretation of the bytes) - - Return - ------ - New NDbuffer representing `self.as_array_like()` - """ - return NDBuffer.from_ndarray_like(self._data.view(dtype=dtype)) - def as_numpy_array(self) -> npt.NDArray[Any]: """Return the buffer as a NumPy array (host memory). @@ -223,17 +273,8 @@ def __add__(self, other: Buffer) -> Self: other_array = other.as_array_like() assert other_array.dtype == np.dtype("b") - return self.__class__(np.concatenate((self._data, other_array))) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, bytes | bytearray): - # Many of the tests compares `Buffer` with `bytes` so we - # convert the bytes to a Buffer and try again - return self == self.from_bytes(other) - if isinstance(other, Buffer): - return (self._data == other.as_array_like()).all() - raise ValueError( - f"equal operator not supported between {self.__class__} and {other.__class__}" + return self.__class__( + np.concatenate((np.asanyarray(self._data), np.asanyarray(other_array))) ) @@ -345,22 +386,6 @@ def as_ndarray_like(self) -> NDArrayLike: """ return self._data - def as_buffer(self) -> Buffer: - """Create a new Buffer from this one. - - Warning - ------- - Copies data if the buffer is non-contiguous. - - Return - ------ - The new buffer (might be data copy) - """ - data = self._data - if not self._data.flags.contiguous: - data = np.ascontiguousarray(self._data) - return Buffer(data.reshape(-1).view(dtype="b")) # Flatten the array without copy - def as_numpy_array(self) -> npt.NDArray[Any]: """Return the buffer as a NumPy array (host memory). @@ -393,8 +418,8 @@ def byteorder(self) -> Endian: else: return Endian(sys.byteorder) - def reshape(self, newshape: Iterable[int]) -> Self: - return self.__class__(self._data.reshape(tuple(newshape))) + def reshape(self, newshape: ChunkCoords) -> Self: + return self.__class__(self._data.reshape(newshape)) def astype(self, dtype: npt.DTypeLike, order: Literal["K", "A", "C", "F"] = "K") -> Self: return self.__class__(self._data.astype(dtype=dtype, order=order)) @@ -419,8 +444,8 @@ def fill(self, value: Any) -> None: def copy(self) -> Self: return self.__class__(self._data.copy()) - def transpose(self, *axes: np.SupportsIndex) -> Self: # type: ignore[name-defined] - return self.__class__(self._data.transpose(*axes)) + def transpose(self, axes: SupportsIndex | Sequence[SupportsIndex] | None) -> Self: + return self.__class__(self._data.transpose(axes)) def as_numpy_array_wrapper(func: Callable[[npt.NDArray[Any]], bytes], buf: Buffer) -> Buffer: diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index aebaf94e76..f275ae37d1 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -8,7 +8,7 @@ import numpy as np from zarr.abc.codec import ArrayBytesCodec -from zarr.buffer import Buffer, NDBuffer +from zarr.buffer import Buffer, NDArrayLike, NDBuffer from zarr.codecs.registry import register_codec from zarr.common import parse_enum, parse_named_configuration @@ -75,7 +75,13 @@ async def _decode_single( dtype = np.dtype(f"{prefix}{chunk_spec.dtype.str[1:]}") else: dtype = np.dtype(f"|{chunk_spec.dtype.str[1:]}") - chunk_array = chunk_bytes.as_nd_buffer(dtype=dtype) + + as_array_like = chunk_bytes.as_array_like() + if isinstance(as_array_like, NDArrayLike): + as_nd_array_like = as_array_like + else: + as_nd_array_like = np.asanyarray(as_array_like) + chunk_array = NDBuffer.from_ndarray_like(as_nd_array_like.view(dtype=dtype)) # ensure correct chunk shape if chunk_array.shape != chunk_spec.shape: @@ -96,7 +102,11 @@ async def _encode_single( # see https://github.com/numpy/numpy/issues/26473 new_dtype = chunk_array.dtype.newbyteorder(self.endian.name) # type: ignore[arg-type] chunk_array = chunk_array.astype(new_dtype) - return chunk_array.as_buffer() + + as_nd_array_like = chunk_array.as_ndarray_like() + # Flatten the nd-array (only copy if needed) + as_nd_array_like = as_nd_array_like.ravel().view(dtype="b") + return Buffer.from_array_like(as_nd_array_like) def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 99f8021594..1e6fe09a9f 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -2,6 +2,7 @@ from zarr.abc.store import Store from zarr.buffer import Buffer +from zarr.testing.utils import assert_bytes_equal class StoreTests: @@ -27,7 +28,7 @@ def test_store_capabilities(self, store: Store) -> None: @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""]) async def test_set_get_bytes_roundtrip(self, store: Store, key: str, data: bytes) -> None: await store.set(key, Buffer.from_bytes(data)) - assert await store.get(key) == data + assert_bytes_equal(await store.get(key), data) @pytest.mark.parametrize("key", ["foo/c/0"]) @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""]) @@ -36,11 +37,12 @@ async def test_get_partial_values(self, store: Store, key: str, data: bytes) -> await store.set(key, Buffer.from_bytes(data)) # read back just part of it vals = await store.get_partial_values([(key, (0, 2))]) - assert vals == [data[0:2]] + assert_bytes_equal(vals[0], data[0:2]) # read back multiple parts of it at once vals = await store.get_partial_values([(key, (0, 2)), (key, (2, 4))]) - assert vals == [data[0:2], data[2:4]] + assert_bytes_equal(vals[0], data[0:2]) + assert_bytes_equal(vals[1], data[2:4]) async def test_exists(self, store: Store) -> None: assert not await store.exists("foo") diff --git a/src/zarr/testing/utils.py b/src/zarr/testing/utils.py new file mode 100644 index 0000000000..04b05d1b1c --- /dev/null +++ b/src/zarr/testing/utils.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from zarr.buffer import Buffer +from zarr.common import BytesLike + + +def assert_bytes_equal(b1: Buffer | BytesLike | None, b2: Buffer | BytesLike | None) -> None: + """Help function to assert if two bytes-like or Buffers are equal + + Warning + ------- + Always copies data, only use for testing and debugging + """ + if isinstance(b1, Buffer): + b1 = b1.to_bytes() + if isinstance(b2, Buffer): + b2 = b2.to_bytes() + assert b1 == b2 diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index b9f56014bc..21dc58197e 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -1,5 +1,7 @@ from __future__ import annotations +from collections.abc import Iterator +from types import ModuleType from typing import TYPE_CHECKING from zarr.common import ZarrFormat @@ -81,3 +83,10 @@ async def async_group(request: pytest.FixtureRequest, tmpdir) -> AsyncGroup: exists_ok=False, ) return agroup + + +@pytest.fixture(params=["numpy", "cupy"]) +def xp(request: pytest.FixtureRequest) -> Iterator[ModuleType]: + """Fixture to parametrize over numpy-like libraries""" + + yield pytest.importorskip(request.param) diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index 4ab92768b4..2f58d116fe 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -8,9 +8,7 @@ import pytest from zarr.array import AsyncArray -from zarr.buffer import NDBuffer -from zarr.store.core import StorePath -from zarr.store.memory import MemoryStore +from zarr.buffer import ArrayLike, NDArrayLike, NDBuffer if TYPE_CHECKING: from typing_extensions import Self @@ -41,12 +39,17 @@ def create( return ret +def test_nd_array_like(xp): + ary = xp.arange(10) + assert isinstance(ary, ArrayLike) + assert isinstance(ary, NDArrayLike) + + @pytest.mark.asyncio -async def test_async_array_factory(): - store = StorePath(MemoryStore()) +async def test_async_array_factory(store_path): expect = np.zeros((9, 9), dtype="uint16", order="F") a = await AsyncArray.create( - store / "test_async_array", + store_path, shape=expect.shape, chunk_shape=(5, 5), dtype=expect.dtype, diff --git a/tests/v3/test_codecs.py b/tests/v3/test_codecs.py index 5f94114ede..a595b12494 100644 --- a/tests/v3/test_codecs.py +++ b/tests/v3/test_codecs.py @@ -25,6 +25,7 @@ from zarr.config import config from zarr.indexing import morton_order_iter from zarr.store import MemoryStore, StorePath +from zarr.testing.utils import assert_bytes_equal @dataclass(frozen=True) @@ -294,7 +295,7 @@ async def test_order( fill_value=1, ) z[:, :] = data - assert (await (store / "order/0.0").get()) == z._store["0.0"] + assert_bytes_equal(await (store / "order/0.0").get(), z._store["0.0"]) @pytest.mark.parametrize("input_order", ["F", "C"]) @@ -665,10 +666,10 @@ async def test_zarr_compat(store: Store): assert np.array_equal(data, await _AsyncArrayProxy(a)[:16, :18].get()) assert np.array_equal(data, z2[:16, :18]) - assert z2._store["0.0"] == await (store / "zarr_compat3/0.0").get() - assert z2._store["0.1"] == await (store / "zarr_compat3/0.1").get() - assert z2._store["1.0"] == await (store / "zarr_compat3/1.0").get() - assert z2._store["1.1"] == await (store / "zarr_compat3/1.1").get() + assert_bytes_equal(z2._store["0.0"], await (store / "zarr_compat3/0.0").get()) + assert_bytes_equal(z2._store["0.1"], await (store / "zarr_compat3/0.1").get()) + assert_bytes_equal(z2._store["1.0"], await (store / "zarr_compat3/1.0").get()) + assert_bytes_equal(z2._store["1.1"], await (store / "zarr_compat3/1.1").get()) async def test_zarr_compat_F(store: Store): @@ -698,10 +699,10 @@ async def test_zarr_compat_F(store: Store): assert np.array_equal(data, await _AsyncArrayProxy(a)[:16, :18].get()) assert np.array_equal(data, z2[:16, :18]) - assert z2._store["0.0"] == await (store / "zarr_compatF3/0.0").get() - assert z2._store["0.1"] == await (store / "zarr_compatF3/0.1").get() - assert z2._store["1.0"] == await (store / "zarr_compatF3/1.0").get() - assert z2._store["1.1"] == await (store / "zarr_compatF3/1.1").get() + assert_bytes_equal(z2._store["0.0"], await (store / "zarr_compatF3/0.0").get()) + assert_bytes_equal(z2._store["0.1"], await (store / "zarr_compatF3/0.1").get()) + assert_bytes_equal(z2._store["1.0"], await (store / "zarr_compatF3/1.0").get()) + assert_bytes_equal(z2._store["1.1"], await (store / "zarr_compatF3/1.1").get()) async def test_dimension_names(store: Store): @@ -795,7 +796,7 @@ async def test_endian(store: Store, endian: Literal["big", "little"]): fill_value=1, ) z[:, :] = data - assert await (store / "endian/0.0").get() == z._store["0.0"] + assert_bytes_equal(await (store / "endian/0.0").get(), z._store["0.0"]) @pytest.mark.parametrize("dtype_input_endian", [">u2", "