Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions py_hamt/sharded_zarr_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,15 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]:
# 1. Exclude .json files immediately (metadata)
if key.endswith(".json"):
return None
excluded_array_prefixes = {"time", "lat", "lon", "latitude", "longitude"}
excluded_array_prefixes = {
"time",
"lat",
"lon",
"latitude",
"longitude",
"forecast_reference_time",
"step",
}

chunk_marker = "/c/"
marker_idx = key.rfind(chunk_marker) # Use rfind for robustness
Expand All @@ -360,7 +368,7 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]:
if path_before_c:
actual_array_name = path_before_c.split("/")[-1]

# 2. If the determined array name is in our exclusion list, return None.
# If the determined array name is in our exclusion list, return None.
if actual_array_name in excluded_array_prefixes:
return None

Expand Down Expand Up @@ -619,29 +627,32 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None:
raise PermissionError("Cannot write to a read-only store.")
await self._resize_complete.wait()

if (
key.endswith("zarr.json")
and not key.startswith("time/")
and not key.startswith(("lat/", "latitude/"))
and not key.startswith(("lon/", "longitude/"))
and not key == "zarr.json"
):
if key.endswith("zarr.json") and not key == "zarr.json":
metadata_json = json.loads(value.to_bytes().decode("utf-8"))
new_array_shape = metadata_json.get("shape")
if not new_array_shape:
raise ValueError("Shape not found in metadata.")
if tuple(new_array_shape) != self._array_shape:
async with self._resize_lock:
# Double-check after acquiring the lock, in case another task
# just finished this exact resize while we were waiting.
if tuple(new_array_shape) != self._array_shape:
# Block all other tasks until resize is complete.
self._resize_complete.clear()
try:
await self.resize_store(new_shape=tuple(new_array_shape))
finally:
# All waiting tasks will now un-pause and proceed safely.
self._resize_complete.set()
# Some metadata entries (e.g., group metadata) do not have a shape field.
if new_array_shape:
# Only resize when the metadata shape represents the primary array.
if (
len(new_array_shape) == len(self._array_shape)
and tuple(new_array_shape) != self._array_shape
):
async with self._resize_lock:
# Double-check after acquiring the lock, in case another task
# just finished this exact resize while we were waiting.
if (
len(new_array_shape) == len(self._array_shape)
and tuple(new_array_shape) != self._array_shape
):
# Block all other tasks until resize is complete.
self._resize_complete.clear()
try:
await self.resize_store(
new_shape=tuple(new_array_shape)
)
finally:
# All waiting tasks will now un-pause and proceed safely.
self._resize_complete.set()

raw_data_bytes = value.to_bytes()
# Save the data to CAS first to get its CID.
Expand Down Expand Up @@ -706,10 +717,9 @@ async def delete(self, key: str) -> None:

chunk_coords = self._parse_chunk_key(key)
if chunk_coords is None: # Metadata
if self._root_obj["metadata"].pop(key, None):
# Coordinate/metadata deletions should be idempotent for caller convenience.
if self._root_obj["metadata"].pop(key, None) is not None:
self._dirty_root = True
else:
raise KeyError(f"Metadata key '{key}' not found.")
return None

linear_chunk_index = self._get_linear_chunk_index(chunk_coords)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "py-hamt"
version = "3.3.0"
version = "3.3.1"
description = "HAMT implementation for a content-addressed storage system."
readme = "README.md"
requires-python = ">=3.12"
Expand Down
4 changes: 0 additions & 4 deletions tests/test_sharded_store_deleting.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ async def test_delete_nonexistent_key(create_ipfs: tuple[str, str]):
await store.flush()
assert not store._shard_data_cache._dirty_shards # No dirty shards after flush

# Try to delete nonexistent metadata key
with pytest.raises(KeyError, match="Metadata key 'nonexistent.json' not found"):
await store.delete("nonexistent.json")

# Try to delete nonexistent chunk key (out of bounds)
with pytest.raises(IndexError, match="Chunk coordinate"):
await store.delete("temp/c/3/0") # Out of bounds for 2x2 chunk grid
Expand Down
12 changes: 0 additions & 12 deletions tests/test_sharded_store_resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,6 @@ async def test_resize_variable_invalid_cases(
):
await store.resize_variable("nonexistent", new_shape=(150, 18, 36))

# Test invalid metadata (simulate by setting invalid metadata)
invalid_metadata = json.dumps({"not_shape": [1, 2, 3]}).encode("utf-8")
invalid_cid = await kubo_cas.save(invalid_metadata, codec="raw")
store._root_obj["metadata"]["invalid/zarr.json"] = invalid_cid
with pytest.raises(ValueError, match="Shape not found in metadata"):
await store.set(
"invalid/zarr.json",
zarr.core.buffer.default_buffer_prototype().buffer.from_bytes(
invalid_metadata
),
)


@pytest.mark.asyncio
async def test_resize_store_with_data_preservation(create_ipfs: tuple[str, str]):
Expand Down
144 changes: 144 additions & 0 deletions tests/test_sharded_zarr_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import json
import math

import dag_cbor
import numpy as np
Expand Down Expand Up @@ -76,6 +78,148 @@ async def test_sharded_zarr_store_write_read(
await store_read.set("temp/c/0/0", proto.buffer.from_bytes(b"test_data"))


@pytest.mark.asyncio
async def test_sharded_zarr_store_forecast_step_coordinates(
create_ipfs: tuple[str, str],
):
"""Ensure datasets with forecast_reference_time/step coordinates write and read."""
rpc_base_url, gateway_base_url = create_ipfs
forecast_reference_time = pd.date_range("2024-01-01", periods=3, freq="6H")
step = pd.to_timedelta([0, 6, 12, 18], unit="h")
latitude = np.linspace(-90, 90, 4)
longitude = np.linspace(-180, 180, 8)

temperature = np.random.randn(
len(forecast_reference_time),
len(step),
len(latitude),
len(longitude),
)

ds = xr.Dataset(
{
"t2m": (
["forecast_reference_time", "step", "latitude", "longitude"],
temperature,
)
},
coords={
"forecast_reference_time": forecast_reference_time,
"step": step,
"latitude": latitude,
"longitude": longitude,
},
).chunk({
"forecast_reference_time": 2,
"step": 2,
"latitude": 2,
"longitude": 4,
})

ordered_dims = list(ds.sizes)
array_shape_tuple = tuple(ds.sizes[dim] for dim in ordered_dims)
chunk_shape_tuple = tuple(ds.chunks[dim][0] for dim in ordered_dims)

async with KuboCAS(
rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url
) as kubo_cas:
store_write = await ShardedZarrStore.open(
cas=kubo_cas,
read_only=False,
array_shape=array_shape_tuple,
chunk_shape=chunk_shape_tuple,
chunks_per_shard=32,
)
ds.to_zarr(store=store_write, mode="w")
# The primary array geometry should remain unchanged after writing
assert store_write._array_shape == array_shape_tuple
assert store_write._chunks_per_dim == tuple(
math.ceil(a / c)
for a, c in zip(array_shape_tuple, chunk_shape_tuple, strict=True)
)
root_cid = await store_write.flush()

store_read = await ShardedZarrStore.open(
cas=kubo_cas, read_only=True, root_cid=root_cid
)
ds_read = xr.open_zarr(store=store_read)
xr.testing.assert_identical(ds, ds_read)

# Coordinate arrays should be stored and retrievable
assert await store_read.exists("forecast_reference_time/c/0")
assert await store_read.exists("step/c/0")


@pytest.mark.asyncio
async def test_coordinate_metadata_delete_idempotent(create_ipfs: tuple[str, str]):
"""Deleting coordinate metadata keys should be a no-op if already absent."""
rpc_base_url, gateway_base_url = create_ipfs
ds = xr.Dataset(
{
"t": (
["forecast_reference_time", "step", "latitude", "longitude"],
np.random.randn(1, 1, 2, 2),
)
},
coords={
"forecast_reference_time": pd.date_range("2024-01-01", periods=1),
"step": pd.to_timedelta([0], unit="h"),
"latitude": np.array([0.0, 1.0]),
"longitude": np.array([0.0, 1.0]),
},
).chunk({
"forecast_reference_time": 1,
"step": 1,
"latitude": 1,
"longitude": 1,
})

ordered_dims = list(ds.sizes)
array_shape_tuple = tuple(ds.sizes[dim] for dim in ordered_dims)
chunk_shape_tuple = tuple(ds.chunks[dim][0] for dim in ordered_dims)

async with KuboCAS(
rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url
) as kubo_cas:
store = await ShardedZarrStore.open(
cas=kubo_cas,
read_only=False,
array_shape=array_shape_tuple,
chunk_shape=chunk_shape_tuple,
chunks_per_shard=8,
)
ds.to_zarr(store=store, mode="w")
# Delete coordinate metadata once (present) and again (absent) without errors.
await store.delete("forecast_reference_time/c/0")
await store.delete("forecast_reference_time/c/0")
await store.delete("step/c/0")
await store.delete("step/c/0")


@pytest.mark.asyncio
async def test_metadata_without_shape_does_not_resize(create_ipfs: tuple[str, str]):
"""Metadata files lacking a shape should not trigger resize or errors."""
rpc_base_url, gateway_base_url = create_ipfs
async with KuboCAS(
rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url
) as kubo_cas:
store = await ShardedZarrStore.open(
cas=kubo_cas,
read_only=False,
array_shape=(10, 10),
chunk_shape=(5, 5),
chunks_per_shard=4,
)
proto = zarr.core.buffer.default_buffer_prototype()
metadata_without_shape = proto.buffer.from_bytes(
json.dumps({"zarr_format": 3}).encode("utf-8")
)
await store.set("group/zarr.json", metadata_without_shape)
# No resize should occur; geometry stays the same.
assert store._array_shape == (10, 10)
assert store._chunks_per_dim == (2, 2)


@pytest.mark.asyncio
async def test_load_or_initialize_shard_cache_concurrent_loads(
create_ipfs: tuple[str, str],
Expand Down
9 changes: 4 additions & 5 deletions tests/test_sharded_zarr_store_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,10 @@ async def test_sharded_zarr_store_get_set_exceptions(create_ipfs: tuple[str, str
with pytest.raises(NotImplementedError):
await store.set_partial_values([])

# Test ValueError when shape is not found in metadata during set
with pytest.raises(ValueError, match="Shape not found in metadata."):
await store.set(
"test/zarr.json", proto.buffer.from_bytes(b'{"not": "a shape"}')
)
# Metadata lacking shape should be accepted without resizing
await store.set(
"test/zarr.json", proto.buffer.from_bytes(b'{"not": "a shape"}')
)


@pytest.mark.asyncio
Expand Down