Skip to content

Commit

Permalink
generalize buffer to cover ndbuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Jun 9, 2024
1 parent c2a1d2e commit 0739f30
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 22 deletions.
47 changes: 38 additions & 9 deletions src/zarr/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Literal,
Protocol,
SupportsIndex,
TypeVar,
runtime_checkable,
)

Expand Down Expand Up @@ -40,9 +41,10 @@ def __getitem__(self, key: slice) -> Self: ...

def __setitem__(self, key: slice, value: Any) -> None: ...


from typing import Generic
TShape = TypeVar('TShape', bound=tuple[int, ...])
@runtime_checkable
class NDArrayLike(Protocol):
class NDArrayLike(Protocol, Generic[TShape]):
"""Protocol for the nd-array-like type that underlie NDBuffer"""

@property
Expand All @@ -55,7 +57,7 @@ def ndim(self) -> int: ...
def size(self) -> int: ...

@property
def shape(self) -> ChunkCoords: ...
def shape(self) -> TShape: ...

def __len__(self) -> int: ...

Expand Down Expand Up @@ -151,7 +153,7 @@ def __call__(self, ndarray_like: NDArrayLike) -> NDBuffer:
"""


class Buffer:
class xBuffer:
"""A flat contiguous memory block
We use Buffer throughout Zarr to represent a contiguous block of memory.
Expand Down Expand Up @@ -278,7 +280,7 @@ def __add__(self, other: Buffer) -> Self:
)


class NDBuffer:
class NDBuffer(Generic[TShape]):
"""A n-dimensional memory block
We use NDBuffer throughout Zarr to represent a n-dimensional memory block.
Expand All @@ -303,7 +305,8 @@ class NDBuffer:
ndarray-like object that is convertible to a regular Numpy array.
"""

def __init__(self, array: NDArrayLike):
def __init__(self,
array: NDArrayLike[TShape]):
assert array.ndim > 0
assert array.dtype != object
self._data = array
Expand Down Expand Up @@ -346,7 +349,7 @@ def create(
return ret

@classmethod
def from_ndarray_like(cls, ndarray_like: NDArrayLike) -> Self:
def from_ndarray_like(cls, ndarray_like: NDArrayLike[TShape]) -> Self:
"""Create a new buffer of a ndarray-like object
Parameters
Expand Down Expand Up @@ -375,7 +378,7 @@ def from_numpy_array(cls, array_like: npt.ArrayLike) -> Self:
"""
return cls.from_ndarray_like(np.asanyarray(array_like))

def as_ndarray_like(self) -> NDArrayLike:
def as_ndarray_like(self) -> NDArrayLike[TShape]:
"""Return the underlying array (host or device memory) of this buffer
This will never copy data.
Expand All @@ -399,12 +402,37 @@ def as_numpy_array(self) -> npt.NDArray[Any]:
"""
return np.asanyarray(self._data)

@classmethod
def from_bytes(cls, bytes_like: BytesLike, shape: TShape, dtype: npt.DTypeLike = 'b') -> Self:
"""Create a new buffer of a bytes-like object (host memory)
Parameters
----------
bytes_like
bytes-like object
Return
------
New buffer representing `bytes_like`
"""
return cls.from_ndarray_like(np.frombuffer(bytes_like, dtype=dtype).reshape(shape))

@classmethod
def create_zero_length(cls, ndim: int, dtype: npt.DTypeLike = 'b') -> Self:
"""Create an empty buffer with length zero
Return
------
New empty 0-length buffer
"""
return cls(np.expand_dims(np.array([], dtype=dtype), tuple(range(ndim))))

@property
def dtype(self) -> np.dtype[Any]:
return self._data.dtype

@property
def shape(self) -> tuple[int, ...]:
def shape(self) -> TShape:
return self._data.shape

@property
Expand Down Expand Up @@ -447,6 +475,7 @@ def copy(self) -> Self:
def transpose(self, axes: SupportsIndex | Sequence[SupportsIndex] | None) -> Self:
return self.__class__(self._data.transpose(axes))

Buffer = NDBuffer[tuple[int]]

def as_numpy_array_wrapper(func: Callable[[npt.NDArray[Any]], bytes], buf: Buffer) -> Buffer:
"""Converts the input of `func` to a numpy array and the output back to `Buffer`.
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/codecs/blosc.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def _encode_single(
# Since blosc only takes bytes, we convert the input and output of the encoding
# between bytes and Buffer
return await to_thread(
lambda chunk: Buffer.from_bytes(self._blosc_codec.encode(chunk.as_array_like())),
lambda chunk: Buffer.from_bytes(self._blosc_codec.encode(chunk.as_ndarray_like())),
chunk_bytes,
)

Expand Down
6 changes: 3 additions & 3 deletions src/zarr/codecs/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def _decode_single(
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> NDBuffer:
assert isinstance(chunk_bytes, Buffer)
assert isinstance(chunk_bytes, NDBuffer) and len(chunk_bytes.shape) == 1
if chunk_spec.dtype.itemsize > 0:
if self.endian == Endian.little:
prefix = "<"
Expand All @@ -76,7 +76,7 @@ async def _decode_single(
else:
dtype = np.dtype(f"|{chunk_spec.dtype.str[1:]}")

as_array_like = chunk_bytes.as_array_like()
as_array_like = chunk_bytes.as_ndarray_like()
if isinstance(as_array_like, NDArrayLike):
as_nd_array_like = as_array_like
else:
Expand Down Expand Up @@ -106,7 +106,7 @@ async def _encode_single(
as_nd_array_like = chunk_array.as_ndarray_like()
# Flatten the nd-array (only copy if needed)
as_nd_array_like = as_nd_array_like.ravel().view(dtype="b")
return Buffer.from_array_like(as_nd_array_like)
return Buffer.from_ndarray_like(as_nd_array_like)

def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
return input_byte_length
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def from_bytes(
def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardReader:
index = _ShardIndex.create_empty(chunks_per_shard)
obj = cls()
obj.buf = Buffer.create_zero_length()
obj.buf = Buffer.create_zero_length(ndim=1)
obj.index = index
return obj

Expand Down Expand Up @@ -217,7 +217,7 @@ def merge_with_morton_order(
@classmethod
def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardBuilder:
obj = cls()
obj.buf = Buffer.create_zero_length()
obj.buf = Buffer.create_zero_length(ndim=1)
obj.index = _ShardIndex.create_empty(chunks_per_shard)
return obj

Expand Down
4 changes: 2 additions & 2 deletions src/zarr/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]:
config: dict[str, Any] = o.get_config()
return config
raise TypeError

data = json.dumps(self.to_dict(), default=_json_convert).encode()
return {
ZARR_JSON: Buffer.from_bytes(json.dumps(self.to_dict(), default=_json_convert).encode())
ZARR_JSON: Buffer.from_bytes(data, shape=(len(data),))
}

@classmethod
Expand Down
7 changes: 4 additions & 3 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.buffer import Buffer, NDBuffer
from zarr.common import OpenMode, concurrent_map, to_thread


Expand All @@ -32,7 +32,8 @@ def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer

end = (start + byte_range[1]) if byte_range[1] is not None else None
else:
return Buffer.from_bytes(path.read_bytes())
data = path.read_bytes()
return Buffer.from_bytes(data, shape=(len(data),))
with path.open("rb") as f:
size = f.seek(0, io.SEEK_END)
if start is not None:
Expand Down Expand Up @@ -123,7 +124,7 @@ async def set(self, key: str, value: Buffer) -> None:
if isinstance(value, bytes | bytearray): # type:ignore[unreachable]
# TODO: to support the v2 tests, we convert bytes to Buffer here
value = Buffer.from_bytes(value) # type:ignore[unreachable]
if not isinstance(value, Buffer):
if not isinstance(value, NDBuffer):
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
path = self.root / key
await to_thread(_put, path, value)
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import AsyncGenerator, MutableMapping

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.buffer import Buffer, NDBuffer
from zarr.common import OpenMode, concurrent_map
from zarr.store.core import _normalize_interval_index

Expand Down Expand Up @@ -55,7 +55,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
if isinstance(value, bytes | bytearray): # type:ignore[unreachable]
# TODO: to support the v2 tests, we convert bytes to Buffer here
value = Buffer.from_bytes(value) # type:ignore[unreachable]
if not isinstance(value, Buffer):
if not isinstance(value, NDBuffer):
raise TypeError(f"Expected Buffer. Got {type(value)}.")

if byte_range is not None:
Expand Down

0 comments on commit 0739f30

Please sign in to comment.