diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 52093d4..1a16a00 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -340,7 +340,15 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # 1. Exclude .json files immediately (metadata) if key.endswith(".json"): return None - excluded_array_prefixes = {"time", "lat", "lon", "latitude", "longitude"} + excluded_array_prefixes = { + "time", + "lat", + "lon", + "latitude", + "longitude", + "forecast_reference_time", + "step", + } chunk_marker = "/c/" marker_idx = key.rfind(chunk_marker) # Use rfind for robustness @@ -360,7 +368,7 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: if path_before_c: actual_array_name = path_before_c.split("/")[-1] - # 2. If the determined array name is in our exclusion list, return None. + # If the determined array name is in our exclusion list, return None. if actual_array_name in excluded_array_prefixes: return None @@ -619,29 +627,32 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: raise PermissionError("Cannot write to a read-only store.") await self._resize_complete.wait() - if ( - key.endswith("zarr.json") - and not key.startswith("time/") - and not key.startswith(("lat/", "latitude/")) - and not key.startswith(("lon/", "longitude/")) - and not key == "zarr.json" - ): + if key.endswith("zarr.json") and not key == "zarr.json": metadata_json = json.loads(value.to_bytes().decode("utf-8")) new_array_shape = metadata_json.get("shape") - if not new_array_shape: - raise ValueError("Shape not found in metadata.") - if tuple(new_array_shape) != self._array_shape: - async with self._resize_lock: - # Double-check after acquiring the lock, in case another task - # just finished this exact resize while we were waiting. - if tuple(new_array_shape) != self._array_shape: - # Block all other tasks until resize is complete. - self._resize_complete.clear() - try: - await self.resize_store(new_shape=tuple(new_array_shape)) - finally: - # All waiting tasks will now un-pause and proceed safely. - self._resize_complete.set() + # Some metadata entries (e.g., group metadata) do not have a shape field. + if new_array_shape: + # Only resize when the metadata shape represents the primary array. + if ( + len(new_array_shape) == len(self._array_shape) + and tuple(new_array_shape) != self._array_shape + ): + async with self._resize_lock: + # Double-check after acquiring the lock, in case another task + # just finished this exact resize while we were waiting. + if ( + len(new_array_shape) == len(self._array_shape) + and tuple(new_array_shape) != self._array_shape + ): + # Block all other tasks until resize is complete. + self._resize_complete.clear() + try: + await self.resize_store( + new_shape=tuple(new_array_shape) + ) + finally: + # All waiting tasks will now un-pause and proceed safely. + self._resize_complete.set() raw_data_bytes = value.to_bytes() # Save the data to CAS first to get its CID. @@ -706,10 +717,9 @@ async def delete(self, key: str) -> None: chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata - if self._root_obj["metadata"].pop(key, None): + # Coordinate/metadata deletions should be idempotent for caller convenience. + if self._root_obj["metadata"].pop(key, None) is not None: self._dirty_root = True - else: - raise KeyError(f"Metadata key '{key}' not found.") return None linear_chunk_index = self._get_linear_chunk_index(chunk_coords) diff --git a/pyproject.toml b/pyproject.toml index 24d103b..09ba772 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "py-hamt" -version = "3.3.0" +version = "3.3.1" description = "HAMT implementation for a content-addressed storage system." readme = "README.md" requires-python = ">=3.12" diff --git a/tests/test_sharded_store_deleting.py b/tests/test_sharded_store_deleting.py index 88222a8..36d81c0 100644 --- a/tests/test_sharded_store_deleting.py +++ b/tests/test_sharded_store_deleting.py @@ -137,10 +137,6 @@ async def test_delete_nonexistent_key(create_ipfs: tuple[str, str]): await store.flush() assert not store._shard_data_cache._dirty_shards # No dirty shards after flush - # Try to delete nonexistent metadata key - with pytest.raises(KeyError, match="Metadata key 'nonexistent.json' not found"): - await store.delete("nonexistent.json") - # Try to delete nonexistent chunk key (out of bounds) with pytest.raises(IndexError, match="Chunk coordinate"): await store.delete("temp/c/3/0") # Out of bounds for 2x2 chunk grid diff --git a/tests/test_sharded_store_resizing.py b/tests/test_sharded_store_resizing.py index ce43465..29da805 100644 --- a/tests/test_sharded_store_resizing.py +++ b/tests/test_sharded_store_resizing.py @@ -296,18 +296,6 @@ async def test_resize_variable_invalid_cases( ): await store.resize_variable("nonexistent", new_shape=(150, 18, 36)) - # Test invalid metadata (simulate by setting invalid metadata) - invalid_metadata = json.dumps({"not_shape": [1, 2, 3]}).encode("utf-8") - invalid_cid = await kubo_cas.save(invalid_metadata, codec="raw") - store._root_obj["metadata"]["invalid/zarr.json"] = invalid_cid - with pytest.raises(ValueError, match="Shape not found in metadata"): - await store.set( - "invalid/zarr.json", - zarr.core.buffer.default_buffer_prototype().buffer.from_bytes( - invalid_metadata - ), - ) - @pytest.mark.asyncio async def test_resize_store_with_data_preservation(create_ipfs: tuple[str, str]): diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index ea41734..c84a034 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -1,4 +1,6 @@ import asyncio +import json +import math import dag_cbor import numpy as np @@ -76,6 +78,148 @@ async def test_sharded_zarr_store_write_read( await store_read.set("temp/c/0/0", proto.buffer.from_bytes(b"test_data")) +@pytest.mark.asyncio +async def test_sharded_zarr_store_forecast_step_coordinates( + create_ipfs: tuple[str, str], +): + """Ensure datasets with forecast_reference_time/step coordinates write and read.""" + rpc_base_url, gateway_base_url = create_ipfs + forecast_reference_time = pd.date_range("2024-01-01", periods=3, freq="6H") + step = pd.to_timedelta([0, 6, 12, 18], unit="h") + latitude = np.linspace(-90, 90, 4) + longitude = np.linspace(-180, 180, 8) + + temperature = np.random.randn( + len(forecast_reference_time), + len(step), + len(latitude), + len(longitude), + ) + + ds = xr.Dataset( + { + "t2m": ( + ["forecast_reference_time", "step", "latitude", "longitude"], + temperature, + ) + }, + coords={ + "forecast_reference_time": forecast_reference_time, + "step": step, + "latitude": latitude, + "longitude": longitude, + }, + ).chunk({ + "forecast_reference_time": 2, + "step": 2, + "latitude": 2, + "longitude": 4, + }) + + ordered_dims = list(ds.sizes) + array_shape_tuple = tuple(ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=32, + ) + ds.to_zarr(store=store_write, mode="w") + # The primary array geometry should remain unchanged after writing + assert store_write._array_shape == array_shape_tuple + assert store_write._chunks_per_dim == tuple( + math.ceil(a / c) + for a, c in zip(array_shape_tuple, chunk_shape_tuple, strict=True) + ) + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + ds_read = xr.open_zarr(store=store_read) + xr.testing.assert_identical(ds, ds_read) + + # Coordinate arrays should be stored and retrievable + assert await store_read.exists("forecast_reference_time/c/0") + assert await store_read.exists("step/c/0") + + +@pytest.mark.asyncio +async def test_coordinate_metadata_delete_idempotent(create_ipfs: tuple[str, str]): + """Deleting coordinate metadata keys should be a no-op if already absent.""" + rpc_base_url, gateway_base_url = create_ipfs + ds = xr.Dataset( + { + "t": ( + ["forecast_reference_time", "step", "latitude", "longitude"], + np.random.randn(1, 1, 2, 2), + ) + }, + coords={ + "forecast_reference_time": pd.date_range("2024-01-01", periods=1), + "step": pd.to_timedelta([0], unit="h"), + "latitude": np.array([0.0, 1.0]), + "longitude": np.array([0.0, 1.0]), + }, + ).chunk({ + "forecast_reference_time": 1, + "step": 1, + "latitude": 1, + "longitude": 1, + }) + + ordered_dims = list(ds.sizes) + array_shape_tuple = tuple(ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=8, + ) + ds.to_zarr(store=store, mode="w") + # Delete coordinate metadata once (present) and again (absent) without errors. + await store.delete("forecast_reference_time/c/0") + await store.delete("forecast_reference_time/c/0") + await store.delete("step/c/0") + await store.delete("step/c/0") + + +@pytest.mark.asyncio +async def test_metadata_without_shape_does_not_resize(create_ipfs: tuple[str, str]): + """Metadata files lacking a shape should not trigger resize or errors.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(5, 5), + chunks_per_shard=4, + ) + proto = zarr.core.buffer.default_buffer_prototype() + metadata_without_shape = proto.buffer.from_bytes( + json.dumps({"zarr_format": 3}).encode("utf-8") + ) + await store.set("group/zarr.json", metadata_without_shape) + # No resize should occur; geometry stays the same. + assert store._array_shape == (10, 10) + assert store._chunks_per_dim == (2, 2) + + @pytest.mark.asyncio async def test_load_or_initialize_shard_cache_concurrent_loads( create_ipfs: tuple[str, str], diff --git a/tests/test_sharded_zarr_store_coverage.py b/tests/test_sharded_zarr_store_coverage.py index 72463c7..b42a39e 100644 --- a/tests/test_sharded_zarr_store_coverage.py +++ b/tests/test_sharded_zarr_store_coverage.py @@ -205,11 +205,10 @@ async def test_sharded_zarr_store_get_set_exceptions(create_ipfs: tuple[str, str with pytest.raises(NotImplementedError): await store.set_partial_values([]) - # Test ValueError when shape is not found in metadata during set - with pytest.raises(ValueError, match="Shape not found in metadata."): - await store.set( - "test/zarr.json", proto.buffer.from_bytes(b'{"not": "a shape"}') - ) + # Metadata lacking shape should be accepted without resizing + await store.set( + "test/zarr.json", proto.buffer.from_bytes(b'{"not": "a shape"}') + ) @pytest.mark.asyncio