Skip to content

Commit

Permalink
fix: use Buffer instead of bytes for store tests
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed May 22, 2024
1 parent 07fc249 commit 8915ff3
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 48 deletions.
24 changes: 24 additions & 0 deletions src/zarr/store/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,27 @@ def make_store_path(store_like: StoreLike) -> StorePath:
elif isinstance(store_like, str):
return StorePath(LocalStore(Path(store_like)))
raise TypeError


def _normalize_interval_index(
data: Buffer, interval: None | tuple[int | None, int | None]
) -> tuple[int, int]:
"""
Convert an implicit interval into an explicit start and length
"""
if interval is None:
start = 0
length = len(data)
else:
maybe_start, maybe_len = interval
if maybe_start is None:
start = 0
else:
start = maybe_start

if maybe_len is None:
length = len(data) - start
else:
length = maybe_len

return (start, length)
6 changes: 3 additions & 3 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import concurrent_map
from zarr.store.core import _normalize_interval_index


# TODO: this store could easily be extended to wrap any MutableMapping store from v2
Expand All @@ -31,9 +32,8 @@ async def get(
assert isinstance(key, str)
try:
value = self._store_dict[key]
if byte_range is not None:
value = value[byte_range[0] : byte_range[1]]
return value
start, length = _normalize_interval_index(value, byte_range)
return value[start : start + length]
except KeyError:
return None

Expand Down
67 changes: 29 additions & 38 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,22 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.store.core import _normalize_interval_index
from zarr.testing.utils import assert_bytes_equal


def _normalize_byte_range(
data: bytes, byte_range: None | tuple[int | None, int | None]
) -> tuple[int, int]:
"""
Convert an implicit byte range into an explicit start and length
"""
if byte_range is None:
start = 0
length = len(data)
else:
maybe_start, maybe_len = byte_range
if maybe_start is None:
start = 0
else:
start = maybe_start

if maybe_len is None:
length = len(data) - start
else:
length = maybe_len

return (start, length)


S = TypeVar("S", bound=Store)


class StoreTests(Generic[S]):
store_cls: type[S]

def set(self, store: S, key: str, value: bytes) -> None:
def set(self, store: S, key: str, value: Buffer) -> None:
"""
Insert key: value pairs into a store without using the store methods.
"""
raise NotImplementedError

def get(self, store: S, key: str) -> bytes:
def get(self, store: S, key: str) -> Buffer:
"""
Retrieve values from a store without using the store methods.
"""
Expand Down Expand Up @@ -77,17 +53,20 @@ async def test_get(
self, store: S, key: str, data: bytes, byte_range: None | tuple[int | None, int | None]
) -> None:
# insert values into the store
self.set(store, key, data)
data_buf = Buffer.from_bytes(data)
self.set(store, key, data_buf)
observed = await store.get(key, byte_range=byte_range)
start, length = _normalize_byte_range(data, byte_range=byte_range)
expected = Buffer.from_bytes(data[start : start + length])
assert observed == expected
start, length = _normalize_interval_index(data_buf, interval=byte_range)
expected = data_buf[start : start + length]
assert_bytes_equal(observed, expected)

@pytest.mark.parametrize("key", ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
async def test_set(self, store: S, key: str, data: bytes) -> None:
await store.set(key, Buffer.from_bytes(data))
assert self.get(store, key) == data
data_buf = Buffer.from_bytes(data)
await store.set(key, data_buf)
observed = self.get(store, key)
assert_bytes_equal(observed, data_buf)

@pytest.mark.parametrize(
"key_ranges",
Expand All @@ -103,15 +82,27 @@ async def test_get_partial_values(
) -> None:
# put all of the data
for key, _ in key_ranges:
self.set(store, key, bytes(key, encoding="utf-8"))
self.set(store, key, Buffer.from_bytes(bytes(key, encoding="utf-8")))

# read back just part of it
observed = await store.get_partial_values(key_ranges=key_ranges)
expected = []
observed_maybe = await store.get_partial_values(key_ranges=key_ranges)

observed: list[Buffer] = []
expected: list[Buffer] = []

for obs in observed_maybe:
assert obs is not None
observed.append(obs)

for idx in range(len(observed)):
key, byte_range = key_ranges[idx]
expected.append(await store.get(key, byte_range=byte_range))
assert observed == expected
result = await store.get(key, byte_range=byte_range)
assert result is not None
expected.append(result)

assert all(
obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True)
)

async def test_exists(self, store: S) -> None:
assert not await store.exists("foo")
Expand Down
13 changes: 6 additions & 7 deletions tests/v3/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
from zarr.store.local import LocalStore
from zarr.store.memory import MemoryStore
from zarr.testing.store import StoreTests
from zarr.testing.utils import assert_bytes_equal


@pytest.mark.parametrize("store_dict", (None, {}))
class TestMemoryStore(StoreTests[MemoryStore]):
store_cls = MemoryStore

def set(self, store: MemoryStore, key: str, value: bytes) -> None:
def set(self, store: MemoryStore, key: str, value: Buffer) -> None:
store._store_dict[key] = value

def get(self, store: MemoryStore, key: str) -> bytes:
def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(scope="function")
Expand All @@ -44,14 +43,14 @@ def test_list_prefix(self, store: MemoryStore) -> None:
class TestLocalStore(StoreTests[LocalStore]):
store_cls = LocalStore

def get(self, store: LocalStore, key: str) -> bytes:
return (store.root / key).read_bytes()
def get(self, store: LocalStore, key: str) -> Buffer:
return Buffer.from_bytes((store.root / key).read_bytes())

def set(self, store: LocalStore, key: str, value: bytes) -> None:
def set(self, store: LocalStore, key: str, value: Buffer) -> None:
parent = (store.root / key).parent
if not parent.exists():
parent.mkdir(parents=True)
(store.root / key).write_bytes(value)
(store.root / key).write_bytes(value.to_bytes())

@pytest.fixture(scope="function")
def store(self, tmpdir) -> LocalStore:
Expand Down

0 comments on commit 8915ff3

Please sign in to comment.