From f62e6c11bee29cf30b0e7c82e0ba46aec1f44451 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Fri, 30 May 2025 08:31:20 -0400 Subject: [PATCH 01/74] feat: support partials --- py_hamt/hamt.py | 5 +- py_hamt/store.py | 50 ++++-- py_hamt/zarr_hamt_store.py | 48 +++++- tests/test_zarr_ipfs_partial.py | 262 ++++++++++++++++++++++++++++++++ 4 files changed, 348 insertions(+), 17 deletions(-) create mode 100644 tests/test_zarr_ipfs_partial.py diff --git a/py_hamt/hamt.py b/py_hamt/hamt.py index ad867c2..98f3dc6 100644 --- a/py_hamt/hamt.py +++ b/py_hamt/hamt.py @@ -3,6 +3,7 @@ import uuid import asyncio from copy import deepcopy +from typing import Optional import dag_cbor from dag_cbor.ipld import IPLDKind @@ -571,10 +572,10 @@ async def delete(self, key: str): # If we didn't make a change, then this key must not exist within the HAMT raise KeyError - async def get(self, key: str) -> IPLDKind: + async def get(self, key: str, offset: Optional[int] = None, length: Optional[int] = None, suffix: Optional[int] = None) -> IPLDKind: """Get a value.""" pointer = await self.get_pointer(key) - data = await self.cas.load(pointer) + data = await self.cas.load(pointer, offset=offset, length=length, suffix=suffix) if self.values_are_bytes: return data else: diff --git a/py_hamt/store.py b/py_hamt/store.py index 0fbe62e..2071383 100644 --- a/py_hamt/store.py +++ b/py_hamt/store.py @@ -1,6 +1,6 @@ import asyncio import aiohttp -from typing import Literal +from typing import Literal, Optional from abc import ABC, abstractmethod from dag_cbor.ipld import IPLDKind from multiformats import multihash @@ -30,7 +30,7 @@ async def save(self, data: bytes, codec: CodecInput) -> IPLDKind: """ @abstractmethod - async def load(self, id: IPLDKind) -> bytes: + async def load(self, id: IPLDKind, offset: Optional[int] = None, length: Optional[int] = None, suffix: Optional[int] = None) -> bytes: """Retrieve data.""" @@ -49,11 +49,23 @@ async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> by self.store[hash] = data return hash - async def load(self, id: bytes) -> bytes: # type: ignore since bytes is a subset of the IPLDKind type - if id in self.store: - return self.store[id] - - raise KeyError + async def load(self, id: bytes, offset: Optional[int] = None, length: Optional[int] = None, suffix: Optional[int] = None) -> bytes: # type: ignore since bytes is a subset of the IPLDKind type + if id not in self.store: + raise KeyError + + data = self.store[id] + + if offset is not None: + start = offset + if length is not None: + end = start + length + return data[start:end] + else: + return data[start:] + elif suffix is not None: # If only length is given, assume start from 0 + return data[-suffix:] + else: # Full load + return data class KuboCAS(ContentAddressedStore): @@ -166,11 +178,31 @@ async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> CI return cid async def load( # type: ignore CID is definitely in the IPLDKind type - self, id: CID + self, + id: CID, + offset: Optional[int] = None, + length: Optional[int] = None, + suffix: Optional[int] = None ) -> bytes: """@private""" url = self.gateway_base_url + str(id) + headers = {} + + # Construct the Range header if required + if offset is not None: + start = offset + if length is not None: + # Standard HTTP Range: bytes=start-end (inclusive) + end = start + length - 1 + headers["Range"] = f"bytes={start}-{end}" + else: + # Standard HTTP Range: bytes=start- (from start to end) + headers["Range"] = f"bytes={start}-" + elif suffix is not None: + # Standard HTTP Range: bytes=-N (last N bytes) + headers["Range"] = f"bytes=-{suffix}" + async with self._sem: # throttle gateway - async with self._loop_session().get(url) as resp: + async with self._loop_session().get(url, headers=headers or None) as resp: resp.raise_for_status() return await resp.read() diff --git a/py_hamt/zarr_hamt_store.py b/py_hamt/zarr_hamt_store.py index f5c5dbb..8f2c3a3 100644 --- a/py_hamt/zarr_hamt_store.py +++ b/py_hamt/zarr_hamt_store.py @@ -2,10 +2,11 @@ import zarr.abc.store import zarr.core.buffer from zarr.core.common import BytesLike +from typing import Optional +import asyncio from py_hamt.hamt import HAMT - class ZarrHAMTStore(zarr.abc.store.Store): """ Write and read Zarr v3s with a HAMT. @@ -61,6 +62,27 @@ def __init__(self, hamt: HAMT, read_only: bool = False) -> None: self.metadata_read_cache: dict[str, bytes] = dict() """@private""" + def _map_byte_request(self, byte_range: Optional[zarr.abc.store.ByteRequest]) -> tuple[Optional[int], Optional[int], Optional[int]]: + """Helper to map Zarr ByteRequest to offset, length, suffix.""" + offset: Optional[int] = None + length: Optional[int] = None + suffix: Optional[int] = None + + if byte_range: + if isinstance(byte_range, zarr.abc.store.RangeByteRequest): + offset = byte_range.start + length = byte_range.end - byte_range.start + if length < 0: + raise ValueError("End must be >= start for RangeByteRequest") + elif isinstance(byte_range, zarr.abc.store.OffsetByteRequest): + offset = byte_range.offset + elif isinstance(byte_range, zarr.abc.store.SuffixByteRequest): + suffix = byte_range.suffix + else: + raise TypeError(f"Unsupported ByteRequest type: {type(byte_range)}") + + return offset, length, suffix + @property def read_only(self) -> bool: """@private""" @@ -86,25 +108,39 @@ async def get( len(key) >= 9 and key[-9:] == "zarr.json" ) # if path ends with zarr.json - if is_metadata and key in self.metadata_read_cache: + if is_metadata and byte_range is None and key in self.metadata_read_cache: val = self.metadata_read_cache[key] else: - val = await self.hamt.get(key) # type: ignore We know values received will always be bytes since we only store bytes in the HAMT - if is_metadata: + offset, length, suffix = self._map_byte_request(byte_range) + + val = await self.hamt.get(key, offset=offset, length=length, suffix=suffix) # type: ignore We know values received will always be bytes since we only store bytes in the HAMT + # Update cache only on full metadata reads + if is_metadata and byte_range is None: self.metadata_read_cache[key] = val return prototype.buffer.from_bytes(val) except KeyError: # Sometimes zarr queries keys that don't exist anymore, just return nothing on those cases return + except Exception as e: + print(f"Error getting key '{key}' with range {byte_range}: {e}") + raise + async def get_partial_values( self, prototype: zarr.core.buffer.BufferPrototype, key_ranges: Iterable[tuple[str, zarr.abc.store.ByteRequest | None]], ) -> list[zarr.core.buffer.Buffer | None]: - """@private""" - raise NotImplementedError + """ + Retrieves multiple keys or byte ranges concurrently using asyncio.gather. + """ + tasks = [ + self.get(key, prototype, byte_range) + for key, byte_range in key_ranges + ] + results = await asyncio.gather(*tasks, return_exceptions=False) # Set return_exceptions=True for debugging + return results async def exists(self, key: str) -> bool: """@private""" diff --git a/tests/test_zarr_ipfs_partial.py b/tests/test_zarr_ipfs_partial.py new file mode 100644 index 0000000..702441f --- /dev/null +++ b/tests/test_zarr_ipfs_partial.py @@ -0,0 +1,262 @@ +import time + +import numpy as np +import pandas as pd +import xarray as xr +import pytest +import zarr +import zarr.core.buffer +# Make sure to import the ByteRequest types +from zarr.abc.store import RangeByteRequest, OffsetByteRequest, SuffixByteRequest +import aiohttp +from typing import Optional + + + +from py_hamt import HAMT, KuboCAS + +from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset. + + Returns: + tuple: (dataset_path, expected_data) + - dataset_path: Path to the zarr store + - expected_data: The original xarray Dataset for comparison + """ + # Coordinates of the random data + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + # Create random variables with different shapes + temp = np.random.randn(len(times), len(lats), len(lons)) + precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons))) + + # Create the dataset + ds = xr.Dataset( + { + "temp": ( + ["time", "lat", "lon"], + temp, + {"units": "celsius", "long_name": "Surface Temperature"}, + ), + "precip": ( + ["time", "lat", "lon"], + precip, + {"units": "mm/day", "long_name": "Daily Precipitation"}, + ), + }, + coords={ + "time": times, + "lat": ("lat", lats, {"units": "degrees_north"}), + "lon": ("lon", lons, {"units": "degrees_east"}), + }, + attrs={"description": "Test dataset with random weather data"}, + ) + + yield ds + +# This test also collects miscellaneous statistics about performance, run with pytest -s to see these statistics being printed out +@pytest.mark.asyncio(loop_scope="session") # ← match the loop of the fixture +async def test_write_read( + create_ipfs, + random_zarr_dataset: xr.Dataset, +): # noqa for fixture which is imported above but then "redefined" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + print("=== Writing this xarray Dataset to a Zarr v3 on IPFS ===") + print(test_ds) + + + async with KuboCAS( # <-- own and auto-close session + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + hamt = await HAMT.build(cas=kubo_cas, values_are_bytes=True) + zhs = ZarrHAMTStore(hamt) + assert zhs.supports_writes + start = time.perf_counter() + # Do an initial write along with an append which is a common xarray/zarr operation + # Ensure chunks are not too small for partial value tests + test_ds.to_zarr(store=zhs, chunk_store={'time': 50, 'lat': 18, 'lon': 36}) + test_ds.to_zarr(store=zhs, mode="a", append_dim="time", zarr_format=3) + end = time.perf_counter() + elapsed = end - start + print("=== Write Stats") + print(f"Total time in seconds: {elapsed:.2f}") + print("=== Root CID") + await hamt.make_read_only() + cid = hamt.root_node_id + + print(f"=== Verifying Gateway Suffix Support (CID: {cid}) ===") + # Get the gateway URL without the /ipfs/ part + gateway_only_url = gateway_base_url + + # You can add an assertion here if you expect it to work + # If you know the gateway *might* be buggy, just printing is okay too. + assert is_correct, "IPFS Gateway did not return the correct suffix data." + + print("=== Reading data back in and checking if identical") + hamt_read = await HAMT.build( # Renamed to avoid confusion + cas=kubo_cas, root_node_id=cid, values_are_bytes=True, read_only=True + ) + start = time.perf_counter() + zhs_read = ZarrHAMTStore(hamt_read, read_only=True) # Use the read-only hamt + ipfs_ds = xr.open_zarr(store=zhs_read) + print(ipfs_ds) + + # Check both halves, since each are an identical copy + ds1 = ipfs_ds.isel(time=slice(0, len(ipfs_ds.time) // 2)) + ds2 = ipfs_ds.isel(time=slice(len(ipfs_ds.time) // 2, len(ipfs_ds.time))) + xr.testing.assert_identical(ds1, ds2) + xr.testing.assert_identical(test_ds, ds1) + xr.testing.assert_identical(test_ds, ds2) + + end = time.perf_counter() + elapsed = end - start + print("=== Read Stats") + print(f"Total time in seconds: {elapsed:.2f}") + + # --- Start: New Partial Values Tests --- + + print("=== Testing get_partial_values ===") + proto = zarr.core.buffer.default_buffer_prototype() + + # Find a chunk key to test with (e.g., the first chunk of 'temp') + chunk_key = None + async for k in zhs_read.list(): + if k.startswith("temp/") and k != "temp/.zarray": + chunk_key = k + break + + assert chunk_key is not None, "Could not find a chunk key to test." + print(f"Testing with chunk key: {chunk_key}") + + # Get the full chunk data for comparison + full_chunk_buffer = await zhs_read.get(chunk_key, proto) + assert full_chunk_buffer is not None + full_chunk_data = full_chunk_buffer.to_bytes() + chunk_len = len(full_chunk_data) + print(f"Full chunk size: {chunk_len} bytes") + + # Ensure the chunk is large enough for meaningful tests + assert chunk_len > 100, "Chunk size too small for partial value tests" + + # Define some byte requests + range_req = RangeByteRequest(start=10, end=50) # Request 40 bytes + offset_req = OffsetByteRequest(offset=chunk_len - 30) # Request last 30 bytes + suffix_req = SuffixByteRequest(suffix=20) # Request last 20 bytes + + key_ranges_to_test = [ + (chunk_key, range_req), + (chunk_key, offset_req), + (chunk_key, suffix_req), + (chunk_key, None), # Full read + ] + + # Call get_partial_values + results = await zhs_read.get_partial_values(proto, key_ranges_to_test) + + # Assertions + assert len(results) == 4 + assert results[0] is not None + assert results[1] is not None + assert results[2] is not None + assert results[3] is not None + + # Check RangeByteRequest result + expected_range = full_chunk_data[10:50] + assert results[0].to_bytes() == expected_range, "RangeByteRequest failed" + print(f"RangeByteRequest: OK (Got {len(results[0].to_bytes())} bytes)") + + # Check OffsetByteRequest result + expected_offset = full_chunk_data[chunk_len - 30:] + assert results[1].to_bytes() == expected_offset, "OffsetByteRequest failed" + print(f"OffsetByteRequest: OK (Got {len(results[1].to_bytes())} bytes)") + + # Check SuffixByteRequest result + expected_suffix = full_chunk_data[-20:] + # Broken until Kubo 0.36.0 + assert results[2].to_bytes() == expected_suffix, "SuffixByteRequest failed" + print(f"SuffixByteRequest: OK (Got {len(results[2].to_bytes())} bytes)") + + # Check full read result + assert results[3].to_bytes() == full_chunk_data, "Full read via get_partial_values failed" + print(f"Full Read: OK (Got {len(results[3].to_bytes())} bytes)") + + + # --- End: New Partial Values Tests --- + + + # Tests for code coverage's sake + assert await zhs_read.exists("zarr.json") + # __eq__ + assert zhs_read == zhs_read + assert zhs_read != hamt_read + assert not zhs_read.supports_writes + assert not zhs_read.supports_partial_writes + assert zhs_read.supports_deletes # Should be true in read-only mode for HAMT? Usually False + + hamt_keys = set() + async for k in zhs_read.hamt.keys(): + hamt_keys.add(k) + + zhs_keys: set[str] = set() + async for k in zhs_read.list(): + zhs_keys.add(k) + assert hamt_keys == zhs_keys + + zhs_keys: set[str] = set() + async for k in zhs_read.list_prefix(""): + zhs_keys.add(k) + assert hamt_keys == zhs_keys + + with pytest.raises(NotImplementedError): + await zhs_read.set_partial_values([]) + + # REMOVED: The old NotImplementedError check for get_partial_values + # with pytest.raises(NotImplementedError): + # await zhs_read.get_partial_values( + # zarr.core.buffer.default_buffer_prototype(), [] + # ) + + previous_zarr_json = await zhs_read.get( + "zarr.json", zarr.core.buffer.default_buffer_prototype() + ) + assert previous_zarr_json is not None + + # --- Test set_if_not_exists (needs a writable store) --- + await hamt_read.enable_write() + zhs_write = ZarrHAMTStore(hamt_read, read_only=False) + + # Setting a metadata file that should always exist should not change anything + await zhs_write.set_if_not_exists( + "zarr.json", + zarr.core.buffer.Buffer.from_bytes(b"should_not_change"), + ) + zarr_json_now = await zhs_write.get( + "zarr.json", zarr.core.buffer.default_buffer_prototype() + ) + assert zarr_json_now is not None + assert previous_zarr_json.to_bytes() == zarr_json_now.to_bytes() + + # now remove that metadata file and then add it back + await zhs_write.delete("zarr.json") + # doing a duplicate delete should not result in an error + await zhs_write.delete("zarr.json") + + zhs_keys_after_delete: set[str] = set() + async for k in zhs_write.list(): + zhs_keys_after_delete.add(k) + assert hamt_keys != zhs_keys_after_delete + assert "zarr.json" not in zhs_keys_after_delete + + await zhs_write.set_if_not_exists("zarr.json", previous_zarr_json) + zarr_json_now = await zhs_write.get( + "zarr.json", zarr.core.buffer.default_buffer_prototype() + ) + assert zarr_json_now is not None + assert previous_zarr_json.to_bytes() == zarr_json_now.to_bytes() \ No newline at end of file From 226967fcd81cc633f8da1f375f21672efb96e69a Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Fri, 30 May 2025 10:55:22 -0400 Subject: [PATCH 02/74] fix: re-add accidentally deleted _map_byte_request --- py_hamt/zarr_hamt_store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py_hamt/zarr_hamt_store.py b/py_hamt/zarr_hamt_store.py index 3b7294a..e378e16 100644 --- a/py_hamt/zarr_hamt_store.py +++ b/py_hamt/zarr_hamt_store.py @@ -112,6 +112,7 @@ async def get( if is_metadata and byte_range is None and key in self.metadata_read_cache: val = self.metadata_read_cache[key] else: + offset, length, suffix = self._map_byte_request(byte_range) val = cast( bytes, await self.hamt.get(key, offset=offset, length=length, suffix=suffix) ) # We know values received will always be bytes since we only store bytes in the HAMT From 89e8124af2e353f1e66d3d3b8e5a16453c8f7cbb Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 2 Jun 2025 09:18:13 -0400 Subject: [PATCH 03/74] fix: tidying up --- py_hamt/hamt.py | 12 ++++- py_hamt/store.py | 28 ++++++++--- py_hamt/zarr_hamt_store.py | 21 +++++---- tests/test_zarr_ipfs.py | 5 -- tests/test_zarr_ipfs_partial.py | 82 +++++++-------------------------- 5 files changed, 60 insertions(+), 88 deletions(-) diff --git a/py_hamt/hamt.py b/py_hamt/hamt.py index c94959e..aab1151 100644 --- a/py_hamt/hamt.py +++ b/py_hamt/hamt.py @@ -590,10 +590,18 @@ async def delete(self, key: str) -> None: # If we didn't make a change, then this key must not exist within the HAMT raise KeyError - async def get(self, key: str, offset: Optional[int] = None, length: Optional[int] = None, suffix: Optional[int] = None) -> IPLDKind: + async def get( + self, + key: str, + offset: Optional[int] = None, + length: Optional[int] = None, + suffix: Optional[int] = None, + ) -> IPLDKind: """Get a value.""" pointer: IPLDKind = await self.get_pointer(key) - data: bytes = await self.cas.load(pointer, offset=offset, length=length, suffix=suffix) + data: bytes = await self.cas.load( + pointer, offset=offset, length=length, suffix=suffix + ) if self.values_are_bytes: return data else: diff --git a/py_hamt/store.py b/py_hamt/store.py index 06b4a04..00b581d 100644 --- a/py_hamt/store.py +++ b/py_hamt/store.py @@ -30,7 +30,13 @@ async def save(self, data: bytes, codec: CodecInput) -> IPLDKind: """ @abstractmethod - async def load(self, id: IPLDKind, offset: Optional[int] = None, length: Optional[int] = None, suffix: Optional[int] = None) -> bytes: + async def load( + self, + id: IPLDKind, + offset: Optional[int] = None, + length: Optional[int] = None, + suffix: Optional[int] = None, + ) -> bytes: """Retrieve data.""" @@ -49,7 +55,13 @@ async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> by self.store[hash] = data return hash - async def load(self, id: IPLDKind, offset: Optional[int] = None, length: Optional[int] = None, suffix: Optional[int] = None) -> bytes: # type: ignore since bytes is a subset of the IPLDKind type + async def load( + self, + id: IPLDKind, + offset: Optional[int] = None, + length: Optional[int] = None, + suffix: Optional[int] = None, + ) -> bytes: # type: ignore since bytes is a subset of the IPLDKind type """ `ContentAddressedStore` allows any IPLD scalar key. For the in-memory backend we *require* a `bytes` hash; anything else is rejected at run @@ -60,12 +72,16 @@ async def load(self, id: IPLDKind, offset: Optional[int] = None, length: Optiona h/t https://stackoverflow.com/questions/75209249/overriding-a-method-mypy-throws-an-incompatible-with-super-type-error-when-ch """ key = cast(bytes, id) + if not isinstance(key, (bytes, bytearray)): # defensive guard + raise TypeError( + f"InMemoryCAS only supports byte‐hash keys; got {type(id).__name__}" + ) data: bytes try: data = self.store[key] except KeyError as exc: raise KeyError("Object not found in in-memory store") from exc - + if offset is not None: start = offset if length is not None: @@ -189,15 +205,15 @@ async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> CI return cid async def load( # type: ignore CID is definitely in the IPLDKind type - self, + self, id: IPLDKind, offset: Optional[int] = None, length: Optional[int] = None, - suffix: Optional[int] = None + suffix: Optional[int] = None, ) -> bytes: """@private""" cid = cast(CID, id) - url: str = self.gateway_base_url + str(id) + url: str = self.gateway_base_url + str(cid) headers: dict[str, str] = {} # Construct the Range header if required diff --git a/py_hamt/zarr_hamt_store.py b/py_hamt/zarr_hamt_store.py index e378e16..7fe8069 100644 --- a/py_hamt/zarr_hamt_store.py +++ b/py_hamt/zarr_hamt_store.py @@ -63,7 +63,9 @@ def __init__(self, hamt: HAMT, read_only: bool = False) -> None: self.metadata_read_cache: dict[str, bytes] = {} """@private""" - def _map_byte_request(self, byte_range: Optional[zarr.abc.store.ByteRequest]) -> tuple[Optional[int], Optional[int], Optional[int]]: + def _map_byte_request( + self, byte_range: Optional[zarr.abc.store.ByteRequest] + ) -> tuple[Optional[int], Optional[int], Optional[int]]: """Helper to map Zarr ByteRequest to offset, length, suffix.""" offset: Optional[int] = None length: Optional[int] = None @@ -81,7 +83,7 @@ def _map_byte_request(self, byte_range: Optional[zarr.abc.store.ByteRequest]) -> suffix = byte_range.suffix else: raise TypeError(f"Unsupported ByteRequest type: {type(byte_range)}") - + return offset, length, suffix @property @@ -114,7 +116,10 @@ async def get( else: offset, length, suffix = self._map_byte_request(byte_range) val = cast( - bytes, await self.hamt.get(key, offset=offset, length=length, suffix=suffix) + bytes, + await self.hamt.get( + key, offset=offset, length=length, suffix=suffix + ), ) # We know values received will always be bytes since we only store bytes in the HAMT if is_metadata and byte_range is None: self.metadata_read_cache[key] = val @@ -127,7 +132,6 @@ async def get( print(f"Error getting key '{key}' with range {byte_range}: {e}") raise - async def get_partial_values( self, prototype: zarr.core.buffer.BufferPrototype, @@ -136,11 +140,10 @@ async def get_partial_values( """ Retrieves multiple keys or byte ranges concurrently using asyncio.gather. """ - tasks = [ - self.get(key, prototype, byte_range) - for key, byte_range in key_ranges - ] - results = await asyncio.gather(*tasks, return_exceptions=False) # Set return_exceptions=True for debugging + tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + results = await asyncio.gather( + *tasks, return_exceptions=False + ) # Set return_exceptions=True for debugging return results async def exists(self, key: str) -> bool: diff --git a/tests/test_zarr_ipfs.py b/tests/test_zarr_ipfs.py index 11ffcb8..851936e 100644 --- a/tests/test_zarr_ipfs.py +++ b/tests/test_zarr_ipfs.py @@ -139,11 +139,6 @@ async def test_write_read( with pytest.raises(NotImplementedError): await zhs.set_partial_values([]) - with pytest.raises(NotImplementedError): - await zhs.get_partial_values( - zarr.core.buffer.default_buffer_prototype(), [] - ) - previous_zarr_json: zarr.core.buffer.Buffer | None = await zhs.get( "zarr.json", zarr.core.buffer.default_buffer_prototype() ) diff --git a/tests/test_zarr_ipfs_partial.py b/tests/test_zarr_ipfs_partial.py index 702441f..a39d045 100644 --- a/tests/test_zarr_ipfs_partial.py +++ b/tests/test_zarr_ipfs_partial.py @@ -8,9 +8,6 @@ import zarr.core.buffer # Make sure to import the ByteRequest types from zarr.abc.store import RangeByteRequest, OffsetByteRequest, SuffixByteRequest -import aiohttp -from typing import Optional - from py_hamt import HAMT, KuboCAS @@ -81,7 +78,7 @@ async def test_write_read( start = time.perf_counter() # Do an initial write along with an append which is a common xarray/zarr operation # Ensure chunks are not too small for partial value tests - test_ds.to_zarr(store=zhs, chunk_store={'time': 50, 'lat': 18, 'lon': 36}) + test_ds.to_zarr(store=zhs, chunk_store={"time": 50, "lat": 18, "lon": 36}) test_ds.to_zarr(store=zhs, mode="a", append_dim="time", zarr_format=3) end = time.perf_counter() elapsed = end - start @@ -93,18 +90,13 @@ async def test_write_read( print(f"=== Verifying Gateway Suffix Support (CID: {cid}) ===") # Get the gateway URL without the /ipfs/ part - gateway_only_url = gateway_base_url - - # You can add an assertion here if you expect it to work - # If you know the gateway *might* be buggy, just printing is okay too. - assert is_correct, "IPFS Gateway did not return the correct suffix data." print("=== Reading data back in and checking if identical") - hamt_read = await HAMT.build( # Renamed to avoid confusion + hamt_read = await HAMT.build( # Renamed to avoid confusion cas=kubo_cas, root_node_id=cid, values_are_bytes=True, read_only=True ) start = time.perf_counter() - zhs_read = ZarrHAMTStore(hamt_read, read_only=True) # Use the read-only hamt + zhs_read = ZarrHAMTStore(hamt_read, read_only=True) # Use the read-only hamt ipfs_ds = xr.open_zarr(store=zhs_read) print(ipfs_ds) @@ -131,7 +123,7 @@ async def test_write_read( if k.startswith("temp/") and k != "temp/.zarray": chunk_key = k break - + assert chunk_key is not None, "Could not find a chunk key to test." print(f"Testing with chunk key: {chunk_key}") @@ -141,20 +133,20 @@ async def test_write_read( full_chunk_data = full_chunk_buffer.to_bytes() chunk_len = len(full_chunk_data) print(f"Full chunk size: {chunk_len} bytes") - + # Ensure the chunk is large enough for meaningful tests assert chunk_len > 100, "Chunk size too small for partial value tests" # Define some byte requests - range_req = RangeByteRequest(start=10, end=50) # Request 40 bytes - offset_req = OffsetByteRequest(offset=chunk_len - 30) # Request last 30 bytes - suffix_req = SuffixByteRequest(suffix=20) # Request last 20 bytes + range_req = RangeByteRequest(start=10, end=50) # Request 40 bytes + offset_req = OffsetByteRequest(offset=chunk_len - 30) # Request last 30 bytes + suffix_req = SuffixByteRequest(suffix=20) # Request last 20 bytes key_ranges_to_test = [ (chunk_key, range_req), (chunk_key, offset_req), (chunk_key, suffix_req), - (chunk_key, None), # Full read + (chunk_key, None), # Full read ] # Call get_partial_values @@ -173,7 +165,7 @@ async def test_write_read( print(f"RangeByteRequest: OK (Got {len(results[0].to_bytes())} bytes)") # Check OffsetByteRequest result - expected_offset = full_chunk_data[chunk_len - 30:] + expected_offset = full_chunk_data[chunk_len - 30 :] assert results[1].to_bytes() == expected_offset, "OffsetByteRequest failed" print(f"OffsetByteRequest: OK (Got {len(results[1].to_bytes())} bytes)") @@ -184,13 +176,13 @@ async def test_write_read( print(f"SuffixByteRequest: OK (Got {len(results[2].to_bytes())} bytes)") # Check full read result - assert results[3].to_bytes() == full_chunk_data, "Full read via get_partial_values failed" + assert results[3].to_bytes() == full_chunk_data, ( + "Full read via get_partial_values failed" + ) print(f"Full Read: OK (Got {len(results[3].to_bytes())} bytes)") - # --- End: New Partial Values Tests --- - # Tests for code coverage's sake assert await zhs_read.exists("zarr.json") # __eq__ @@ -198,7 +190,9 @@ async def test_write_read( assert zhs_read != hamt_read assert not zhs_read.supports_writes assert not zhs_read.supports_partial_writes - assert zhs_read.supports_deletes # Should be true in read-only mode for HAMT? Usually False + assert not ( + zhs_read.supports_deletes + ) # Should be true in read-only mode for HAMT? Usually False hamt_keys = set() async for k in zhs_read.hamt.keys(): @@ -216,47 +210,3 @@ async def test_write_read( with pytest.raises(NotImplementedError): await zhs_read.set_partial_values([]) - - # REMOVED: The old NotImplementedError check for get_partial_values - # with pytest.raises(NotImplementedError): - # await zhs_read.get_partial_values( - # zarr.core.buffer.default_buffer_prototype(), [] - # ) - - previous_zarr_json = await zhs_read.get( - "zarr.json", zarr.core.buffer.default_buffer_prototype() - ) - assert previous_zarr_json is not None - - # --- Test set_if_not_exists (needs a writable store) --- - await hamt_read.enable_write() - zhs_write = ZarrHAMTStore(hamt_read, read_only=False) - - # Setting a metadata file that should always exist should not change anything - await zhs_write.set_if_not_exists( - "zarr.json", - zarr.core.buffer.Buffer.from_bytes(b"should_not_change"), - ) - zarr_json_now = await zhs_write.get( - "zarr.json", zarr.core.buffer.default_buffer_prototype() - ) - assert zarr_json_now is not None - assert previous_zarr_json.to_bytes() == zarr_json_now.to_bytes() - - # now remove that metadata file and then add it back - await zhs_write.delete("zarr.json") - # doing a duplicate delete should not result in an error - await zhs_write.delete("zarr.json") - - zhs_keys_after_delete: set[str] = set() - async for k in zhs_write.list(): - zhs_keys_after_delete.add(k) - assert hamt_keys != zhs_keys_after_delete - assert "zarr.json" not in zhs_keys_after_delete - - await zhs_write.set_if_not_exists("zarr.json", previous_zarr_json) - zarr_json_now = await zhs_write.get( - "zarr.json", zarr.core.buffer.default_buffer_prototype() - ) - assert zarr_json_now is not None - assert previous_zarr_json.to_bytes() == zarr_json_now.to_bytes() \ No newline at end of file From 6e5603199b6a3810e24164a68e6a4c003ef8eaf6 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 2 Jun 2025 09:30:17 -0400 Subject: [PATCH 04/74] fix: full coverage --- py_hamt/__init__.py | 3 +- py_hamt/zarr_hamt_store.py | 1 + tests/test_zarr_ipfs_partial.py | 62 +++++++++++++++++++++++++++++++-- 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/py_hamt/__init__.py b/py_hamt/__init__.py index 2d420b1..c6da3b9 100644 --- a/py_hamt/__init__.py +++ b/py_hamt/__init__.py @@ -1,5 +1,5 @@ from .hamt import HAMT, blake3_hashfn -from .store import ContentAddressedStore, KuboCAS +from .store import ContentAddressedStore, KuboCAS, InMemoryCAS from .zarr_hamt_store import ZarrHAMTStore __all__ = [ @@ -8,4 +8,5 @@ "ContentAddressedStore", "KuboCAS", "ZarrHAMTStore", + "InMemoryCAS", ] diff --git a/py_hamt/zarr_hamt_store.py b/py_hamt/zarr_hamt_store.py index 7fe8069..5e23403 100644 --- a/py_hamt/zarr_hamt_store.py +++ b/py_hamt/zarr_hamt_store.py @@ -8,6 +8,7 @@ from py_hamt.hamt import HAMT + class ZarrHAMTStore(zarr.abc.store.Store): """ Write and read Zarr v3s with a HAMT. diff --git a/tests/test_zarr_ipfs_partial.py b/tests/test_zarr_ipfs_partial.py index a39d045..24ff2cc 100644 --- a/tests/test_zarr_ipfs_partial.py +++ b/tests/test_zarr_ipfs_partial.py @@ -6,11 +6,12 @@ import pytest import zarr import zarr.core.buffer + # Make sure to import the ByteRequest types from zarr.abc.store import RangeByteRequest, OffsetByteRequest, SuffixByteRequest -from py_hamt import HAMT, KuboCAS +from py_hamt import HAMT, KuboCAS, InMemoryCAS from py_hamt.zarr_hamt_store import ZarrHAMTStore @@ -57,6 +58,7 @@ def random_zarr_dataset(): yield ds + # This test also collects miscellaneous statistics about performance, run with pytest -s to see these statistics being printed out @pytest.mark.asyncio(loop_scope="session") # ← match the loop of the fixture async def test_write_read( @@ -68,7 +70,6 @@ async def test_write_read( print("=== Writing this xarray Dataset to a Zarr v3 on IPFS ===") print(test_ds) - async with KuboCAS( # <-- own and auto-close session rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url ) as kubo_cas: @@ -210,3 +211,60 @@ async def test_write_read( with pytest.raises(NotImplementedError): await zhs_read.set_partial_values([]) + + +@pytest.mark.asyncio +async def test_zarr_hamt_store_byte_request_errors(): + """Tests error handling for unsupported or invalid ByteRequest types.""" + cas = InMemoryCAS() + hamt = await HAMT.build(cas=cas, values_are_bytes=True) + zhs = ZarrHAMTStore(hamt) + proto = zarr.core.buffer.default_buffer_prototype() + await zhs.set("some_key", proto.buffer.from_bytes(b"0123456789")) + + # Test for ValueError with an invalid range (end < start) + invalid_range_req = RangeByteRequest(start=10, end=5) + with pytest.raises(ValueError, match="End must be >= start for RangeByteRequest"): + await zhs.get("some_key", proto, byte_range=invalid_range_req) + + # Test for TypeError with a custom, unsupported request type + class DummyUnsupportedRequest: + pass + + unsupported_req = DummyUnsupportedRequest() + with pytest.raises(TypeError, match="Unsupported ByteRequest type"): + await zhs.get("some_key", proto, byte_range=unsupported_req) + + +@pytest.mark.asyncio +async def test_in_memory_cas_partial_reads(): + """ + Tests the partial read logic (offset, length, suffix) in the InMemoryCAS. + """ + cas = InMemoryCAS() + test_data = b"0123456789abcdefghijklmnopqrstuvwxyz" # 36 bytes long + data_id = await cas.save(test_data, "raw") + + # Test case 1: offset and length + result = await cas.load(data_id, offset=10, length=5) + assert result == b"abcde" + + # Test case 2: offset only + result = await cas.load(data_id, offset=30) + assert result == b"uvwxyz" + + # Test case 3: suffix only + result = await cas.load(data_id, suffix=6) + assert result == b"uvwxyz" + + # Test case 4: Full read (for completeness) + result = await cas.load(data_id) + assert result == test_data + + # Test case 5: Key not found (covers `try...except KeyError`) + with pytest.raises(KeyError, match="Object not found in in-memory store"): + await cas.load(b"\x00\x01\x02\x03\x04") # Some random, non-existent key + + # Test case 6: Invalid key type (covers `isinstance` check) + with pytest.raises(TypeError, match="InMemoryCAS only supports byte‐hash keys"): + await cas.load(12345) # Pass an integer instead of bytes From 46b2f9972c870daeef7eb7dc9df05b8cfd5727b5 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Tue, 3 Jun 2025 08:00:39 -0400 Subject: [PATCH 05/74] fix: re-order --- py_hamt/__init__.py | 2 +- py_hamt/hamt.py | 2 +- py_hamt/store.py | 3 +-- py_hamt/zarr_hamt_store.py | 5 ++--- tests/test_zarr_ipfs_encrypted.py | 5 ----- tests/test_zarr_ipfs_partial.py | 8 +++----- 6 files changed, 8 insertions(+), 17 deletions(-) diff --git a/py_hamt/__init__.py b/py_hamt/__init__.py index 03c5986..0a63761 100644 --- a/py_hamt/__init__.py +++ b/py_hamt/__init__.py @@ -1,6 +1,6 @@ from .encryption_hamt_store import SimpleEncryptedZarrHAMTStore from .hamt import HAMT, blake3_hashfn -from .store import ContentAddressedStore, KuboCAS, InMemoryCAS +from .store import ContentAddressedStore, InMemoryCAS, KuboCAS from .zarr_hamt_store import ZarrHAMTStore __all__ = [ diff --git a/py_hamt/hamt.py b/py_hamt/hamt.py index a232da9..bf1716d 100644 --- a/py_hamt/hamt.py +++ b/py_hamt/hamt.py @@ -8,8 +8,8 @@ Callable, Dict, Iterator, - cast, Optional, + cast, ) import dag_cbor diff --git a/py_hamt/store.py b/py_hamt/store.py index 33b9897..784d539 100644 --- a/py_hamt/store.py +++ b/py_hamt/store.py @@ -1,7 +1,6 @@ import asyncio -import aiohttp from abc import ABC, abstractmethod -from typing import Any, Literal, cast, Optional +from typing import Any, Literal, Optional, cast import aiohttp from dag_cbor.ipld import IPLDKind diff --git a/py_hamt/zarr_hamt_store.py b/py_hamt/zarr_hamt_store.py index 79fab77..e265656 100644 --- a/py_hamt/zarr_hamt_store.py +++ b/py_hamt/zarr_hamt_store.py @@ -1,11 +1,10 @@ +import asyncio from collections.abc import AsyncIterator, Iterable -from typing import cast +from typing import Optional, cast import zarr.abc.store import zarr.core.buffer from zarr.core.common import BytesLike -from typing import Optional -import asyncio from py_hamt.hamt import HAMT diff --git a/tests/test_zarr_ipfs_encrypted.py b/tests/test_zarr_ipfs_encrypted.py index 8d65d1c..32c7a2e 100644 --- a/tests/test_zarr_ipfs_encrypted.py +++ b/tests/test_zarr_ipfs_encrypted.py @@ -195,11 +195,6 @@ async def test_encrypted_write_read( with pytest.raises(NotImplementedError): await ezhs_read_ok.set_partial_values([]) - with pytest.raises(NotImplementedError): - await ezhs_read_ok.get_partial_values( - zarr.core.buffer.default_buffer_prototype(), [] - ) - with pytest.raises(Exception): await ezhs_read_ok.set("new_key", np.array([b"a"], dtype=np.bytes_)) diff --git a/tests/test_zarr_ipfs_partial.py b/tests/test_zarr_ipfs_partial.py index 24ff2cc..0cd076a 100644 --- a/tests/test_zarr_ipfs_partial.py +++ b/tests/test_zarr_ipfs_partial.py @@ -2,17 +2,15 @@ import numpy as np import pandas as pd -import xarray as xr import pytest +import xarray as xr import zarr import zarr.core.buffer # Make sure to import the ByteRequest types -from zarr.abc.store import RangeByteRequest, OffsetByteRequest, SuffixByteRequest - - -from py_hamt import HAMT, KuboCAS, InMemoryCAS +from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest +from py_hamt import HAMT, InMemoryCAS, KuboCAS from py_hamt.zarr_hamt_store import ZarrHAMTStore From 0fb9be9d85d4e9bc485dc601f704f6406079c08e Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Tue, 3 Jun 2025 08:21:18 -0400 Subject: [PATCH 06/74] fix: update ruff and mypy --- py_hamt/store.py | 4 ++-- tests/test_zarr_ipfs_encrypted.py | 11 ++++------- tests/test_zarr_ipfs_partial.py | 12 ++++++------ 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/py_hamt/store.py b/py_hamt/store.py index 784d539..e82fad6 100644 --- a/py_hamt/store.py +++ b/py_hamt/store.py @@ -61,7 +61,7 @@ async def load( offset: Optional[int] = None, length: Optional[int] = None, suffix: Optional[int] = None, - ) -> bytes: # type: ignore since bytes is a subset of the IPLDKind type + ) -> bytes: """ `ContentAddressedStore` allows any IPLD scalar key. For the in-memory backend we *require* a `bytes` hash; anything else is rejected at run @@ -283,7 +283,7 @@ async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> CI cid = cid.set(codec=codec) return cid - async def load( # type: ignore CID is definitely in the IPLDKind type + async def load( self, id: IPLDKind, offset: Optional[int] = None, diff --git a/tests/test_zarr_ipfs_encrypted.py b/tests/test_zarr_ipfs_encrypted.py index 32c7a2e..ccbf114 100644 --- a/tests/test_zarr_ipfs_encrypted.py +++ b/tests/test_zarr_ipfs_encrypted.py @@ -4,9 +4,8 @@ import pandas as pd import pytest import xarray as xr -import zarr -import zarr.core.buffer from Crypto.Random import get_random_bytes +from dag_cbor.ipld import IPLDKind from py_hamt import HAMT, KuboCAS, SimpleEncryptedZarrHAMTStore from py_hamt.zarr_hamt_store import ZarrHAMTStore @@ -91,9 +90,7 @@ async def test_encrypted_write_read( correct_key = get_random_bytes(32) wrong_key = get_random_bytes(32) header = b"test-encryption-header" - - root_cid = None - + root_cid: IPLDKind = None # --- Write Phase --- async with KuboCAS( rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url @@ -109,7 +106,7 @@ async def test_encrypted_write_read( assert ezhs_write != hamt_write assert ezhs_write.supports_writes - test_ds.to_zarr(store=ezhs_write, mode="w", zarr_format=3) # Use mode='w' + test_ds.to_zarr(store=ezhs_write, mode="w", zarr_format=3) # type: ignore await hamt_write.make_read_only() root_cid = hamt_write.root_node_id @@ -196,7 +193,7 @@ async def test_encrypted_write_read( await ezhs_read_ok.set_partial_values([]) with pytest.raises(Exception): - await ezhs_read_ok.set("new_key", np.array([b"a"], dtype=np.bytes_)) + await ezhs_read_ok.set("new_key", np.array([b"a"], dtype=np.bytes_)) # type: ignore with pytest.raises(Exception): await ezhs_read_ok.delete("zarr.json") diff --git a/tests/test_zarr_ipfs_partial.py b/tests/test_zarr_ipfs_partial.py index 0cd076a..da3ba6c 100644 --- a/tests/test_zarr_ipfs_partial.py +++ b/tests/test_zarr_ipfs_partial.py @@ -77,8 +77,8 @@ async def test_write_read( start = time.perf_counter() # Do an initial write along with an append which is a common xarray/zarr operation # Ensure chunks are not too small for partial value tests - test_ds.to_zarr(store=zhs, chunk_store={"time": 50, "lat": 18, "lon": 36}) - test_ds.to_zarr(store=zhs, mode="a", append_dim="time", zarr_format=3) + test_ds.to_zarr(store=zhs, chunk_store={"time": 50, "lat": 18, "lon": 36}) # type: ignore + test_ds.to_zarr(store=zhs, mode="a", append_dim="time", zarr_format=3) # type: ignore end = time.perf_counter() elapsed = end - start print("=== Write Stats") @@ -149,7 +149,7 @@ async def test_write_read( ] # Call get_partial_values - results = await zhs_read.get_partial_values(proto, key_ranges_to_test) + results = await zhs_read.get_partial_values(proto, key_ranges_to_test) # type: ignore # Assertions assert len(results) == 4 @@ -202,10 +202,10 @@ async def test_write_read( zhs_keys.add(k) assert hamt_keys == zhs_keys - zhs_keys: set[str] = set() + zhs_keys_prefix: set[str] = set() async for k in zhs_read.list_prefix(""): - zhs_keys.add(k) - assert hamt_keys == zhs_keys + zhs_keys_prefix.add(k) + assert hamt_keys == zhs_keys_prefix with pytest.raises(NotImplementedError): await zhs_read.set_partial_values([]) From f7f169dea16e9b96b5deb72ec38035e58f0efefe Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Tue, 3 Jun 2025 08:25:15 -0400 Subject: [PATCH 07/74] fix: pre-commit --- py_hamt/zarr_hamt_store.py | 2 +- tests/test_zarr_ipfs_encrypted.py | 2 +- tests/test_zarr_ipfs_partial.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/py_hamt/zarr_hamt_store.py b/py_hamt/zarr_hamt_store.py index e265656..9ef5042 100644 --- a/py_hamt/zarr_hamt_store.py +++ b/py_hamt/zarr_hamt_store.py @@ -91,7 +91,7 @@ def _map_byte_request( if isinstance(byte_range, zarr.abc.store.RangeByteRequest): offset = byte_range.start length = byte_range.end - byte_range.start - if length < 0: + if length is not None and length < 0: raise ValueError("End must be >= start for RangeByteRequest") elif isinstance(byte_range, zarr.abc.store.OffsetByteRequest): offset = byte_range.offset diff --git a/tests/test_zarr_ipfs_encrypted.py b/tests/test_zarr_ipfs_encrypted.py index ccbf114..93aa74b 100644 --- a/tests/test_zarr_ipfs_encrypted.py +++ b/tests/test_zarr_ipfs_encrypted.py @@ -193,7 +193,7 @@ async def test_encrypted_write_read( await ezhs_read_ok.set_partial_values([]) with pytest.raises(Exception): - await ezhs_read_ok.set("new_key", np.array([b"a"], dtype=np.bytes_)) # type: ignore + await ezhs_read_ok.set("new_key", np.array([b"a"], dtype=np.bytes_)) # type: ignore with pytest.raises(Exception): await ezhs_read_ok.delete("zarr.json") diff --git a/tests/test_zarr_ipfs_partial.py b/tests/test_zarr_ipfs_partial.py index da3ba6c..e7a7a05 100644 --- a/tests/test_zarr_ipfs_partial.py +++ b/tests/test_zarr_ipfs_partial.py @@ -77,8 +77,8 @@ async def test_write_read( start = time.perf_counter() # Do an initial write along with an append which is a common xarray/zarr operation # Ensure chunks are not too small for partial value tests - test_ds.to_zarr(store=zhs, chunk_store={"time": 50, "lat": 18, "lon": 36}) # type: ignore - test_ds.to_zarr(store=zhs, mode="a", append_dim="time", zarr_format=3) # type: ignore + test_ds.to_zarr(store=zhs, chunk_store={"time": 50, "lat": 18, "lon": 36}) # type: ignore + test_ds.to_zarr(store=zhs, mode="a", append_dim="time", zarr_format=3) # type: ignore end = time.perf_counter() elapsed = end - start print("=== Write Stats") @@ -149,7 +149,7 @@ async def test_write_read( ] # Call get_partial_values - results = await zhs_read.get_partial_values(proto, key_ranges_to_test) # type: ignore + results = await zhs_read.get_partial_values(proto, key_ranges_to_test) # type: ignore # Assertions assert len(results) == 4 From e17003af4219dd653cae2aa61bb81dfc81a6fea5 Mon Sep 17 00:00:00 2001 From: Faolain Date: Wed, 4 Jun 2025 19:03:01 -0400 Subject: [PATCH 08/74] deps: update kubo to latest in tests --- CLAUDE.md | 87 ++++++++++++++++++++++++++++++++++++++++++ tests/testing_utils.py | 2 +- 2 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..49c6cf4 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,87 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Common Development Commands + +Setup environment: +```bash +uv sync +source .venv/bin/activate +pre-commit install +``` + +Run all checks (tests, linting, formatting, type checking): +```bash +bash run-checks.sh +``` + +Run tests: +```bash +# All tests (requires IPFS daemon or Docker) +pytest --ipfs --cov=py_hamt tests/ + +# Quick tests without IPFS integration +pytest --cov=py_hamt tests/ + +# Single test file +pytest tests/test_hamt.py + +# Coverage report +uv run coverage report --fail-under=100 --show-missing +``` + +Linting and formatting: +```bash +# Run all pre-commit hooks +uv run pre-commit run --all-files --show-diff-on-failure + +# Fix auto-fixable ruff issues +uv run ruff check --fix +``` + +Type checking and other tools: +```bash +# Type checking is handled by pre-commit hooks (mypy) +# Documentation preview +uv run pdoc py_hamt +``` + +## Architecture Overview + +py-hamt implements a Hash Array Mapped Trie (HAMT) for IPFS/IPLD content-addressed storage. The core architecture follows this pattern: + +1. **ContentAddressedStore (CAS)** - Abstract storage layer (store.py) + - `KuboCAS` - IPFS/Kubo implementation for production + - `InMemoryCAS` - In-memory implementation for testing + +2. **HAMT** - Core data structure (hamt.py) + - Uses blake3 hashing by default + - Implements content-addressed trie for efficient key-value storage + - Supports async operations for large datasets + +3. **ZarrHAMTStore** - Zarr integration (zarr_hamt_store.py) + - Implements zarr.abc.store.Store interface + - Enables storing large Zarr arrays on IPFS via HAMT + - Keys stored verbatim, values as raw bytes + +4. **Encryption Layer** - Optional encryption (encryption_hamt_store.py) + - `SimpleEncryptedZarrHAMTStore` for fully encrypted storage + +## Key Design Patterns + +- All storage operations are async to handle IPFS network calls +- Content addressing means identical data gets same hash/CID +- HAMT provides O(log n) access time for large key sets +- Store abstractions allow swapping storage backends +- Type hints required throughout (mypy enforced) +- 100% test coverage required with hypothesis property-based testing + +## IPFS Integration Requirements + +Tests require either: +- Local IPFS daemon running (`ipfs daemon`) +- Docker available for containerized IPFS +- Neither (unit tests only, integration tests skip) + +The `--ipfs` pytest flag controls IPFS test execution. diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 154f4cf..d109d43 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -169,7 +169,7 @@ def create_ipfs(): if client is None: pytest.skip("Neither IPFS daemon nor Docker available – skipping IPFS tests") - image = "ipfs/kubo:v0.35.0" + image = "ipfs/kubo:master-latest" rpc_p = _free_port() gw_p = _free_port() From 72fdcfabfe427e48ddc0c0d3c60612e5419073b8 Mon Sep 17 00:00:00 2001 From: Faolain Date: Wed, 4 Jun 2025 19:15:58 -0400 Subject: [PATCH 09/74] test: add test for ipfs gateway partials --- tests/test_zarr_ipfs_partial.py | 193 ++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) diff --git a/tests/test_zarr_ipfs_partial.py b/tests/test_zarr_ipfs_partial.py index e7a7a05..84de087 100644 --- a/tests/test_zarr_ipfs_partial.py +++ b/tests/test_zarr_ipfs_partial.py @@ -234,6 +234,199 @@ class DummyUnsupportedRequest: await zhs.get("some_key", proto, byte_range=unsupported_req) +@pytest.mark.asyncio +async def test_ipfs_gateway_compression_behavior(create_ipfs): + """ + Test to verify whether IPFS gateways decompress data before applying + byte range requests, which would negate compression benefits for partial reads. + + This test creates highly compressible data, stores it via IPFS, and then + compares the bytes returned by partial vs full reads to determine if + the gateway is operating on compressed or decompressed data. + """ + rpc_base_url, gateway_base_url = create_ipfs + + print("\n=== IPFS Gateway Compression Behavior Test ===") + + # Create highly compressible test data + print("Creating highly compressible test data...") + data = np.zeros((50, 50, 50), dtype=np.float32) # 500KB of zeros + # Add small amount of variation + data[0:5, 0:5, 0:5] = np.random.randn(5, 5, 5) + + ds = xr.Dataset({"compressible_var": (["x", "y", "z"], data)}) + + print(f"Original data shape: {data.shape}") + print(f"Original data size: {data.nbytes:,} bytes") + + # Custom CAS to track actual network transfers + class NetworkTrackingKuboCAS(KuboCAS): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.load_sizes = {} + self.save_sizes = {} + + async def save(self, data, codec=None): + cid = await super().save(data, codec) + self.save_sizes[str(cid)] = len(data) + print(f"Saved to IPFS: {str(cid)[:12]}... ({len(data):,} bytes)") + return cid + + async def load(self, cid, offset=None, length=None, suffix=None): + result = await super().load(cid, offset, length, suffix) + + range_desc = "full" + if offset is not None or length is not None or suffix is not None: + range_desc = f"offset={offset}, length={length}, suffix={suffix}" + + key = f"{str(cid)[:12]}... ({range_desc})" + self.load_sizes[key] = len(result) + print(f"Loaded from IPFS: {key} -> {len(result):,} bytes") + return result + + async with NetworkTrackingKuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as tracking_cas: + # Write dataset with compression + print("\n=== Writing to IPFS with Blosc compression ===") + hamt = await HAMT.build(cas=tracking_cas, values_are_bytes=True) + store = ZarrHAMTStore(hamt) + + # Use smaller chunks to ensure meaningful compression + ds.chunk({"x": 25, "y": 25, "z": 25}).to_zarr( + store=store, mode="w", zarr_format=3 + ) + + await hamt.make_read_only() + root_cid = hamt.root_node_id + print(f"Root CID: {root_cid}") + + # Read back and test compression behavior + print("\n=== Testing Compression vs Partial Reads ===") + hamt_read = await HAMT.build( + cas=tracking_cas, + root_node_id=root_cid, + values_are_bytes=True, + read_only=True, + ) + store_read = ZarrHAMTStore(hamt_read, read_only=True) + + # Find the largest data chunk (likely the actual array data) + chunk_key = None + chunk_size = 0 + async for key in store_read.list(): + if ( + "compressible_var" in key + and not key.endswith(".zarray") + and not key.endswith("zarr.json") + ): + # Get size to find the largest chunk + proto = zarr.core.buffer.default_buffer_prototype() + chunk_buffer = await store_read.get(key, proto) + if chunk_buffer and len(chunk_buffer.to_bytes()) > chunk_size: + chunk_key = key + chunk_size = len(chunk_buffer.to_bytes()) + + assert chunk_key is not None, "No data chunk found" + print(f"Testing with largest chunk: {chunk_key}") + + # Get full chunk for baseline + proto = zarr.core.buffer.default_buffer_prototype() + full_chunk = await store_read.get(chunk_key, proto) + full_compressed_size = len(full_chunk.to_bytes()) + print(f"Full chunk compressed size: {full_compressed_size:,} bytes") + + # Calculate expected uncompressed size + # 25x25x25 float32 = 62,500 bytes uncompressed + expected_uncompressed_size = 25 * 25 * 25 * 4 + compression_ratio = expected_uncompressed_size / full_compressed_size + print(f"Estimated compression ratio: {compression_ratio:.1f}:1") + + # Test different partial read sizes + test_ranges = [ + (0, full_compressed_size // 4, "25% of compressed"), + (0, full_compressed_size // 2, "50% of compressed"), + (full_compressed_size // 4, full_compressed_size // 2, "25%-50% range"), + ] + + print("\n=== Partial Read Analysis ===") + print("If gateway operates on compressed data:") + print(" - Partial reads should return exactly the requested byte ranges") + print(" - Network transfer should be proportional to compressed size") + print("If gateway decompresses before range requests:") + print(" - Partial reads may return more data than expected") + print(" - Network transfer loses compression benefits") + print() + + compression_preserved = True + + for start, end, description in test_ranges: + length = end - start + byte_req = RangeByteRequest(start=start, end=end) + + # Clear the load tracking for this specific test + original_load_count = len(tracking_cas.load_sizes) + + partial_chunk = await store_read.get(chunk_key, proto, byte_range=byte_req) + partial_size = len(partial_chunk.to_bytes()) + + # Find the new load entry + new_loads = list(tracking_cas.load_sizes.items())[original_load_count:] + network_bytes = new_loads[0][1] if new_loads else partial_size + + expected_percentage = length / full_compressed_size + actual_percentage = partial_size / full_compressed_size + network_efficiency = network_bytes / full_compressed_size + + print(f"Range request: {description}") + print( + f" Requested: {length:,} bytes ({expected_percentage:.1%} of compressed)" + ) + print( + f" Received: {partial_size:,} bytes ({actual_percentage:.1%} of compressed)" + ) + print( + f" Network transfer: {network_bytes:,} bytes ({network_efficiency:.1%} of compressed)" + ) + + # Key test: if we get significantly more data than requested, + # the gateway is likely decompressing before applying ranges + if partial_size > length * 1.1: # 10% tolerance for overhead + compression_preserved = False + print( + f" ⚠️ Received {partial_size / length:.1f}x more data than requested!" + ) + print(" ⚠️ Gateway appears to decompress before applying ranges") + else: + print(" ✅ Range applied efficiently to compressed data") + + # Verify the partial data makes sense + full_data = full_chunk.to_bytes() + expected_partial = full_data[start:end] + assert partial_chunk.to_bytes() == expected_partial, ( + "Partial data doesn't match expected range" + ) + print(" ✅ Partial data content verified") + print() + + # Summary analysis + print("=== Final Analysis ===") + if compression_preserved: + print("✅ IPFS gateway preserves compression benefits for partial reads") + print(" - Byte ranges are applied to compressed data") + print(" - Network transfers are efficient") + else: + print("⚠️ IPFS gateway appears to decompress before applying ranges") + print(" - Partial reads may not provide expected bandwidth savings") + print(" - Consider alternative storage strategies (sharding, etc.)") + + print("\nCompression statistics:") + print(f" - Uncompressed chunk size: {expected_uncompressed_size:,} bytes") + print(f" - Compressed chunk size: {full_compressed_size:,} bytes") + print(f" - Compression ratio: {compression_ratio:.1f}:1") + print(f" - Compression preserved in partial reads: {compression_preserved}") + + @pytest.mark.asyncio async def test_in_memory_cas_partial_reads(): """ From c92f5ec690491dbbb79c17648ad1064345eb9a90 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 5 Jun 2025 12:24:52 -0400 Subject: [PATCH 10/74] fix: sharding --- py_hamt/__init__.py | 4 + py_hamt/flat_zarr_store.py | 370 ++++++++++++++++ py_hamt/sharded_zarr_store.py | 745 +++++++++++++++++++++++++++++++++ tests/test_benchmark_stores.py | 282 +++++++++++++ 4 files changed, 1401 insertions(+) create mode 100644 py_hamt/flat_zarr_store.py create mode 100644 py_hamt/sharded_zarr_store.py create mode 100644 tests/test_benchmark_stores.py diff --git a/py_hamt/__init__.py b/py_hamt/__init__.py index 0a63761..7819c54 100644 --- a/py_hamt/__init__.py +++ b/py_hamt/__init__.py @@ -2,6 +2,8 @@ from .hamt import HAMT, blake3_hashfn from .store import ContentAddressedStore, InMemoryCAS, KuboCAS from .zarr_hamt_store import ZarrHAMTStore +from .flat_zarr_store import FlatZarrStore +from .sharded_zarr_store import ShardedZarrStore __all__ = [ "blake3_hashfn", @@ -11,4 +13,6 @@ "ZarrHAMTStore", "InMemoryCAS", "SimpleEncryptedZarrHAMTStore", + "FlatZarrStore", + "ShardedZarrStore", ] diff --git a/py_hamt/flat_zarr_store.py b/py_hamt/flat_zarr_store.py new file mode 100644 index 0000000..b4a96ad --- /dev/null +++ b/py_hamt/flat_zarr_store.py @@ -0,0 +1,370 @@ +import asyncio +import math +from collections.abc import AsyncIterator, Iterable +from typing import Optional, cast + +import dag_cbor +import zarr.abc.store +import zarr.core.buffer +from zarr.core.common import BytesLike + +from .store import ContentAddressedStore + + +class FlatZarrStore(zarr.abc.store.Store): + """ + Implements the Zarr Store API using a flat, predictable layout for chunk CIDs. + + This store bypasses the need for a HAMT, offering direct, calculated + access to chunk data based on a mathematical formula. It is designed for + dense, multi-dimensional arrays where chunk locations are predictable. + + The store is structured around a single root CBOR object. This root object contains: + 1. A dictionary mapping metadata keys (like 'zarr.json') to their CIDs. + 2. A single CID pointing to a large, contiguous block of bytes (the "flat index"). + This flat index is a concatenation of the CIDs of all data chunks. + + Accessing a chunk involves: + 1. Loading the root object (if not cached). + 2. Calculating the byte offset of the chunk's CID within the flat index. + 3. Fetching that specific CID using a byte-range request on the flat index. + 4. Fetching the actual chunk data using the retrieved CID. + + ### Sample Code + ```python + import xarray as xr + import numpy as np + from py_hamt import KuboCAS, FlatZarrStore + + # --- Write --- + ds = xr.Dataset( + {"data": (("t", "y", "x"), np.arange(24).reshape(2, 3, 4))}, + ) + cas = KuboCAS() + # When creating, must provide array shape and chunk shape + store_write = await FlatZarrStore.open( + cas, + read_only=False, + array_shape=ds.data.shape, + chunk_shape=ds.data.encoding['chunks'] + ) + ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() # IMPORTANT: flush to get final root CID + print(f"Finished writing. Root CID: {root_cid}") + + + # --- Read --- + store_read = await FlatZarrStore.open(cas, read_only=True, root_cid=root_cid) + ds_read = xr.open_zarr(store=store_read) + print("Read back dataset:") + print(ds_read) + xr.testing.assert_identical(ds, ds_read) + ``` + """ + + def __init__( + self, cas: ContentAddressedStore, read_only: bool, root_cid: Optional[str] + ): + """Use the async `open()` classmethod to instantiate this class.""" + super().__init__(read_only=read_only) + self.cas = cas + self._root_cid = root_cid + self._root_obj: Optional[dict] = None + self._flat_index_cache: Optional[bytearray] = None + self._cid_len: Optional[int] = None + self._array_shape: Optional[tuple[int, ...]] = None + self._chunk_shape: Optional[tuple[int, ...]] = None + self._chunks_per_dim: Optional[tuple[int, ...]] = None + self._dirty = False + + @classmethod + async def open( + cls, + cas: ContentAddressedStore, + read_only: bool, + root_cid: Optional[str] = None, + *, + array_shape: Optional[tuple[int, ...]] = None, + chunk_shape: Optional[tuple[int, ...]] = None, + cid_len: int = 59, # Default for base32 v1 CIDs like bafy... + ) -> "FlatZarrStore": + """ + Asynchronously opens an existing FlatZarrStore or initializes a new one. + + Args: + cas: The Content Addressed Store (e.g., KuboCAS). + read_only: If True, the store is in read-only mode. + root_cid: The root CID of an existing store to open. Required for read_only. + array_shape: The full shape of the Zarr array. Required for a new writeable store. + chunk_shape: The shape of a single chunk. Required for a new writeable store. + cid_len: The expected byte length of a CID string. + """ + store = cls(cas, read_only, root_cid) + if root_cid: + await store._load_root_from_cid() + elif not read_only: + if not all([array_shape, chunk_shape]): + raise ValueError( + "array_shape and chunk_shape must be provided for a new store." + ) + store._initialize_new_root(array_shape, chunk_shape, cid_len) + else: + raise ValueError("root_cid must be provided for a read-only store.") + return store + + def _initialize_new_root( + self, + array_shape: tuple[int, ...], + chunk_shape: tuple[int, ...], + cid_len: int, + ): + self._array_shape = array_shape + self._chunk_shape = chunk_shape + self._cid_len = cid_len + self._chunks_per_dim = tuple( + math.ceil(a / c) for a, c in zip(array_shape, chunk_shape) + ) + self._root_obj = { + "manifest_version": "flat_zarr_v1", + "metadata": {}, + "chunks": { + "cid": None, # Will be filled on first flush + "array_shape": list(self._array_shape), + "chunk_shape": list(self._chunk_shape), + "cid_byte_length": self._cid_len, + }, + } + self._dirty = True + + async def _load_root_from_cid(self): + if not self._root_cid: + raise ValueError("Cannot load root without a root_cid.") + root_bytes = await self.cas.load(self._root_cid) + self._root_obj = dag_cbor.decode(root_bytes) + chunk_info = self._root_obj.get("chunks", {}) + self._array_shape = tuple(chunk_info["array_shape"]) + self._chunk_shape = tuple(chunk_info["chunk_shape"]) + self._cid_len = chunk_info["cid_byte_length"] + self._chunks_per_dim = tuple( + math.ceil(a / c) for a, c in zip(self._array_shape, self._chunk_shape) + ) + + def _parse_chunk_key(self, key: str) -> Optional[tuple[int, ...]]: + if not self._array_shape or not key.startswith("c/"): + return None + parts = key.split("/") + if len(parts) != len(self._array_shape) + 1: + return None + try: + return tuple(map(int, parts[1:])) + except (ValueError, IndexError): + return None + + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, BytesLike]] + ) -> None: + """@private""" + raise NotImplementedError("Partial writes are not supported by this store.") + + async def get_partial_values( + self, + prototype: zarr.core.buffer.BufferPrototype, + key_ranges: Iterable[tuple[str, zarr.abc.store.ByteRequest | None]], + ) -> list[zarr.core.buffer.Buffer | None]: + """ + Retrieves multiple keys or byte ranges concurrently. + """ + tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + results = await asyncio.gather(*tasks) + return results + + def __eq__(self, other: object) -> bool: + """@private""" + if not isinstance(other, FlatZarrStore): + return NotImplemented + return self._root_cid == other._root_cid + + def _get_chunk_offset(self, chunk_coords: tuple[int, ...]) -> int: + linear_index = 0 + multiplier = 1 + for i in reversed(range(len(self._chunks_per_dim))): + linear_index += chunk_coords[i] * multiplier + multiplier *= self._chunks_per_dim[i] + return linear_index * self._cid_len + + async def flush(self) -> str: + """ + Writes all pending changes (metadata and chunk index) to the CAS + and returns the new root CID. This MUST be called after all writes are complete. + """ + if self.read_only or not self._dirty: + return self._root_cid + + if self._flat_index_cache is not None: + flat_index_cid_obj = await self.cas.save( + bytes(self._flat_index_cache), codec="raw" + ) + self._root_obj["chunks"]["cid"] = str(flat_index_cid_obj) + + root_obj_bytes = dag_cbor.encode(self._root_obj) + new_root_cid_obj = await self.cas.save(root_obj_bytes, codec="dag-cbor") + self._root_cid = str(new_root_cid_obj) + self._dirty = False + return self._root_cid + + async def get( + self, + key: str, + prototype: zarr.core.buffer.BufferPrototype, + byte_range: zarr.abc.store.ByteRequest | None = None, + ) -> zarr.core.buffer.Buffer | None: + """@private""" + if self._root_obj is None: + await self._load_root_from_cid() + + chunk_coords = self._parse_chunk_key(key) + try: + # Metadata request + if chunk_coords is None: + metadata_cid = self._root_obj["metadata"].get(key) + if metadata_cid is None: + return None + data = await self.cas.load(metadata_cid) + return prototype.buffer.from_bytes(data) + + # Chunk data request + flat_index_cid = self._root_obj["chunks"]["cid"] + if flat_index_cid is None: + return None + + offset = self._get_chunk_offset(chunk_coords) + chunk_cid_bytes = await self.cas.load( + flat_index_cid, offset=offset, length=self._cid_len + ) + + if all(b == 0 for b in chunk_cid_bytes): + return None # Chunk doesn't exist + + chunk_cid = chunk_cid_bytes.decode("ascii") + data = await self.cas.load(chunk_cid) + return prototype.buffer.from_bytes(data) + + except (KeyError, IndexError): + return None + + async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: + """@private""" + if self.read_only: + raise ValueError("Cannot write to a read-only store.") + if self._root_obj is None: + raise RuntimeError("Store not initialized for writing.") + + self._dirty = True + raw_bytes = value.to_bytes() + value_cid_obj = await self.cas.save(raw_bytes, codec="raw") + value_cid = str(value_cid_obj) + + if len(value_cid) != self._cid_len: + raise ValueError( + f"Inconsistent CID length. Expected {self._cid_len}, got {len(value_cid)}" + ) + + chunk_coords = self._parse_chunk_key(key) + + if chunk_coords is None: # Metadata + self._root_obj["metadata"][key] = value_cid + return + + # Chunk Data + if self._flat_index_cache is None: + num_chunks = math.prod(self._chunks_per_dim) + self._flat_index_cache = bytearray(num_chunks * self._cid_len) + + offset = self._get_chunk_offset(chunk_coords) + self._flat_index_cache[offset : offset + self._cid_len] = value_cid.encode( + "ascii" + ) + + # --- Other required zarr.abc.store methods --- + + async def exists(self, key: str) -> bool: + """@private""" + # A more efficient version might check for null bytes in the flat index + # but this is functionally correct. + + # TODO: Optimize this check + return True + + + # return (await self.get(key, zarr.core.buffer.Buffer.prototype, None)) is not None + + @property + def supports_writes(self) -> bool: + """@private""" + return not self.read_only + + @property + def supports_partial_writes(self) -> bool: + """@private""" + return False # Each chunk is an immutable object + + @property + def supports_deletes(self) -> bool: + """@private""" + return not self.read_only + + async def delete(self, key: str) -> None: + if self.read_only: + raise ValueError("Cannot delete from a read-only store.") + if self._root_obj is None: + await self._load_root_from_cid() + chunk_coords = self._parse_chunk_key(key) + if chunk_coords is None: + if key in self._root_obj["metadata"]: + del self._root_obj["metadata"][key] + self._dirty = True + return + else: + raise KeyError(f"Metadata key '{key}' not found.") + flat_index_cid = self._root_obj["chunks"]["cid"] + if self._flat_index_cache is None: + if not flat_index_cid: + raise KeyError(f"Chunk key '{key}' not found in non-existent index.") + self._flat_index_cache = bytearray(await self.cas.load(flat_index_cid)) + offset = self._get_chunk_offset(chunk_coords) + if all(b == 0 for b in self._flat_index_cache[offset : offset + self._cid_len]): + raise KeyError(f"Chunk key '{key}' not found.") + self._flat_index_cache[offset : offset + self._cid_len] = bytearray(self._cid_len) + self._dirty = True + + @property + def supports_listing(self) -> bool: + """@private""" + return True + + async def list(self) -> AsyncIterator[str]: + """@private""" + if self._root_obj is None: + await self._load_root_from_cid() + for key in self._root_obj["metadata"]: + yield key + # Note: Listing all chunk keys without reading the index is non-trivial. + # A full implementation might need an efficient way to iterate non-null chunks. + # This basic version only lists metadata. + + async def list_prefix(self, prefix: str) -> AsyncIterator[str]: + """@private""" + async for key in self.list(): + if key.startswith(prefix): + yield key + + async def list_dir(self, prefix: str) -> AsyncIterator[str]: + """@private""" + # This simplified version only works for the root. + if prefix == "": + seen = set() + async for key in self.list(): + name = key.split('/')[0] + if name not in seen: + seen.add(name) + yield name \ No newline at end of file diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py new file mode 100644 index 0000000..fbc2afa --- /dev/null +++ b/py_hamt/sharded_zarr_store.py @@ -0,0 +1,745 @@ +import asyncio +import math +from collections.abc import AsyncIterator, Iterable +from typing import Optional, cast, Dict, List, Set, Tuple + +import dag_cbor +import zarr.abc.store +import zarr.core.buffer +from zarr.core.common import BytesLike +from .store import ContentAddressedStore + + +class ShardedZarrStore(zarr.abc.store.Store): + """ + Implements the Zarr Store API using a sharded layout for chunk CIDs. + + This store divides the flat index of chunk CIDs into multiple smaller "shards". + Each shard is a contiguous block of bytes containing CIDs for a subset of chunks. + This can improve performance for certain access patterns and reduce the size + of individual index objects stored in the CAS. + + The store's root object contains: + 1. A dictionary mapping metadata keys (like 'zarr.json') to their CIDs. + 2. A list of CIDs, where each CID points to a shard of the chunk index. + 3. Sharding configuration details (e.g., chunks_per_shard). + + Accessing a chunk involves: + 1. Loading the root object (if not cached). + 2. Determining the shard index and the offset of the chunk's CID within that shard. + 3. Fetching the specific shard's CID from the root object. + 4. Fetching the chunk's CID using a byte-range request on the identified shard. + 5. Fetching the actual chunk data using the retrieved chunk CID. + """ + + def __init__( + self, + cas: ContentAddressedStore, + read_only: bool, + root_cid: Optional[str], + ): + """Use the async `open()` classmethod to instantiate this class.""" + super().__init__(read_only=read_only) + self.cas = cas + self._root_cid = root_cid + self._root_obj: Optional[dict] = None + + self._shard_data_cache: Dict[int, bytearray] = {} # shard_index -> shard_byte_data + self._dirty_shards: Set[int] = set() # Set of shard_indices that need flushing + self._pending_shard_loads: Dict[int, asyncio.Task] = {} # shard_index -> Task loading the full shard + + self._cid_len: Optional[int] = None + self._array_shape: Optional[Tuple[int, ...]] = None + self._chunk_shape: Optional[Tuple[int, ...]] = None + self._chunks_per_dim: Optional[Tuple[int, ...]] = None # Number of chunks in each dimension + self._chunks_per_shard: Optional[int] = None # How many chunk CIDs per shard + self._num_shards: Optional[int] = None # Total number of shards + self._total_chunks: Optional[int] = None # Total number of chunks in the array + + self._dirty_root = False # Indicates if the root object itself (metadata or shard_cids list) changed + + @classmethod + async def open( + cls, + cas: ContentAddressedStore, + read_only: bool, + root_cid: Optional[str] = None, + *, + array_shape: Optional[Tuple[int, ...]] = None, + chunk_shape: Optional[Tuple[int, ...]] = None, + chunks_per_shard: Optional[int] = None, + cid_len: int = 59, # Default for base32 v1 CIDs like bafy... (e.g., bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi) + ) -> "ShardedZarrStore": + """ + Asynchronously opens an existing ShardedZarrStore or initializes a new one. + """ + store = cls(cas, read_only, root_cid) + if root_cid: + await store._load_root_from_cid() + elif not read_only: + if not all([array_shape, chunk_shape, chunks_per_shard is not None]): + raise ValueError( + "array_shape, chunk_shape, and chunks_per_shard must be provided for a new store." + ) + if not isinstance(chunks_per_shard, int) or chunks_per_shard <= 0: + raise ValueError("chunks_per_shard must be a positive integer.") + store._initialize_new_root(array_shape, chunk_shape, chunks_per_shard, cid_len) + else: + raise ValueError("root_cid must be provided for a read-only store.") + return store + + def _initialize_new_root( + self, + array_shape: Tuple[int, ...], + chunk_shape: Tuple[int, ...], + chunks_per_shard: int, + cid_len: int, + ): + self._array_shape = array_shape + self._chunk_shape = chunk_shape + self._cid_len = cid_len + self._chunks_per_shard = chunks_per_shard + + if not all(cs > 0 for cs in chunk_shape): + raise ValueError("All chunk_shape dimensions must be positive.") + if not all(asarray_s >= 0 for asarray_s in array_shape): # array_shape dims can be 0 + raise ValueError("All array_shape dimensions must be non-negative.") + + + self._chunks_per_dim = tuple( + math.ceil(a / c) if c > 0 else 0 for a, c in zip(array_shape, chunk_shape) + ) + self._total_chunks = math.prod(self._chunks_per_dim) + + if self._total_chunks == 0: + self._num_shards = 0 + else: + self._num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) + + self._root_obj = { + "manifest_version": "sharded_zarr_v1", + "metadata": {}, # For .zgroup, .zarray, .zattrs etc. + "chunks": { # Information about the chunk index itself + "array_shape": list(self._array_shape), # Original array shape + "chunk_shape": list(self._chunk_shape), # Original chunk shape + "cid_byte_length": self._cid_len, + "sharding_config": { + "chunks_per_shard": self._chunks_per_shard, + }, + "shard_cids": [None] * self._num_shards, # List of CIDs for each shard + }, + } + self._dirty_root = True + + async def _load_root_from_cid(self): + if not self._root_cid: + raise ValueError("Cannot load root without a root_cid.") + root_bytes = await self.cas.load(self._root_cid) + self._root_obj = dag_cbor.decode(root_bytes) + + if self._root_obj.get("manifest_version") != "sharded_zarr_v1": + raise ValueError(f"Incompatible manifest version: {self._root_obj.get('manifest_version')}. Expected 'sharded_zarr_v1'.") + + chunk_info = self._root_obj["chunks"] + self._array_shape = tuple(chunk_info["array_shape"]) + self._chunk_shape = tuple(chunk_info["chunk_shape"]) + self._cid_len = chunk_info["cid_byte_length"] + sharding_cfg = chunk_info.get("sharding_config", {}) # Handle older formats if any planned + self._chunks_per_shard = sharding_cfg["chunks_per_shard"] + + if not all(cs > 0 for cs in self._chunk_shape): + raise ValueError("Loaded chunk_shape dimensions must be positive.") + + self._chunks_per_dim = tuple( + math.ceil(a / c) if c > 0 else 0 for a, c in zip(self._array_shape, self._chunk_shape) + ) + self._total_chunks = math.prod(self._chunks_per_dim) + + expected_num_shards = 0 + if self._total_chunks > 0: + expected_num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) + self._num_shards = expected_num_shards + + if len(chunk_info["shard_cids"]) != self._num_shards: + raise ValueError( + f"Inconsistent number of shards. Expected {self._num_shards} from shapes/config, " + f"found {len(chunk_info['shard_cids'])} in root object's shard_cids list." + ) + + async def _fetch_and_cache_full_shard(self, shard_idx: int, shard_cid: str): + """ + Fetches the full data for a shard and caches it. + Manages removal from _pending_shard_loads. + """ + try: + # Double check if it got cached by another operation while this task was scheduled + if shard_idx in self._shard_data_cache: + return + + shard_data_bytes = await self.cas.load(shard_cid) # Load full shard + self._shard_data_cache[shard_idx] = bytearray(shard_data_bytes) + # print(f"DEBUG: Successfully cached full shard {shard_idx} (CID: {shard_cid})") + + except Exception as e: + # Handle or log the exception appropriately + print(f"Warning: Failed to cache full shard {shard_idx} (CID: {shard_cid}): {e}") + # If it fails, subsequent requests might try again if it's still not in cache. + finally: + # Ensure the task is removed from pending list once done (success or failure) + if shard_idx in self._pending_shard_loads: + del self._pending_shard_loads[shard_idx] + + def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: + # 1. Exclude .json files immediately (metadata) + if key.endswith(".json"): + return None + excluded_array_prefixes = {"time", "lat", "lon", "latitude", "longitude"} + + chunk_marker = "/c/" + marker_idx = key.rfind(chunk_marker) # Use rfind for robustness + if marker_idx == -1: + # Key does not contain "/c/", so it's not a chunk data key + # in the expected format (e.g., could be .zattrs, .zgroup at various levels). + return None + + # Extract the part of the key before "/c/", which might represent the array/group path + # e.g., "temp" from "temp/c/0/0/0" + # e.g., "group1/lat" from "group1/lat/c/0" + # e.g., "" if key is "c/0/0/0" (root array) + path_before_c = key[:marker_idx] + + # Determine the actual array name (the last component of the path before "/c/") + actual_array_name = "" + if path_before_c: + actual_array_name = path_before_c.split('/')[-1] + + # 2. If the determined array name is in our exclusion list, return None. + if actual_array_name in excluded_array_prefixes: + return None + + # If we've reached here, the key is potentially for a "main" data variable + # that this store instance is expected to handle via sharding. + # Now, proceed with the original parsing logic using self._array_shape and + # self._chunks_per_dim, which should be configured for this main data variable. + + if not self._array_shape or not self._chunks_per_dim: + # This ShardedZarrStore instance is not properly initialized + # with the shape/chunking info for the array it's supposed to manage. + # This might also happen if a key like "some_other_main_array/c/0" is passed + # but this store instance was configured for "temp". + return None + + # The part after "/c/" contains the chunk coordinates + coord_part = key[marker_idx + len(chunk_marker):] + parts = coord_part.split('/') + + # Validate dimensionality: + # The number of coordinate parts must match the dimensionality of the array + # this store instance is configured for (self._chunks_per_dim). + if len(parts) != len(self._chunks_per_dim): + # This key's dimensionality does not match the store's configured array. + # It's likely for a different array or a malformed key for the current array. + return None + + try: + coords = tuple(map(int, parts)) + # Validate coordinates against the chunk grid of the store's configured array + for i, c_coord in enumerate(coords): + if not (0 <= c_coord < self._chunks_per_dim[i]): + return None # Coordinate out of bounds for this array's chunk grid + return coords + except (ValueError, IndexError): # If int conversion fails or other issues + return None + + def _get_linear_chunk_index(self, chunk_coords: Tuple[int, ...]) -> int: + if not self._chunks_per_dim: + raise RuntimeError("Store not initialized: _chunks_per_dim is None.") + linear_index = 0 + multiplier = 1 + # Convert N-D chunk coordinates to a flat 1-D index (row-major order) + for i in reversed(range(len(self._chunks_per_dim))): + linear_index += chunk_coords[i] * multiplier + multiplier *= self._chunks_per_dim[i] + return linear_index + + def _get_shard_info(self, linear_chunk_index: int) -> Tuple[int, int]: + if self._chunks_per_shard is None or self._chunks_per_shard <= 0: + raise RuntimeError("Sharding not configured properly: _chunks_per_shard invalid.") + if linear_chunk_index < 0: + raise ValueError("Linear chunk index cannot be negative.") + + shard_idx = linear_chunk_index // self._chunks_per_shard + index_in_shard = linear_chunk_index % self._chunks_per_shard + return shard_idx, index_in_shard + + def _get_actual_chunks_in_shard(self, shard_idx: int) -> int: + return self._chunks_per_shard + # if self._num_shards is None or self._total_chunks is None or self._chunks_per_shard is None: + # raise RuntimeError("Store not properly initialized for shard calculation (_num_shards, _total_chunks, or _chunks_per_shard is None).") + # if not (0 <= shard_idx < self._num_shards): # Handles case where _num_shards can be 0 + # if self._num_shards == 0 and shard_idx == 0 and self._total_chunks == 0: # special case for 0-size array + # return 0 + # raise ValueError(f"Invalid shard index: {shard_idx}. Num shards: {self._num_shards}") + + # if shard_idx == self._num_shards - 1: # Last shard + # remaining_chunks = self._total_chunks - shard_idx * self._chunks_per_shard + # return remaining_chunks + # else: + # return self._chunks_per_shard + + async def _load_or_initialize_shard_cache(self, shard_idx: int) -> bytearray: + if shard_idx in self._shard_data_cache: + return self._shard_data_cache[shard_idx] + + # Check if a background load for this shard is already in progress + if shard_idx in self._pending_shard_loads: + # print(f"DEBUG: _load_or_initialize_shard_cache - Shard {shard_idx} has a pending load. Awaiting it.") + try: + await self._pending_shard_loads[shard_idx] + # After awaiting, it should be in the cache if the task was successful + if shard_idx in self._shard_data_cache: + return self._shard_data_cache[shard_idx] + else: + # Task finished but didn't populate cache (e.g., an error occurred in the task) + # print(f"DEBUG: _load_or_initialize_shard_cache - Pending load for {shard_idx} completed but shard not in cache. Proceeding to load manually.") + pass # Fall through to normal loading + except asyncio.CancelledError: + # print(f"DEBUG: _load_or_initialize_shard_cache - Pending load for {shard_idx} was cancelled. Proceeding to load manually.") + # Ensure it's removed if cancelled before its finally block ran + if shard_idx in self._pending_shard_loads: + del self._pending_shard_loads[shard_idx] + # Fall through to normal loading + except Exception as e: + # The pending task itself might have failed with an exception + print(f"Warning: Pending shard load for {shard_idx} failed: {e}. Attempting fresh load.") + # Fall through to normal loading. The pending task's finally block should have cleaned it up. + + if self._root_obj is None: + raise RuntimeError("Root object not loaded or initialized (_root_obj is None).") + if not (0 <= shard_idx < self._num_shards if self._num_shards is not None else False): + raise ValueError(f"Shard index {shard_idx} out of bounds for {self._num_shards} shards.") + + + shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] + if shard_cid: + shard_data_bytes = await self.cas.load(shard_cid) + # Verify length? + # expected_len = self._get_actual_chunks_in_shard(shard_idx) * self._cid_len + # if len(shard_data_bytes) != expected_len: + # raise ValueError(f"Shard {shard_idx} (CID: {shard_cid}) has unexpected length. Got {len(shard_data_bytes)}, expected {expected_len}") + self._shard_data_cache[shard_idx] = bytearray(shard_data_bytes) + else: + if self._cid_len is None: # Should be set + raise RuntimeError("Store not initialized: _cid_len is None for shard initialization.") + # New shard or shard not yet written, initialize with zeros + num_chunks_in_this_shard = self._get_actual_chunks_in_shard(shard_idx) + shard_size_bytes = num_chunks_in_this_shard * self._cid_len + self._shard_data_cache[shard_idx] = bytearray(shard_size_bytes) # Filled with \x00 + return self._shard_data_cache[shard_idx] + + + async def set_partial_values( + self, key_start_values: Iterable[Tuple[str, int, BytesLike]] + ) -> None: + raise NotImplementedError("Partial writes are not supported by ShardedZarrStore.") + + async def get_partial_values( + self, + prototype: zarr.core.buffer.BufferPrototype, + key_ranges: Iterable[Tuple[str, zarr.abc.store.ByteRequest | None]], + ) -> List[Optional[zarr.core.buffer.Buffer]]: + tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + results = await asyncio.gather(*tasks) + return results # type: ignore + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ShardedZarrStore): + return NotImplemented + # For equality, root CID is primary. Config like chunks_per_shard is part of that root's identity. + return self._root_cid == other._root_cid + + async def flush(self) -> str: + if self.read_only: + if self._root_cid is None: # Read-only store should have been opened with a root_cid + raise ValueError("Read-only store has no root CID to return. Was it opened correctly?") + return self._root_cid + + if self._root_obj is None: # Should be initialized for a writable store + raise RuntimeError("Store not initialized for writing: _root_obj is None.") + + # Save all dirty shards first, as their CIDs might need to go into the root object + if self._dirty_shards: + for shard_idx in sorted(list(self._dirty_shards)): + if shard_idx not in self._shard_data_cache: + # This implies an internal logic error if a shard is dirty but not in cache + # However, could happen if cache was cleared externally; robust code might reload/reinit + print(f"Warning: Dirty shard {shard_idx} not found in cache. Skipping save for this shard.") + continue + + shard_data_bytes = bytes(self._shard_data_cache[shard_idx]) + + # The CAS save method here should return a string CID. + new_shard_cid = await self.cas.save(shard_data_bytes, codec="raw") # Shards are raw bytes of CIDs + + if self._root_obj["chunks"]["shard_cids"][shard_idx] != new_shard_cid: + self._root_obj["chunks"]["shard_cids"][shard_idx] = new_shard_cid + self._dirty_root = True # Root object changed because a shard_cid in its list changed + + self._dirty_shards.clear() + + if self._dirty_root: + root_obj_bytes = dag_cbor.encode(self._root_obj) + new_root_cid = await self.cas.save(root_obj_bytes, codec="dag-cbor") + self._root_cid = str(new_root_cid) # Ensure it's string + self._dirty_root = False + + if self._root_cid is None: # Should only happen if nothing was dirty AND it was a new store never flushed + raise RuntimeError("Failed to obtain a root CID after flushing. Store might be empty or unchanged.") + return self._root_cid + + + async def get( + self, + key: str, + prototype: zarr.core.buffer.BufferPrototype, + byte_range: Optional[zarr.abc.store.ByteRequest] = None, + ) -> Optional[zarr.core.buffer.Buffer]: + if self._root_obj is None: + if not self._root_cid: + raise ValueError("Store not initialized and no root_cid to load from.") + await self._load_root_from_cid() # This will populate self._root_obj + if self._root_obj is None: # Should be loaded by _load_root_from_cid + raise RuntimeError("Failed to load root object after _load_root_from_cid call.") + + chunk_coords = self._parse_chunk_key(key) + try: + # Metadata request (e.g., ".zarray", ".zgroup") + if chunk_coords is None: + metadata_cid = self._root_obj["metadata"].get(key) + if metadata_cid is None: + return None + # byte_range is not typically applicable to metadata JSON objects themselves + if byte_range is not None: + # Consider if this should be an error or ignored for metadata + print(f"Warning: byte_range requested for metadata key '{key}'. Ignoring range.") + data = await self.cas.load(metadata_cid) + return prototype.buffer.from_bytes(data) + + # Chunk data request (e.g., "c/0/0/0") + if self._cid_len is None: # Should be set during init/load + raise RuntimeError("Store not properly initialized: _cid_len is None.") + + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + # print("SHARD LOCATION", linear_chunk_index, shard_idx, index_in_shard) # Debugging info + + + if not (0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): + # This case implies linear_chunk_index was out of _total_chunks bounds or bad sharding logic + return None + + target_shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] + if target_shard_cid is None: # This shard has no data (all chunks within it are implicitly empty) + return None + + offset_in_shard_bytes = index_in_shard * self._cid_len + chunk_cid_bytes: Optional[bytes] = None + + if shard_idx in self._shard_data_cache: + print(f"DEBUG: get() - Shard {shard_idx} found in cache. Key: {key}") + cached_shard_data = self._shard_data_cache[shard_idx] + if offset_in_shard_bytes + self._cid_len <= len(cached_shard_data): + chunk_cid_bytes = bytes(cached_shard_data[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len]) + else: + # This would indicate an inconsistency or error in shard data/cache. + print(f"Warning: Cached shard {shard_idx} is smaller than expected for key {key}. Re-fetching CID.") + # Fall through to fetch from CAS, and potentially re-cache full shard. + # To be very robust, you might consider invalidating this cache entry here. + del self._shard_data_cache[shard_idx] # Invalidate corrupted/short cache entry + if shard_idx in self._pending_shard_loads: # Cancel if a load was pending for this now-invalidated cache + self._pending_shard_loads[shard_idx].cancel() + del self._pending_shard_loads[shard_idx] + # Fallthrough to load chunk_cid_bytes directly + + if chunk_cid_bytes is None: # Not in cache or cache was invalid + # print(f"DEBUG: get() - Shard {shard_idx} not in cache or invalid. Fetching specific CID. Key: {key}") + try: + print("FETCHING SPECIFIC CID BYTES FROM SHARD Until Shard is cached. Key:", key) + chunk_cid_bytes = await self.cas.load( + target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len + ) + except Exception as e: # Handle error from CAS load (e.g. shard CID not found, network issue) + # print(f"Error: Failed to load specific CID bytes from shard {target_shard_cid} for key {key}: {e}") + return None # Chunk CID couldn't be retrieved + + # After successfully fetching the specific CID bytes, + # check if we should initiate a background load of the full shard. + if shard_idx not in self._shard_data_cache and shard_idx not in self._pending_shard_loads: + # print(f"DEBUG: get() - Initiating background cache for full shard {shard_idx} (CID: {target_shard_cid}). Key: {key}") + self._pending_shard_loads[shard_idx] = asyncio.create_task( + self._fetch_and_cache_full_shard(shard_idx, target_shard_cid) + ) + + # Load the specific CID from the shard + # chunk_cid_bytes = await self.cas.load( + # target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len + # ) + + if all(b == 0 for b in chunk_cid_bytes): # Check for null CID placeholder (e.g. \x00 * cid_len) + return None # Chunk doesn't exist or is considered empty + + # Decode CID (assuming ASCII, remove potential null padding) + chunk_cid_str = chunk_cid_bytes.decode("ascii").rstrip('\x00') + if not chunk_cid_str: # Empty string after rstrip if all were \x00 (already caught above) + return None + + # Actual chunk data load using the retrieved chunk_cid_str + req_offset = byte_range.start if byte_range else None + req_length = None + if byte_range: + if byte_range.stop is not None: + if byte_range.start > byte_range.stop: # Zarr allows start == stop for 0 length + raise ValueError(f"Byte range start ({byte_range.start}) cannot be greater than stop ({byte_range.stop})") + req_length = byte_range.stop - byte_range.start + + data = await self.cas.load(chunk_cid_str, offset=req_offset, length=req_length) + return prototype.buffer.from_bytes(data) + + except (KeyError, IndexError, TypeError, ValueError) as e: + # print(f"Error during get for key {key} (coords: {chunk_coords}): {type(e).__name__} - {e}") # for debugging + return None # Consistent with Zarr behavior for missing keys + + + async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: + if self.read_only: + raise ValueError("Cannot write to a read-only store.") + if self._root_obj is None: + raise RuntimeError("Store not initialized for writing (root_obj is None). Call open() first.") + if self._cid_len is None: + raise RuntimeError("Store not initialized for writing (_cid_len is None).") + + raw_chunk_data_bytes = value.to_bytes() + # Save the actual chunk data to CAS first, to get its CID + chunk_data_cid_obj = await self.cas.save(raw_chunk_data_bytes, codec="raw") # Chunks are typically raw bytes + chunk_data_cid_str = str(chunk_data_cid_obj) + + # Ensure the CID (as ASCII bytes) fits in the allocated slot, padding with nulls + chunk_data_cid_ascii_bytes = chunk_data_cid_str.encode("ascii") + if len(chunk_data_cid_ascii_bytes) > self._cid_len: + raise ValueError( + f"Encoded CID byte length ({len(chunk_data_cid_ascii_bytes)}) exceeds configured CID length ({self._cid_len}). CID: {chunk_data_cid_str}" + ) + padded_chunk_data_cid_bytes = chunk_data_cid_ascii_bytes.ljust(self._cid_len, b'\0') + + + chunk_coords = self._parse_chunk_key(key) + + if chunk_coords is None: # Metadata key (e.g., ".zarray") + # For metadata, the 'value' is the metadata content itself, not a CID to it. + # So, we store the metadata content, get its CID, and put *that* CID in root_obj. + # This means the `value_cid_str` for metadata should be from `raw_chunk_data_bytes`. + # This seems to align with FlatZarrStore, where `value_cid` is used for both. + self._root_obj["metadata"][key] = chunk_data_cid_str # Store the string CID of the metadata content + self._dirty_root = True + return + + # Chunk Data: `chunk_data_cid_str` is the CID of the data we just saved. + # Now we need to store this CID string (padded) into the correct shard. + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + + # Ensure the target shard is loaded or initialized in cache + target_shard_data_cache = await self._load_or_initialize_shard_cache(shard_idx) + + offset_in_shard_bytes = index_in_shard * self._cid_len + + # Check if the content is actually changing to avoid unnecessary dirtying (optional optimization) + # current_bytes_in_shard = target_shard_data_cache[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len] + # if current_bytes_in_shard == padded_chunk_data_cid_bytes: + # return # No change + + target_shard_data_cache[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len] = padded_chunk_data_cid_bytes + self._dirty_shards.add(shard_idx) + # If this write implies the shard CID in root_obj["chunks"]["shard_cids"] might change + # (e.g., from None to an actual CID when the shard is first flushed), + # then _dirty_root will be set by flush(). + + + async def exists(self, key: str) -> bool: + if self._root_obj is None: + if not self._root_cid: return False + try: + await self._load_root_from_cid() + except Exception: # If loading fails, it doesn't exist in this store + return False + if self._root_obj is None: return False + + + chunk_coords = self._parse_chunk_key(key) + if chunk_coords is None: # Metadata + return key in self._root_obj.get("metadata", {}) + + # Chunk + if self._cid_len is None: return False # Store not properly configured + + try: + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + + if not (self._root_obj and "chunks" in self._root_obj and \ + 0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): + return False + + target_shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] + if target_shard_cid is None: # Shard itself doesn't exist + return False + + offset_in_shard_bytes = index_in_shard * self._cid_len + + # Optimization: Check local shard cache first + if shard_idx in self._shard_data_cache: + cached_shard_data = self._shard_data_cache[shard_idx] + # Ensure index_in_shard is valid for this cached data length + if offset_in_shard_bytes + self._cid_len <= len(cached_shard_data): + chunk_cid_bytes_from_cache = cached_shard_data[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len] + return not all(b == 0 for b in chunk_cid_bytes_from_cache) + # else: fall through to CAS load, cache might be out of sync or wrong size (should not happen with correct logic) + + # If not in cache or cache check was inconclusive, fetch from CAS + chunk_cid_bytes_from_cas = await self.cas.load( + target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len + ) + return not all(b == 0 for b in chunk_cid_bytes_from_cas) + except Exception: # Broad catch for issues like invalid coords, CAS errors during load etc. + return False + + + @property + def supports_writes(self) -> bool: + return not self.read_only + + @property + def supports_partial_writes(self) -> bool: + return False # Each chunk CID is written atomically into a shard slot + + @property + def supports_deletes(self) -> bool: + return not self.read_only + + async def delete(self, key: str) -> None: + if self.read_only: + raise ValueError("Cannot delete from a read-only store.") + if self._root_obj is None: + if self._root_cid: # Try loading if deleting from an existing, non-modified store + try: + await self._load_root_from_cid() + except Exception as e: # If load fails, can't proceed + raise RuntimeError(f"Failed to load store for deletion: {e}") + if self._root_obj is None: # Still None after attempt + raise RuntimeError("Store not initialized for deletion (root_obj is None).") + if self._cid_len is None: + raise RuntimeError("Store not properly initialized for deletion (_cid_len is None).") + + chunk_coords = self._parse_chunk_key(key) + if chunk_coords is None: # Metadata + if key in self._root_obj.get("metadata", {}): + del self._root_obj["metadata"][key] + self._dirty_root = True + return + else: + raise KeyError(f"Metadata key '{key}' not found for deletion.") + + + # Chunk deletion: zero out the CID entry in the shard + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + + if not (0 <= shard_idx < (self._num_shards if self._num_shards is not None else 0)): + raise KeyError(f"Chunk key '{key}' maps to an invalid shard index {shard_idx}.") + + # Ensure shard data is available for modification (loads from CAS if not in cache, or initializes if new) + target_shard_data_cache = await self._load_or_initialize_shard_cache(shard_idx) + + offset_in_shard_bytes = index_in_shard * self._cid_len + + # Check if the entry is already zeroed (meaning it doesn't exist or already deleted) + is_already_zero = True + for i in range(self._cid_len): + if offset_in_shard_bytes + i >= len(target_shard_data_cache) or \ + target_shard_data_cache[offset_in_shard_bytes + i] != 0: + is_already_zero = False + break + + if is_already_zero: + raise KeyError(f"Chunk key '{key}' not found or already effectively deleted (CID slot is zeroed).") + + # Zero out the CID entry in the shard cache + for i in range(self._cid_len): + target_shard_data_cache[offset_in_shard_bytes + i] = 0 + + self._dirty_shards.add(shard_idx) + # If this shard becomes non-None in root_obj due to other writes, flush will handle it. + # If this deletion makes a previously non-None shard all zeros, the shard itself might + # eventually be elided if we had shard GC, but its CID remains in root_obj for now. + + @property + def supports_listing(self) -> bool: + return True # Can list metadata keys + + async def list(self) -> AsyncIterator[str]: + if self._root_obj is None: + if not self._root_cid: + return # Equivalent to `yield from ()` for async iterators + try: + await self._load_root_from_cid() + except Exception: # If loading fails, store is effectively empty for listing + return + if self._root_obj is None: + return + + for key in self._root_obj.get("metadata", {}): + yield key + # Listing all actual chunk keys would require iterating all shards and + # checking for non-null CIDs, which is expensive and not implemented here. + # This behavior is consistent with the provided FlatZarrStore example. + + async def list_prefix(self, prefix: str) -> AsyncIterator[str]: + # Only lists metadata keys matching prefix. + async for key in self.list(): # self.list() currently only yields metadata keys + if key.startswith(prefix): + yield key + + async def list_dir(self, prefix: str) -> AsyncIterator[str]: + # This simplified version only works for the root directory (prefix == "") of metadata. + # It lists unique first components of metadata keys. + if self._root_obj is None: + if not self._root_cid: return + try: + await self._load_root_from_cid() + except Exception: + return + if self._root_obj is None: return + + seen: Set[str] = set() + if prefix == "": + async for key in self.list(): # Iterates metadata keys + # e.g., if key is "group1/.zgroup" or "array1/.zarray", first_component is "group1" or "array1" + # if key is ".zgroup", first_component is ".zgroup" + first_component = key.split('/', 1)[0] + if first_component not in seen: + seen.add(first_component) + yield first_component + else: + # For listing subdirectories like "group1/", we'd need to match keys starting with "group1/" + # and then extract the next component. This is more involved. + # Zarr spec: list_dir(path) should yield children (both objects and "directories") + # For simplicity, and consistency with original FlatZarrStore, keeping this minimal. + # To make it more compliant for prefix="foo/": + # normalized_prefix = prefix if prefix.endswith('/') else prefix + '/' + # async for key in self.list_prefix(normalized_prefix): + # remainder = key[len(normalized_prefix):] + # child = remainder.split('/', 1)[0] + # if child not in seen: + # seen.add(child) + # yield child + pass # Or raise NotImplementedError for non-empty prefixes if strict. \ No newline at end of file diff --git a/tests/test_benchmark_stores.py b/tests/test_benchmark_stores.py new file mode 100644 index 0000000..cff1b98 --- /dev/null +++ b/tests/test_benchmark_stores.py @@ -0,0 +1,282 @@ +import time + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from dag_cbor.ipld import IPLDKind + +# Import both store implementations +from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore +from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset for benchmarking.""" + # Using a slightly larger dataset for a more meaningful benchmark + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + temp = np.random.randn(len(times), len(lats), len(lons)) + precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons))) + + ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], temp), + }, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) + yield ds + + +# ### +# BENCHMARK FOR THE ORIGINAL ZarrHAMTStore +# ### +@pytest.mark.asyncio(loop_scope="session") +async def test_benchmark_hamt_store( + create_ipfs: tuple[str, str], + random_zarr_dataset: xr.Dataset, +): + """Benchmarks write and read performance for the ZarrHAMTStore.""" + print("\n\n" + "=" * 80) + print("🚀 STARTING BENCHMARK for ZarrHAMTStore") + print("=" * 80) + + rpc_base_url, gateway_base_url = create_ipfs + + # rpc_base_url = f"https://ipfs-gateway.dclimate.net" + # gateway_base_url = f"https://ipfs-gateway.dclimate.net" + # headers = { + # "X-API-Key": "", + # } + headers = {} + test_ds = random_zarr_dataset + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers + ) as kubo_cas: + # --- Write --- + print("Building HAMT store...") + hamt = await HAMT.build(cas=kubo_cas, values_are_bytes=True) + print("HAMT store built successfully.") + zhs = ZarrHAMTStore(hamt) + print("ZarrHAMTStore created successfully.") + + start_write = time.perf_counter() + # Perform an initial write and an append to simulate a common workflow + test_ds.to_zarr(store=zhs, mode="w") + print("Initial write completed, now appending...") + test_ds.to_zarr(store=zhs, mode="a", append_dim="time") + await hamt.make_read_only() # Flush and freeze to get the final CID + end_write = time.perf_counter() + + cid: IPLDKind = hamt.root_node_id + print(f"\n--- [HAMT] Write Stats ---") + print(f"Total time to write and append: {end_write - start_write:.2f} seconds") + print(f"Final Root CID: {cid}") + + # --- Read --- + hamt_ro = await HAMT.build( + cas=kubo_cas, root_node_id=cid, values_are_bytes=True, read_only=True + ) + zhs_ro = ZarrHAMTStore(hamt_ro, read_only=True) + + start_read = time.perf_counter() + ipfs_ds = xr.open_zarr(store=zhs_ro) + # Force a read of some data to ensure it's loaded + _ = ipfs_ds.temp.isel(time=0).values + print(_) + end_read = time.perf_counter() + + print(f"\n--- [HAMT] Read Stats ---") + print(f"Total time to open and read: {end_read - start_read:.2f} seconds") + + # --- Verification --- + full_test_ds = xr.concat([test_ds, test_ds], dim="time") + xr.testing.assert_identical(full_test_ds, ipfs_ds) + print("\n✅ [HAMT] Data verification successful.") + print("=" * 80) + + +# ### +# BENCHMARK FOR THE NEW FlatZarrStore +# ### +@pytest.mark.asyncio(loop_scope="session") +async def test_benchmark_flat_store( + create_ipfs: tuple[str, str], + random_zarr_dataset: xr.Dataset, +): + """Benchmarks write and read performance for the new FlatZarrStore.""" + print("\n\n" + "=" * 80) + print("🚀 STARTING BENCHMARK for FlatZarrStore") + print("=" * 80) + + rpc_base_url, gateway_base_url = create_ipfs + # rpc_base_url = f"https://ipfs-gateway.dclimate.net" + # gateway_base_url = f"https://ipfs-gateway.dclimate.net" + # headers = { + # "X-API-Key": "", + # } + headers = {} + test_ds = random_zarr_dataset + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers + ) as kubo_cas: + # --- Write --- + # The full shape after appending + appended_shape = list(test_ds.dims.values()) + time_axis_index = list(test_ds.dims).index("time") + appended_shape[time_axis_index] *= 2 + final_array_shape = tuple(appended_shape) + + final_chunk_shape = [] + for dim_name in test_ds.dims: # Preserves dimension order + if dim_name in test_ds.chunks: + # test_ds.chunks[dim_name] is a tuple e.g. (20,) + final_chunk_shape.append(test_ds.chunks[dim_name][0]) + else: + # Fallback if a dimension isn't explicitly chunked (should use its full size) + final_chunk_shape.append(test_ds.dims[dim_name]) + final_chunk_shape = tuple(final_chunk_shape) + + store_write = await FlatZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=final_array_shape, + chunk_shape=final_chunk_shape, + ) + + start_write = time.perf_counter() + # Perform an initial write and an append + test_ds.to_zarr(store=store_write, mode="w") + test_ds.to_zarr(store=store_write, mode="a", append_dim="time") + root_cid = await store_write.flush() # Flush to get the final CID + end_write = time.perf_counter() + + print(f"\n--- [FlatZarr] Write Stats ---") + print(f"Total time to write and append: {end_write - start_write:.2f} seconds") + print(f"Final Root CID: {root_cid}") + + # --- Read --- + store_read = await FlatZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + + start_read = time.perf_counter() + ipfs_ds = xr.open_zarr(store=store_read) + # Force a read of some data to ensure it's loaded + _ = ipfs_ds.temp.isel(time=0).values + print(_) + end_read = time.perf_counter() + + print(f"\n--- [FlatZarr] Read Stats ---") + print(f"Total time to open and read: {end_read - start_read:.2f} seconds") + + # --- Verification --- + full_test_ds = xr.concat([test_ds, test_ds], dim="time") + xr.testing.assert_identical(full_test_ds, ipfs_ds) + print("\n✅ [FlatZarr] Data verification successful.") + print("=" * 80) + +@pytest.mark.asyncio(loop_scope="session") +async def test_benchmark_sharded_store( # Renamed function + create_ipfs: tuple[str, str], + random_zarr_dataset: xr.Dataset, +): + """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring + print("\n\n" + "=" * 80) + print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print + print("=" * 80) + + rpc_base_url, gateway_base_url = create_ipfs + + # rpc_base_url = f"https://ipfs-gateway.dclimate.net" + # gateway_base_url = f"https://ipfs-gateway.dclimate.net" + # headers = { + # "X-API-Key": "", + # } + headers = {} + test_ds = random_zarr_dataset + + # Define chunks_per_shard for the ShardedZarrStore + chunks_per_shard_config = 1024 # Configuration for sharding + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers + ) as kubo_cas: + # --- Write --- + # The full shape after appending + appended_shape = list(test_ds.dims.values()) + time_axis_index = list(test_ds.dims).index("time") + appended_shape[time_axis_index] *= 2 # Simulating appending along time dimension + final_array_shape = tuple(appended_shape) + + # Determine chunk shape from the dataset's encoding or dimensions + final_chunk_shape_list = [] + for dim_name in test_ds.dims: # Preserves dimension order from the dataset + if dim_name in test_ds.chunks: + # test_ds.chunks is a dict like {'time': (20,), 'y': (20,), 'x': (20,)} + final_chunk_shape_list.append(test_ds.chunks[dim_name][0]) + else: + # Fallback if a dimension isn't explicitly chunked (should use its full size) + final_chunk_shape_list.append(test_ds.dims[dim_name]) + final_chunk_shape = tuple(final_chunk_shape_list) + + # Use ShardedZarrStore and provide chunks_per_shard + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=final_array_shape, + chunk_shape=final_chunk_shape, + chunks_per_shard=chunks_per_shard_config # Added new parameter + ) + + start_write = time.perf_counter() + # Perform an initial write and an append + test_ds.to_zarr(store=store_write, mode="w") + test_ds.to_zarr(store=store_write, mode="a", append_dim="time") + root_cid = await store_write.flush() # Flush to get the final CID + end_write = time.perf_counter() + + print(f"\n--- [ShardedZarr] Write Stats (chunks_per_shard={chunks_per_shard_config}) ---") # Updated print + print(f"Total time to write and append: {end_write - start_write:.2f} seconds") + print(f"Final Root CID: {root_cid}") + + print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print + # --- Read --- + # When opening for read, chunks_per_shard is read from the store's metadata + store_read = await ShardedZarrStore.open( # Use ShardedZarrStore + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + + start_read = time.perf_counter() + ipfs_ds = xr.open_zarr(store=store_read) + # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) + if "temp" in ipfs_ds.variables and "time" in ipfs_ds.coords: + _ = ipfs_ds.temp.isel(time=0).values + print(_) + elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable + first_var_name = list(ipfs_ds.data_vars.keys())[0] + # Construct a minimal selection based on available dimensions + selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} + if selection: + _ = ipfs_ds[first_var_name].isel(**selection).values + else: # If no dimensions, try loading the whole variable (e.g. scalar) + _ = ipfs_ds[first_var_name].values + end_read = time.perf_counter() + + print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print + print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") + + # --- Verification --- + # Create the expected full dataset after append operation + full_test_ds = xr.concat([test_ds, test_ds], dim="time") + xr.testing.assert_identical(full_test_ds, ipfs_ds) + print("\n✅ [ShardedZarr] Data verification successful.") # Updated print + print("=" * 80) \ No newline at end of file From 4476272182b1ca693acfb5c5a975749639dbd854 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Fri, 6 Jun 2025 10:18:07 -0400 Subject: [PATCH 11/74] fix: converter --- py_hamt/__init__.py | 2 + py_hamt/hamt_to_sharded_converter.py | 135 +++++++++++++++++++++++++++ py_hamt/sharded_zarr_store.py | 15 +-- tests/test_converter.py | 129 +++++++++++++++++++++++++ 4 files changed, 274 insertions(+), 7 deletions(-) create mode 100644 py_hamt/hamt_to_sharded_converter.py create mode 100644 tests/test_converter.py diff --git a/py_hamt/__init__.py b/py_hamt/__init__.py index 7819c54..03bb94a 100644 --- a/py_hamt/__init__.py +++ b/py_hamt/__init__.py @@ -4,6 +4,7 @@ from .zarr_hamt_store import ZarrHAMTStore from .flat_zarr_store import FlatZarrStore from .sharded_zarr_store import ShardedZarrStore +from .hamt_to_sharded_converter import convert_hamt_to_sharded __all__ = [ "blake3_hashfn", @@ -15,4 +16,5 @@ "SimpleEncryptedZarrHAMTStore", "FlatZarrStore", "ShardedZarrStore", + "convert_hamt_to_sharded", ] diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py new file mode 100644 index 0000000..ff9f88c --- /dev/null +++ b/py_hamt/hamt_to_sharded_converter.py @@ -0,0 +1,135 @@ +import argparse +import asyncio +import json +import time +from typing import Dict, Any +from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore +from py_hamt.zarr_hamt_store import ZarrHAMTStore +import xarray as xr +from multiformats import CID +from zarr.core.buffer import Buffer, BufferPrototype + +async def convert_hamt_to_sharded( + cas: KuboCAS, hamt_root_cid: str, chunks_per_shard: int, cid_len: int = 59 +) -> str: + """ + Converts a Zarr dataset from a HAMT-based store to a ShardedZarrStore. + + Args: + cas: An initialized ContentAddressedStore instance (KuboCAS). + hamt_root_cid: The root CID of the source ZarrHAMTStore. + chunks_per_shard: The number of chunks to group into a single shard in the new store. + + Returns: + The root CID of the newly created ShardedZarrStore. + """ + print(f"--- Starting Conversion from HAMT Root {hamt_root_cid} ---") + start_time = time.perf_counter() + # 1. Open the source HAMT store for reading + print("Opening source HAMT store...") + hamt_ro = await HAMT.build( + cas=cas, root_node_id=hamt_root_cid, values_are_bytes=True, read_only=True + ) + source_store = ZarrHAMTStore(hamt_ro, read_only=True) + source_dataset = xr.open_zarr(store=source_store, consolidated=True) + # 2. Introspect the source array to get its configuration + print("Reading metadata from source store...") + try: + # Read the stores metadata to get array shape and chunk shape + print("Fetching metadata...") + ordered_dims = list(source_dataset.dims) + array_shape_tuple = tuple(source_dataset.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(source_dataset.chunks[dim][0] for dim in ordered_dims) + array_shape = array_shape_tuple + chunk_shape = chunk_shape_tuple + print("Metadata read successfully.") + print(f"Found Array Shape: {array_shape}") + print(f"Found Chunk Shape: {chunk_shape}") + + except KeyError as e: + raise RuntimeError( + f"Could not find required metadata in source .zarray: {e}" + ) from e + + # 3. Create the destination ShardedZarrStore for writing + print(f"Initializing new ShardedZarrStore with {chunks_per_shard} chunks per shard...") + dest_store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=array_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + cid_len=cid_len, + ) + + print("Destination store initialized.") + + # 4. Iterate and copy all data from source to destination + print("Starting data migration...") + count = 0 + async for key in hamt_ro.keys(): + count += 1 + # Read the raw data (metadata or chunk) from the source + cid: CID = await hamt_ro.get_pointer(key) + if cid is None: + continue + cid_base32_str = str(cid.encode("base32")) + + # Write the exact same key-value pair to the destination. + await dest_store.set_pointer(key, cid_base32_str) + if count % 200 == 0: + print(f"Migrated {count} keys...") + + print(f"Migration of {count} total keys complete.") + + # 5. Finalize the new store by flushing it to the CAS + print("Flushing new store to get final root CID...") + new_root_cid = await dest_store.flush() + end_time = time.perf_counter() + + print("\n--- Conversion Complete! ---") + print(f"Total time: {end_time - start_time:.2f} seconds") + print(f"New ShardedZarrStore Root CID: {new_root_cid}") + return new_root_cid + + +async def main(): + parser = argparse.ArgumentParser( + description="Convert a Zarr HAMT store to a Sharded Zarr store." + ) + parser.add_argument( + "hamt_cid", type=str, help="The root CID of the source Zarr HAMT store." + ) + parser.add_argument( + "--chunks-per-shard", + type=int, + default=1024, + help="Number of chunk CIDs to store per shard in the new store.", + ) + parser.add_argument( + "--rpc-url", + type=str, + default="http://127.0.0.1:5001/api/v0", + help="The URL of the IPFS Kubo RPC API.", + ) + parser.add_argument( + "--gateway-url", + type=str, + default="http://127.0.0.1:8080", + help="The URL of the IPFS Gateway.", + ) + args = parser.parse_args() + async with KuboCAS( + rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url + ) as cas_client: + try: + await convert_hamt_to_sharded( + cas=cas_client, + hamt_root_cid=args.hamt_cid, + chunks_per_shard=args.chunks_per_shard, + ) + except Exception as e: + print(f"\nAn error occurred: {e}") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index fbc2afa..968c787 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -138,8 +138,7 @@ async def _load_root_from_cid(self): self._root_obj = dag_cbor.decode(root_bytes) if self._root_obj.get("manifest_version") != "sharded_zarr_v1": - raise ValueError(f"Incompatible manifest version: {self._root_obj.get('manifest_version')}. Expected 'sharded_zarr_v1'.") - + raise ValueError(f"Incompatible manifest version: {self._root_obj.get('manifest_version')}. Expected 'sharded_zarr_v1'.") chunk_info = self._root_obj["chunks"] self._array_shape = tuple(chunk_info["array_shape"]) self._chunk_shape = tuple(chunk_info["chunk_shape"]) @@ -446,7 +445,6 @@ async def get( chunk_cid_bytes: Optional[bytes] = None if shard_idx in self._shard_data_cache: - print(f"DEBUG: get() - Shard {shard_idx} found in cache. Key: {key}") cached_shard_data = self._shard_data_cache[shard_idx] if offset_in_shard_bytes + self._cid_len <= len(cached_shard_data): chunk_cid_bytes = bytes(cached_shard_data[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len]) @@ -464,7 +462,6 @@ async def get( if chunk_cid_bytes is None: # Not in cache or cache was invalid # print(f"DEBUG: get() - Shard {shard_idx} not in cache or invalid. Fetching specific CID. Key: {key}") try: - print("FETCHING SPECIFIC CID BYTES FROM SHARD Until Shard is cached. Key:", key) chunk_cid_bytes = await self.cas.load( target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len ) @@ -522,12 +519,16 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: # Save the actual chunk data to CAS first, to get its CID chunk_data_cid_obj = await self.cas.save(raw_chunk_data_bytes, codec="raw") # Chunks are typically raw bytes chunk_data_cid_str = str(chunk_data_cid_obj) + await self.set_pointer(key, chunk_data_cid_str) # Store the CID in the index + async def set_pointer( + self, key: str, pointer: str + ) -> None: # Ensure the CID (as ASCII bytes) fits in the allocated slot, padding with nulls - chunk_data_cid_ascii_bytes = chunk_data_cid_str.encode("ascii") + chunk_data_cid_ascii_bytes = pointer.encode("ascii") if len(chunk_data_cid_ascii_bytes) > self._cid_len: raise ValueError( - f"Encoded CID byte length ({len(chunk_data_cid_ascii_bytes)}) exceeds configured CID length ({self._cid_len}). CID: {chunk_data_cid_str}" + f"Encoded CID byte length ({len(chunk_data_cid_ascii_bytes)}) exceeds configured CID length ({self._cid_len}). CID: {pointer}" ) padded_chunk_data_cid_bytes = chunk_data_cid_ascii_bytes.ljust(self._cid_len, b'\0') @@ -539,7 +540,7 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: # So, we store the metadata content, get its CID, and put *that* CID in root_obj. # This means the `value_cid_str` for metadata should be from `raw_chunk_data_bytes`. # This seems to align with FlatZarrStore, where `value_cid` is used for both. - self._root_obj["metadata"][key] = chunk_data_cid_str # Store the string CID of the metadata content + self._root_obj["metadata"][key] = pointer # Store the string CID of the metadata content self._dirty_root = True return diff --git a/tests/test_converter.py b/tests/test_converter.py new file mode 100644 index 0000000..e72eece --- /dev/null +++ b/tests/test_converter.py @@ -0,0 +1,129 @@ +import asyncio +import time +import uuid + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +# Import store implementations +from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore, convert_hamt_to_sharded +from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +@pytest.fixture(scope="module") +def converter_test_dataset(): + """ + Creates a random, uniquely-named xarray Dataset specifically for the converter test. + Using a unique variable name helps avoid potential caching issues between test runs. + """ + # A smaller dataset is fine for a verification test + times = pd.date_range("2025-01-01", periods=20) + lats = np.linspace(40, 50, 10) + lons = np.linspace(-85, -75, 20) + + # Generate a unique variable name for this test run + unique_var_name = f"data_{str(uuid.uuid4())[:8]}" + + data = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + {unique_var_name: (["time", "lat", "lon"], data)}, + coords={"time": times, "lat": lats, "lon": lons}, + attrs={"description": "Test dataset for converter verification."}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 10, "lat": 10, "lon": 10}) + yield ds + + +@pytest.mark.asyncio(loop_scope="session") +async def test_converter_produces_identical_dataset( + create_ipfs: tuple[str, str], + converter_test_dataset: xr.Dataset, +): + """ + Tests the hamt_to_sharded_converter by performing a full conversion + and verifying that the resulting dataset is identical to the source. + """ + print("\n\n" + "=" * 80) + print("🚀 STARTING TEST for HAMT to Sharded Converter") + print("=" * 80) + + rpc_base_url, gateway_base_url = create_ipfs + test_ds = converter_test_dataset + chunks_per_shard_config = 64 # A reasonable value for this test size + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # -------------------------------------------------------------------- + # STEP 1: Create the source HAMT store from our test dataset + # -------------------------------------------------------------------- + print("\n--- STEP 1: Creating source HAMT store ---") + hamt_write = await HAMT.build(cas=kubo_cas, values_are_bytes=True) + source_hamt_store = ZarrHAMTStore(hamt_write) + + start_write = time.perf_counter() + test_ds.to_zarr(store=source_hamt_store, mode="w") + await hamt_write.make_read_only() # Flush to get the final CID + end_write = time.perf_counter() + + hamt_root_cid = str(hamt_write.root_node_id) + print(f"Source HAMT store created in {end_write - start_write:.2f}s") + print(f"Source HAMT Root CID: {hamt_root_cid}") + + # -------------------------------------------------------------------- + # STEP 2: Run the conversion script + # -------------------------------------------------------------------- + print("\n--- STEP 2: Running conversion script ---") + sharded_root_cid = await convert_hamt_to_sharded( + cas=kubo_cas, + hamt_root_cid=hamt_root_cid, + chunks_per_shard=chunks_per_shard_config, + ) + print("Conversion script finished.") + print(f"New Sharded Store Root CID: {sharded_root_cid}") + assert sharded_root_cid is not None + + # -------------------------------------------------------------------- + # STEP 3: Verification + # -------------------------------------------------------------------- + print("\n--- STEP 3: Verifying data integrity ---") + + # Open the original dataset from the HAMT store + print("Reading data back from original HAMT store...") + + hamt_ro = await HAMT.build( + cas=kubo_cas, root_node_id=hamt_root_cid, values_are_bytes=True, read_only=True + ) + zhs_ro = ZarrHAMTStore(hamt_ro, read_only=True) + + start_read = time.perf_counter() + ds_from_hamt = xr.open_zarr(store=zhs_ro) + + end_read = time.perf_counter() + print(f"Original HAMT store read in {end_read - start_read:.2f}s") + + # Open the converted dataset from the new Sharded store + print("Reading data back from new Sharded store...") + dest_store_ro = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=sharded_root_cid + ) + ds_from_sharded = xr.open_zarr(dest_store_ro) + + # The ultimate test: are the two xarray.Dataset objects identical? + # This checks coordinates, variables, data values, and attributes. + print("Comparing the two datasets...") + xr.testing.assert_identical(ds_from_hamt, ds_from_sharded) + # Ask for random samples from both datasets to ensure they match + for var in ds_from_hamt.data_vars: + # Assert all identical + np.testing.assert_array_equal( + ds_from_hamt[var].values, ds_from_sharded[var].values + ) + + print("\n✅ Verification successful! The datasets are identical.") + print("=" * 80) \ No newline at end of file From c981437416a348f04d1f0aa93d6a87cf945439fd Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 9 Jun 2025 10:58:13 -0400 Subject: [PATCH 12/74] fix: more work on sharding --- py_hamt/__init__.py | 5 +- py_hamt/flat_zarr_store.py | 745 ++++++++++++++------------- py_hamt/hamt_to_sharded_converter.py | 37 +- py_hamt/sharded_zarr_store.py | 256 +++------ tests/test_benchmark_stores.py | 594 +++++++++++---------- tests/test_converter.py | 130 ++++- tests/test_cpc_compare.py | 128 +++++ tests/test_sharded_zarr_store.py | 669 ++++++++++++++++++++++++ 8 files changed, 1699 insertions(+), 865 deletions(-) create mode 100644 tests/test_cpc_compare.py create mode 100644 tests/test_sharded_zarr_store.py diff --git a/py_hamt/__init__.py b/py_hamt/__init__.py index 03bb94a..918bec7 100644 --- a/py_hamt/__init__.py +++ b/py_hamt/__init__.py @@ -2,9 +2,8 @@ from .hamt import HAMT, blake3_hashfn from .store import ContentAddressedStore, InMemoryCAS, KuboCAS from .zarr_hamt_store import ZarrHAMTStore -from .flat_zarr_store import FlatZarrStore from .sharded_zarr_store import ShardedZarrStore -from .hamt_to_sharded_converter import convert_hamt_to_sharded +from .hamt_to_sharded_converter import convert_hamt_to_sharded, sharded_converter_cli __all__ = [ "blake3_hashfn", @@ -14,7 +13,7 @@ "ZarrHAMTStore", "InMemoryCAS", "SimpleEncryptedZarrHAMTStore", - "FlatZarrStore", "ShardedZarrStore", "convert_hamt_to_sharded", + "sharded_converter_cli", ] diff --git a/py_hamt/flat_zarr_store.py b/py_hamt/flat_zarr_store.py index b4a96ad..33f531f 100644 --- a/py_hamt/flat_zarr_store.py +++ b/py_hamt/flat_zarr_store.py @@ -1,370 +1,375 @@ -import asyncio -import math -from collections.abc import AsyncIterator, Iterable -from typing import Optional, cast - -import dag_cbor -import zarr.abc.store -import zarr.core.buffer -from zarr.core.common import BytesLike - -from .store import ContentAddressedStore - - -class FlatZarrStore(zarr.abc.store.Store): - """ - Implements the Zarr Store API using a flat, predictable layout for chunk CIDs. - - This store bypasses the need for a HAMT, offering direct, calculated - access to chunk data based on a mathematical formula. It is designed for - dense, multi-dimensional arrays where chunk locations are predictable. - - The store is structured around a single root CBOR object. This root object contains: - 1. A dictionary mapping metadata keys (like 'zarr.json') to their CIDs. - 2. A single CID pointing to a large, contiguous block of bytes (the "flat index"). - This flat index is a concatenation of the CIDs of all data chunks. - - Accessing a chunk involves: - 1. Loading the root object (if not cached). - 2. Calculating the byte offset of the chunk's CID within the flat index. - 3. Fetching that specific CID using a byte-range request on the flat index. - 4. Fetching the actual chunk data using the retrieved CID. - - ### Sample Code - ```python - import xarray as xr - import numpy as np - from py_hamt import KuboCAS, FlatZarrStore - - # --- Write --- - ds = xr.Dataset( - {"data": (("t", "y", "x"), np.arange(24).reshape(2, 3, 4))}, - ) - cas = KuboCAS() - # When creating, must provide array shape and chunk shape - store_write = await FlatZarrStore.open( - cas, - read_only=False, - array_shape=ds.data.shape, - chunk_shape=ds.data.encoding['chunks'] - ) - ds.to_zarr(store=store_write, mode="w") - root_cid = await store_write.flush() # IMPORTANT: flush to get final root CID - print(f"Finished writing. Root CID: {root_cid}") - - - # --- Read --- - store_read = await FlatZarrStore.open(cas, read_only=True, root_cid=root_cid) - ds_read = xr.open_zarr(store=store_read) - print("Read back dataset:") - print(ds_read) - xr.testing.assert_identical(ds, ds_read) - ``` - """ - - def __init__( - self, cas: ContentAddressedStore, read_only: bool, root_cid: Optional[str] - ): - """Use the async `open()` classmethod to instantiate this class.""" - super().__init__(read_only=read_only) - self.cas = cas - self._root_cid = root_cid - self._root_obj: Optional[dict] = None - self._flat_index_cache: Optional[bytearray] = None - self._cid_len: Optional[int] = None - self._array_shape: Optional[tuple[int, ...]] = None - self._chunk_shape: Optional[tuple[int, ...]] = None - self._chunks_per_dim: Optional[tuple[int, ...]] = None - self._dirty = False - - @classmethod - async def open( - cls, - cas: ContentAddressedStore, - read_only: bool, - root_cid: Optional[str] = None, - *, - array_shape: Optional[tuple[int, ...]] = None, - chunk_shape: Optional[tuple[int, ...]] = None, - cid_len: int = 59, # Default for base32 v1 CIDs like bafy... - ) -> "FlatZarrStore": - """ - Asynchronously opens an existing FlatZarrStore or initializes a new one. - - Args: - cas: The Content Addressed Store (e.g., KuboCAS). - read_only: If True, the store is in read-only mode. - root_cid: The root CID of an existing store to open. Required for read_only. - array_shape: The full shape of the Zarr array. Required for a new writeable store. - chunk_shape: The shape of a single chunk. Required for a new writeable store. - cid_len: The expected byte length of a CID string. - """ - store = cls(cas, read_only, root_cid) - if root_cid: - await store._load_root_from_cid() - elif not read_only: - if not all([array_shape, chunk_shape]): - raise ValueError( - "array_shape and chunk_shape must be provided for a new store." - ) - store._initialize_new_root(array_shape, chunk_shape, cid_len) - else: - raise ValueError("root_cid must be provided for a read-only store.") - return store - - def _initialize_new_root( - self, - array_shape: tuple[int, ...], - chunk_shape: tuple[int, ...], - cid_len: int, - ): - self._array_shape = array_shape - self._chunk_shape = chunk_shape - self._cid_len = cid_len - self._chunks_per_dim = tuple( - math.ceil(a / c) for a, c in zip(array_shape, chunk_shape) - ) - self._root_obj = { - "manifest_version": "flat_zarr_v1", - "metadata": {}, - "chunks": { - "cid": None, # Will be filled on first flush - "array_shape": list(self._array_shape), - "chunk_shape": list(self._chunk_shape), - "cid_byte_length": self._cid_len, - }, - } - self._dirty = True - - async def _load_root_from_cid(self): - if not self._root_cid: - raise ValueError("Cannot load root without a root_cid.") - root_bytes = await self.cas.load(self._root_cid) - self._root_obj = dag_cbor.decode(root_bytes) - chunk_info = self._root_obj.get("chunks", {}) - self._array_shape = tuple(chunk_info["array_shape"]) - self._chunk_shape = tuple(chunk_info["chunk_shape"]) - self._cid_len = chunk_info["cid_byte_length"] - self._chunks_per_dim = tuple( - math.ceil(a / c) for a, c in zip(self._array_shape, self._chunk_shape) - ) - - def _parse_chunk_key(self, key: str) -> Optional[tuple[int, ...]]: - if not self._array_shape or not key.startswith("c/"): - return None - parts = key.split("/") - if len(parts) != len(self._array_shape) + 1: - return None - try: - return tuple(map(int, parts[1:])) - except (ValueError, IndexError): - return None - - async def set_partial_values( - self, key_start_values: Iterable[tuple[str, int, BytesLike]] - ) -> None: - """@private""" - raise NotImplementedError("Partial writes are not supported by this store.") - - async def get_partial_values( - self, - prototype: zarr.core.buffer.BufferPrototype, - key_ranges: Iterable[tuple[str, zarr.abc.store.ByteRequest | None]], - ) -> list[zarr.core.buffer.Buffer | None]: - """ - Retrieves multiple keys or byte ranges concurrently. - """ - tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] - results = await asyncio.gather(*tasks) - return results - - def __eq__(self, other: object) -> bool: - """@private""" - if not isinstance(other, FlatZarrStore): - return NotImplemented - return self._root_cid == other._root_cid - - def _get_chunk_offset(self, chunk_coords: tuple[int, ...]) -> int: - linear_index = 0 - multiplier = 1 - for i in reversed(range(len(self._chunks_per_dim))): - linear_index += chunk_coords[i] * multiplier - multiplier *= self._chunks_per_dim[i] - return linear_index * self._cid_len - - async def flush(self) -> str: - """ - Writes all pending changes (metadata and chunk index) to the CAS - and returns the new root CID. This MUST be called after all writes are complete. - """ - if self.read_only or not self._dirty: - return self._root_cid - - if self._flat_index_cache is not None: - flat_index_cid_obj = await self.cas.save( - bytes(self._flat_index_cache), codec="raw" - ) - self._root_obj["chunks"]["cid"] = str(flat_index_cid_obj) - - root_obj_bytes = dag_cbor.encode(self._root_obj) - new_root_cid_obj = await self.cas.save(root_obj_bytes, codec="dag-cbor") - self._root_cid = str(new_root_cid_obj) - self._dirty = False - return self._root_cid - - async def get( - self, - key: str, - prototype: zarr.core.buffer.BufferPrototype, - byte_range: zarr.abc.store.ByteRequest | None = None, - ) -> zarr.core.buffer.Buffer | None: - """@private""" - if self._root_obj is None: - await self._load_root_from_cid() - - chunk_coords = self._parse_chunk_key(key) - try: - # Metadata request - if chunk_coords is None: - metadata_cid = self._root_obj["metadata"].get(key) - if metadata_cid is None: - return None - data = await self.cas.load(metadata_cid) - return prototype.buffer.from_bytes(data) - - # Chunk data request - flat_index_cid = self._root_obj["chunks"]["cid"] - if flat_index_cid is None: - return None - - offset = self._get_chunk_offset(chunk_coords) - chunk_cid_bytes = await self.cas.load( - flat_index_cid, offset=offset, length=self._cid_len - ) - - if all(b == 0 for b in chunk_cid_bytes): - return None # Chunk doesn't exist - - chunk_cid = chunk_cid_bytes.decode("ascii") - data = await self.cas.load(chunk_cid) - return prototype.buffer.from_bytes(data) - - except (KeyError, IndexError): - return None - - async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: - """@private""" - if self.read_only: - raise ValueError("Cannot write to a read-only store.") - if self._root_obj is None: - raise RuntimeError("Store not initialized for writing.") - - self._dirty = True - raw_bytes = value.to_bytes() - value_cid_obj = await self.cas.save(raw_bytes, codec="raw") - value_cid = str(value_cid_obj) - - if len(value_cid) != self._cid_len: - raise ValueError( - f"Inconsistent CID length. Expected {self._cid_len}, got {len(value_cid)}" - ) - - chunk_coords = self._parse_chunk_key(key) - - if chunk_coords is None: # Metadata - self._root_obj["metadata"][key] = value_cid - return - - # Chunk Data - if self._flat_index_cache is None: - num_chunks = math.prod(self._chunks_per_dim) - self._flat_index_cache = bytearray(num_chunks * self._cid_len) - - offset = self._get_chunk_offset(chunk_coords) - self._flat_index_cache[offset : offset + self._cid_len] = value_cid.encode( - "ascii" - ) - - # --- Other required zarr.abc.store methods --- - - async def exists(self, key: str) -> bool: - """@private""" - # A more efficient version might check for null bytes in the flat index - # but this is functionally correct. - - # TODO: Optimize this check - return True - - - # return (await self.get(key, zarr.core.buffer.Buffer.prototype, None)) is not None - - @property - def supports_writes(self) -> bool: - """@private""" - return not self.read_only - - @property - def supports_partial_writes(self) -> bool: - """@private""" - return False # Each chunk is an immutable object - - @property - def supports_deletes(self) -> bool: - """@private""" - return not self.read_only - - async def delete(self, key: str) -> None: - if self.read_only: - raise ValueError("Cannot delete from a read-only store.") - if self._root_obj is None: - await self._load_root_from_cid() - chunk_coords = self._parse_chunk_key(key) - if chunk_coords is None: - if key in self._root_obj["metadata"]: - del self._root_obj["metadata"][key] - self._dirty = True - return - else: - raise KeyError(f"Metadata key '{key}' not found.") - flat_index_cid = self._root_obj["chunks"]["cid"] - if self._flat_index_cache is None: - if not flat_index_cid: - raise KeyError(f"Chunk key '{key}' not found in non-existent index.") - self._flat_index_cache = bytearray(await self.cas.load(flat_index_cid)) - offset = self._get_chunk_offset(chunk_coords) - if all(b == 0 for b in self._flat_index_cache[offset : offset + self._cid_len]): - raise KeyError(f"Chunk key '{key}' not found.") - self._flat_index_cache[offset : offset + self._cid_len] = bytearray(self._cid_len) - self._dirty = True - - @property - def supports_listing(self) -> bool: - """@private""" - return True - - async def list(self) -> AsyncIterator[str]: - """@private""" - if self._root_obj is None: - await self._load_root_from_cid() - for key in self._root_obj["metadata"]: - yield key - # Note: Listing all chunk keys without reading the index is non-trivial. - # A full implementation might need an efficient way to iterate non-null chunks. - # This basic version only lists metadata. - - async def list_prefix(self, prefix: str) -> AsyncIterator[str]: - """@private""" - async for key in self.list(): - if key.startswith(prefix): - yield key - - async def list_dir(self, prefix: str) -> AsyncIterator[str]: - """@private""" - # This simplified version only works for the root. - if prefix == "": - seen = set() - async for key in self.list(): - name = key.split('/')[0] - if name not in seen: - seen.add(name) - yield name \ No newline at end of file +# Functional Flat Store for Zarr using IPFS. This is ultimately the "Best" way to do it +# but only if the metadata is small enough to fit in a single CBOR object. +# This is commented out because we are not using it. However I did not want to delete it +# because it is a good example of how to implement a Zarr store using a flat index + +# import asyncio +# import math +# from collections.abc import AsyncIterator, Iterable +# from typing import Optional, cast + +# import dag_cbor +# import zarr.abc.store +# import zarr.core.buffer +# from zarr.core.common import BytesLike + +# from .store import ContentAddressedStore + + +# class FlatZarrStore(zarr.abc.store.Store): +# """ +# Implements the Zarr Store API using a flat, predictable layout for chunk CIDs. + +# This store bypasses the need for a HAMT, offering direct, calculated +# access to chunk data based on a mathematical formula. It is designed for +# dense, multi-dimensional arrays where chunk locations are predictable. + +# The store is structured around a single root CBOR object. This root object contains: +# 1. A dictionary mapping metadata keys (like 'zarr.json') to their CIDs. +# 2. A single CID pointing to a large, contiguous block of bytes (the "flat index"). +# This flat index is a concatenation of the CIDs of all data chunks. + +# Accessing a chunk involves: +# 1. Loading the root object (if not cached). +# 2. Calculating the byte offset of the chunk's CID within the flat index. +# 3. Fetching that specific CID using a byte-range request on the flat index. +# 4. Fetching the actual chunk data using the retrieved CID. + +# ### Sample Code +# ```python +# import xarray as xr +# import numpy as np +# from py_hamt import KuboCAS, FlatZarrStore + +# # --- Write --- +# ds = xr.Dataset( +# {"data": (("t", "y", "x"), np.arange(24).reshape(2, 3, 4))}, +# ) +# cas = KuboCAS() +# # When creating, must provide array shape and chunk shape +# store_write = await FlatZarrStore.open( +# cas, +# read_only=False, +# array_shape=ds.data.shape, +# chunk_shape=ds.data.encoding['chunks'] +# ) +# ds.to_zarr(store=store_write, mode="w") +# root_cid = await store_write.flush() # IMPORTANT: flush to get final root CID +# print(f"Finished writing. Root CID: {root_cid}") + + +# # --- Read --- +# store_read = await FlatZarrStore.open(cas, read_only=True, root_cid=root_cid) +# ds_read = xr.open_zarr(store=store_read) +# print("Read back dataset:") +# print(ds_read) +# xr.testing.assert_identical(ds, ds_read) +# ``` +# """ + +# def __init__( +# self, cas: ContentAddressedStore, read_only: bool, root_cid: Optional[str] +# ): +# """Use the async `open()` classmethod to instantiate this class.""" +# super().__init__(read_only=read_only) +# self.cas = cas +# self._root_cid = root_cid +# self._root_obj: Optional[dict] = None +# self._flat_index_cache: Optional[bytearray] = None +# self._cid_len: Optional[int] = None +# self._array_shape: Optional[tuple[int, ...]] = None +# self._chunk_shape: Optional[tuple[int, ...]] = None +# self._chunks_per_dim: Optional[tuple[int, ...]] = None +# self._dirty = False + +# @classmethod +# async def open( +# cls, +# cas: ContentAddressedStore, +# read_only: bool, +# root_cid: Optional[str] = None, +# *, +# array_shape: Optional[tuple[int, ...]] = None, +# chunk_shape: Optional[tuple[int, ...]] = None, +# cid_len: int = 59, # Default for base32 v1 CIDs like bafy... +# ) -> "FlatZarrStore": +# """ +# Asynchronously opens an existing FlatZarrStore or initializes a new one. + +# Args: +# cas: The Content Addressed Store (e.g., KuboCAS). +# read_only: If True, the store is in read-only mode. +# root_cid: The root CID of an existing store to open. Required for read_only. +# array_shape: The full shape of the Zarr array. Required for a new writeable store. +# chunk_shape: The shape of a single chunk. Required for a new writeable store. +# cid_len: The expected byte length of a CID string. +# """ +# store = cls(cas, read_only, root_cid) +# if root_cid: +# await store._load_root_from_cid() +# elif not read_only: +# if not all([array_shape, chunk_shape]): +# raise ValueError( +# "array_shape and chunk_shape must be provided for a new store." +# ) +# store._initialize_new_root(array_shape, chunk_shape, cid_len) +# else: +# raise ValueError("root_cid must be provided for a read-only store.") +# return store + +# def _initialize_new_root( +# self, +# array_shape: tuple[int, ...], +# chunk_shape: tuple[int, ...], +# cid_len: int, +# ): +# self._array_shape = array_shape +# self._chunk_shape = chunk_shape +# self._cid_len = cid_len +# self._chunks_per_dim = tuple( +# math.ceil(a / c) for a, c in zip(array_shape, chunk_shape) +# ) +# self._root_obj = { +# "manifest_version": "flat_zarr_v1", +# "metadata": {}, +# "chunks": { +# "cid": None, # Will be filled on first flush +# "array_shape": list(self._array_shape), +# "chunk_shape": list(self._chunk_shape), +# "cid_byte_length": self._cid_len, +# }, +# } +# self._dirty = True + +# async def _load_root_from_cid(self): +# if not self._root_cid: +# raise ValueError("Cannot load root without a root_cid.") +# root_bytes = await self.cas.load(self._root_cid) +# self._root_obj = dag_cbor.decode(root_bytes) +# chunk_info = self._root_obj.get("chunks", {}) +# self._array_shape = tuple(chunk_info["array_shape"]) +# self._chunk_shape = tuple(chunk_info["chunk_shape"]) +# self._cid_len = chunk_info["cid_byte_length"] +# self._chunks_per_dim = tuple( +# math.ceil(a / c) for a, c in zip(self._array_shape, self._chunk_shape) +# ) + +# def _parse_chunk_key(self, key: str) -> Optional[tuple[int, ...]]: +# if not self._array_shape or not key.startswith("c/"): +# return None +# parts = key.split("/") +# if len(parts) != len(self._array_shape) + 1: +# return None +# try: +# return tuple(map(int, parts[1:])) +# except (ValueError, IndexError): +# return None + +# async def set_partial_values( +# self, key_start_values: Iterable[tuple[str, int, BytesLike]] +# ) -> None: +# """@private""" +# raise NotImplementedError("Partial writes are not supported by this store.") + +# async def get_partial_values( +# self, +# prototype: zarr.core.buffer.BufferPrototype, +# key_ranges: Iterable[tuple[str, zarr.abc.store.ByteRequest | None]], +# ) -> list[zarr.core.buffer.Buffer | None]: +# """ +# Retrieves multiple keys or byte ranges concurrently. +# """ +# tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] +# results = await asyncio.gather(*tasks) +# return results + +# def __eq__(self, other: object) -> bool: +# """@private""" +# if not isinstance(other, FlatZarrStore): +# return NotImplemented +# return self._root_cid == other._root_cid + +# def _get_chunk_offset(self, chunk_coords: tuple[int, ...]) -> int: +# linear_index = 0 +# multiplier = 1 +# for i in reversed(range(len(self._chunks_per_dim))): +# linear_index += chunk_coords[i] * multiplier +# multiplier *= self._chunks_per_dim[i] +# return linear_index * self._cid_len + +# async def flush(self) -> str: +# """ +# Writes all pending changes (metadata and chunk index) to the CAS +# and returns the new root CID. This MUST be called after all writes are complete. +# """ +# if self.read_only or not self._dirty: +# return self._root_cid + +# if self._flat_index_cache is not None: +# flat_index_cid_obj = await self.cas.save( +# bytes(self._flat_index_cache), codec="raw" +# ) +# self._root_obj["chunks"]["cid"] = str(flat_index_cid_obj) + +# root_obj_bytes = dag_cbor.encode(self._root_obj) +# new_root_cid_obj = await self.cas.save(root_obj_bytes, codec="dag-cbor") +# self._root_cid = str(new_root_cid_obj) +# self._dirty = False +# return self._root_cid + +# async def get( +# self, +# key: str, +# prototype: zarr.core.buffer.BufferPrototype, +# byte_range: zarr.abc.store.ByteRequest | None = None, +# ) -> zarr.core.buffer.Buffer | None: +# """@private""" +# if self._root_obj is None: +# await self._load_root_from_cid() + +# chunk_coords = self._parse_chunk_key(key) +# try: +# # Metadata request +# if chunk_coords is None: +# metadata_cid = self._root_obj["metadata"].get(key) +# if metadata_cid is None: +# return None +# data = await self.cas.load(metadata_cid) +# return prototype.buffer.from_bytes(data) + +# # Chunk data request +# flat_index_cid = self._root_obj["chunks"]["cid"] +# if flat_index_cid is None: +# return None + +# offset = self._get_chunk_offset(chunk_coords) +# chunk_cid_bytes = await self.cas.load( +# flat_index_cid, offset=offset, length=self._cid_len +# ) + +# if all(b == 0 for b in chunk_cid_bytes): +# return None # Chunk doesn't exist + +# chunk_cid = chunk_cid_bytes.decode("ascii") +# data = await self.cas.load(chunk_cid) +# return prototype.buffer.from_bytes(data) + +# except (KeyError, IndexError): +# return None + +# async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: +# """@private""" +# if self.read_only: +# raise ValueError("Cannot write to a read-only store.") +# if self._root_obj is None: +# raise RuntimeError("Store not initialized for writing.") + +# self._dirty = True +# raw_bytes = value.to_bytes() +# value_cid_obj = await self.cas.save(raw_bytes, codec="raw") +# value_cid = str(value_cid_obj) + +# if len(value_cid) != self._cid_len: +# raise ValueError( +# f"Inconsistent CID length. Expected {self._cid_len}, got {len(value_cid)}" +# ) + +# chunk_coords = self._parse_chunk_key(key) + +# if chunk_coords is None: # Metadata +# self._root_obj["metadata"][key] = value_cid +# return + +# # Chunk Data +# if self._flat_index_cache is None: +# num_chunks = math.prod(self._chunks_per_dim) +# self._flat_index_cache = bytearray(num_chunks * self._cid_len) + +# offset = self._get_chunk_offset(chunk_coords) +# self._flat_index_cache[offset : offset + self._cid_len] = value_cid.encode( +# "ascii" +# ) + +# # --- Other required zarr.abc.store methods --- + +# async def exists(self, key: str) -> bool: +# """@private""" +# # A more efficient version might check for null bytes in the flat index +# # but this is functionally correct. + +# # TODO: Optimize this check +# return True + + +# # return (await self.get(key, zarr.core.buffer.Buffer.prototype, None)) is not None + +# @property +# def supports_writes(self) -> bool: +# """@private""" +# return not self.read_only + +# @property +# def supports_partial_writes(self) -> bool: +# """@private""" +# return False # Each chunk is an immutable object + +# @property +# def supports_deletes(self) -> bool: +# """@private""" +# return not self.read_only + +# async def delete(self, key: str) -> None: +# if self.read_only: +# raise ValueError("Cannot delete from a read-only store.") +# if self._root_obj is None: +# await self._load_root_from_cid() +# chunk_coords = self._parse_chunk_key(key) +# if chunk_coords is None: +# if key in self._root_obj["metadata"]: +# del self._root_obj["metadata"][key] +# self._dirty = True +# return +# else: +# raise KeyError(f"Metadata key '{key}' not found.") +# flat_index_cid = self._root_obj["chunks"]["cid"] +# if self._flat_index_cache is None: +# if not flat_index_cid: +# raise KeyError(f"Chunk key '{key}' not found in non-existent index.") +# self._flat_index_cache = bytearray(await self.cas.load(flat_index_cid)) +# offset = self._get_chunk_offset(chunk_coords) +# if all(b == 0 for b in self._flat_index_cache[offset : offset + self._cid_len]): +# raise KeyError(f"Chunk key '{key}' not found.") +# self._flat_index_cache[offset : offset + self._cid_len] = bytearray(self._cid_len) +# self._dirty = True + +# @property +# def supports_listing(self) -> bool: +# """@private""" +# return True + +# async def list(self) -> AsyncIterator[str]: +# """@private""" +# if self._root_obj is None: +# await self._load_root_from_cid() +# for key in self._root_obj["metadata"]: +# yield key +# # Note: Listing all chunk keys without reading the index is non-trivial. +# # A full implementation might need an efficient way to iterate non-null chunks. +# # This basic version only lists metadata. + +# async def list_prefix(self, prefix: str) -> AsyncIterator[str]: +# """@private""" +# async for key in self.list(): +# if key.startswith(prefix): +# yield key + +# async def list_dir(self, prefix: str) -> AsyncIterator[str]: +# """@private""" +# # This simplified version only works for the root. +# if prefix == "": +# seen = set() +# async for key in self.list(): +# name = key.split('/')[0] +# if name not in seen: +# seen.add(name) +# yield name \ No newline at end of file diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index ff9f88c..7b6e4d9 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -3,7 +3,7 @@ import json import time from typing import Dict, Any -from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore +from py_hamt import HAMT, KuboCAS, ShardedZarrStore from py_hamt.zarr_hamt_store import ZarrHAMTStore import xarray as xr from multiformats import CID @@ -34,22 +34,14 @@ async def convert_hamt_to_sharded( source_dataset = xr.open_zarr(store=source_store, consolidated=True) # 2. Introspect the source array to get its configuration print("Reading metadata from source store...") - try: - # Read the stores metadata to get array shape and chunk shape - print("Fetching metadata...") - ordered_dims = list(source_dataset.dims) - array_shape_tuple = tuple(source_dataset.dims[dim] for dim in ordered_dims) - chunk_shape_tuple = tuple(source_dataset.chunks[dim][0] for dim in ordered_dims) - array_shape = array_shape_tuple - chunk_shape = chunk_shape_tuple - print("Metadata read successfully.") - print(f"Found Array Shape: {array_shape}") - print(f"Found Chunk Shape: {chunk_shape}") - except KeyError as e: - raise RuntimeError( - f"Could not find required metadata in source .zarray: {e}" - ) from e + # Read the stores metadata to get array shape and chunk shape + ordered_dims = list(source_dataset.dims) + array_shape_tuple = tuple(source_dataset.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(source_dataset.chunks[dim][0] for dim in ordered_dims) + array_shape = array_shape_tuple + chunk_shape = chunk_shape_tuple + # 3. Create the destination ShardedZarrStore for writing print(f"Initializing new ShardedZarrStore with {chunks_per_shard} chunks per shard...") @@ -71,14 +63,12 @@ async def convert_hamt_to_sharded( count += 1 # Read the raw data (metadata or chunk) from the source cid: CID = await hamt_ro.get_pointer(key) - if cid is None: - continue cid_base32_str = str(cid.encode("base32")) # Write the exact same key-value pair to the destination. await dest_store.set_pointer(key, cid_base32_str) - if count % 200 == 0: - print(f"Migrated {count} keys...") + if count % 200 == 0: # pragma: no cover + print(f"Migrated {count} keys...") # pragma: no cover print(f"Migration of {count} total keys complete.") @@ -93,7 +83,7 @@ async def convert_hamt_to_sharded( return new_root_cid -async def main(): +async def sharded_converter_cli(): parser = argparse.ArgumentParser( description="Convert a Zarr HAMT store to a Sharded Zarr store." ) @@ -109,7 +99,7 @@ async def main(): parser.add_argument( "--rpc-url", type=str, - default="http://127.0.0.1:5001/api/v0", + default="http://127.0.0.1:5001", help="The URL of the IPFS Kubo RPC API.", ) parser.add_argument( @@ -119,6 +109,7 @@ async def main(): help="The URL of the IPFS Gateway.", ) args = parser.parse_args() + # Initialize the KuboCAS client with the provided RPC and Gateway URLs async with KuboCAS( rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url ) as cas_client: @@ -132,4 +123,4 @@ async def main(): print(f"\nAn error occurred: {e}") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(sharded_converter_cli()) # pragma: no cover \ No newline at end of file diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 968c787..62e5925 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -118,7 +118,7 @@ def _initialize_new_root( self._root_obj = { "manifest_version": "sharded_zarr_v1", - "metadata": {}, # For .zgroup, .zarray, .zattrs etc. + "metadata": {}, # For .json "chunks": { # Information about the chunk index itself "array_shape": list(self._array_shape), # Original array shape "chunk_shape": list(self._chunk_shape), # Original chunk shape @@ -133,7 +133,7 @@ def _initialize_new_root( async def _load_root_from_cid(self): if not self._root_cid: - raise ValueError("Cannot load root without a root_cid.") + raise RuntimeError("Cannot load root without a root_cid.") root_bytes = await self.cas.load(self._root_cid) self._root_obj = dag_cbor.decode(root_bytes) @@ -171,15 +171,10 @@ async def _fetch_and_cache_full_shard(self, shard_idx: int, shard_cid: str): Manages removal from _pending_shard_loads. """ try: - # Double check if it got cached by another operation while this task was scheduled - if shard_idx in self._shard_data_cache: - return - shard_data_bytes = await self.cas.load(shard_cid) # Load full shard self._shard_data_cache[shard_idx] = bytearray(shard_data_bytes) - # print(f"DEBUG: Successfully cached full shard {shard_idx} (CID: {shard_cid})") - except Exception as e: + print(e) # Handle or log the exception appropriately print(f"Warning: Failed to cache full shard {shard_idx} (CID: {shard_cid}): {e}") # If it fails, subsequent requests might try again if it's still not in cache. @@ -251,8 +246,6 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: return None def _get_linear_chunk_index(self, chunk_coords: Tuple[int, ...]) -> int: - if not self._chunks_per_dim: - raise RuntimeError("Store not initialized: _chunks_per_dim is None.") linear_index = 0 multiplier = 1 # Convert N-D chunk coordinates to a flat 1-D index (row-major order) @@ -271,68 +264,38 @@ def _get_shard_info(self, linear_chunk_index: int) -> Tuple[int, int]: index_in_shard = linear_chunk_index % self._chunks_per_shard return shard_idx, index_in_shard - def _get_actual_chunks_in_shard(self, shard_idx: int) -> int: - return self._chunks_per_shard - # if self._num_shards is None or self._total_chunks is None or self._chunks_per_shard is None: - # raise RuntimeError("Store not properly initialized for shard calculation (_num_shards, _total_chunks, or _chunks_per_shard is None).") - # if not (0 <= shard_idx < self._num_shards): # Handles case where _num_shards can be 0 - # if self._num_shards == 0 and shard_idx == 0 and self._total_chunks == 0: # special case for 0-size array - # return 0 - # raise ValueError(f"Invalid shard index: {shard_idx}. Num shards: {self._num_shards}") - - # if shard_idx == self._num_shards - 1: # Last shard - # remaining_chunks = self._total_chunks - shard_idx * self._chunks_per_shard - # return remaining_chunks - # else: - # return self._chunks_per_shard - async def _load_or_initialize_shard_cache(self, shard_idx: int) -> bytearray: if shard_idx in self._shard_data_cache: return self._shard_data_cache[shard_idx] - # Check if a background load for this shard is already in progress if shard_idx in self._pending_shard_loads: - # print(f"DEBUG: _load_or_initialize_shard_cache - Shard {shard_idx} has a pending load. Awaiting it.") try: await self._pending_shard_loads[shard_idx] - # After awaiting, it should be in the cache if the task was successful if shard_idx in self._shard_data_cache: return self._shard_data_cache[shard_idx] else: - # Task finished but didn't populate cache (e.g., an error occurred in the task) - # print(f"DEBUG: _load_or_initialize_shard_cache - Pending load for {shard_idx} completed but shard not in cache. Proceeding to load manually.") - pass # Fall through to normal loading + pass # Fall through to normal loading except asyncio.CancelledError: - # print(f"DEBUG: _load_or_initialize_shard_cache - Pending load for {shard_idx} was cancelled. Proceeding to load manually.") - # Ensure it's removed if cancelled before its finally block ran if shard_idx in self._pending_shard_loads: del self._pending_shard_loads[shard_idx] # Fall through to normal loading except Exception as e: - # The pending task itself might have failed with an exception print(f"Warning: Pending shard load for {shard_idx} failed: {e}. Attempting fresh load.") - # Fall through to normal loading. The pending task's finally block should have cleaned it up. if self._root_obj is None: raise RuntimeError("Root object not loaded or initialized (_root_obj is None).") if not (0 <= shard_idx < self._num_shards if self._num_shards is not None else False): raise ValueError(f"Shard index {shard_idx} out of bounds for {self._num_shards} shards.") - shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] if shard_cid: shard_data_bytes = await self.cas.load(shard_cid) - # Verify length? - # expected_len = self._get_actual_chunks_in_shard(shard_idx) * self._cid_len - # if len(shard_data_bytes) != expected_len: - # raise ValueError(f"Shard {shard_idx} (CID: {shard_cid}) has unexpected length. Got {len(shard_data_bytes)}, expected {expected_len}") self._shard_data_cache[shard_idx] = bytearray(shard_data_bytes) else: if self._cid_len is None: # Should be set raise RuntimeError("Store not initialized: _cid_len is None for shard initialization.") # New shard or shard not yet written, initialize with zeros - num_chunks_in_this_shard = self._get_actual_chunks_in_shard(shard_idx) - shard_size_bytes = num_chunks_in_this_shard * self._cid_len + shard_size_bytes = self._chunks_per_shard * self._cid_len self._shard_data_cache[shard_idx] = bytearray(shard_size_bytes) # Filled with \x00 return self._shard_data_cache[shard_idx] @@ -404,107 +367,69 @@ async def get( byte_range: Optional[zarr.abc.store.ByteRequest] = None, ) -> Optional[zarr.core.buffer.Buffer]: if self._root_obj is None: - if not self._root_cid: - raise ValueError("Store not initialized and no root_cid to load from.") - await self._load_root_from_cid() # This will populate self._root_obj - if self._root_obj is None: # Should be loaded by _load_root_from_cid - raise RuntimeError("Failed to load root object after _load_root_from_cid call.") + raise RuntimeError("Load the root object first before accessing data.") chunk_coords = self._parse_chunk_key(key) - try: - # Metadata request (e.g., ".zarray", ".zgroup") - if chunk_coords is None: - metadata_cid = self._root_obj["metadata"].get(key) - if metadata_cid is None: - return None - # byte_range is not typically applicable to metadata JSON objects themselves - if byte_range is not None: - # Consider if this should be an error or ignored for metadata - print(f"Warning: byte_range requested for metadata key '{key}'. Ignoring range.") - data = await self.cas.load(metadata_cid) - return prototype.buffer.from_bytes(data) - - # Chunk data request (e.g., "c/0/0/0") - if self._cid_len is None: # Should be set during init/load - raise RuntimeError("Store not properly initialized: _cid_len is None.") + # Metadata request (e.g., ".json") + if chunk_coords is None: + metadata_cid = self._root_obj["metadata"].get(key) + if metadata_cid is None: + return None + # byte_range is not typically applicable to metadata JSON objects themselves + if byte_range is not None: + # Consider if this should be an error or ignored for metadata + print(f"Warning: byte_range requested for metadata key '{key}'. Ignoring range.") + data = await self.cas.load(metadata_cid) + return prototype.buffer.from_bytes(data) - linear_chunk_index = self._get_linear_chunk_index(chunk_coords) - shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - # print("SHARD LOCATION", linear_chunk_index, shard_idx, index_in_shard) # Debugging info + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + if not (0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): + # This case implies linear_chunk_index was out of _total_chunks bounds or bad sharding logic + return None - if not (0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): - # This case implies linear_chunk_index was out of _total_chunks bounds or bad sharding logic - return None + target_shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] + if target_shard_cid is None: # This shard has no data (all chunks within it are implicitly empty) + return None - target_shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] - if target_shard_cid is None: # This shard has no data (all chunks within it are implicitly empty) - return None + offset_in_shard_bytes = index_in_shard * self._cid_len + chunk_cid_bytes: Optional[bytes] = None - offset_in_shard_bytes = index_in_shard * self._cid_len - chunk_cid_bytes: Optional[bytes] = None + if shard_idx in self._shard_data_cache: + cached_shard_data = self._shard_data_cache[shard_idx] + chunk_cid_bytes = bytes(cached_shard_data[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len]) - if shard_idx in self._shard_data_cache: - cached_shard_data = self._shard_data_cache[shard_idx] - if offset_in_shard_bytes + self._cid_len <= len(cached_shard_data): - chunk_cid_bytes = bytes(cached_shard_data[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len]) - else: - # This would indicate an inconsistency or error in shard data/cache. - print(f"Warning: Cached shard {shard_idx} is smaller than expected for key {key}. Re-fetching CID.") - # Fall through to fetch from CAS, and potentially re-cache full shard. - # To be very robust, you might consider invalidating this cache entry here. - del self._shard_data_cache[shard_idx] # Invalidate corrupted/short cache entry - if shard_idx in self._pending_shard_loads: # Cancel if a load was pending for this now-invalidated cache - self._pending_shard_loads[shard_idx].cancel() - del self._pending_shard_loads[shard_idx] - # Fallthrough to load chunk_cid_bytes directly - - if chunk_cid_bytes is None: # Not in cache or cache was invalid - # print(f"DEBUG: get() - Shard {shard_idx} not in cache or invalid. Fetching specific CID. Key: {key}") - try: - chunk_cid_bytes = await self.cas.load( - target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len - ) - except Exception as e: # Handle error from CAS load (e.g. shard CID not found, network issue) - # print(f"Error: Failed to load specific CID bytes from shard {target_shard_cid} for key {key}: {e}") - return None # Chunk CID couldn't be retrieved - - # After successfully fetching the specific CID bytes, - # check if we should initiate a background load of the full shard. - if shard_idx not in self._shard_data_cache and shard_idx not in self._pending_shard_loads: - # print(f"DEBUG: get() - Initiating background cache for full shard {shard_idx} (CID: {target_shard_cid}). Key: {key}") - self._pending_shard_loads[shard_idx] = asyncio.create_task( - self._fetch_and_cache_full_shard(shard_idx, target_shard_cid) - ) - - # Load the specific CID from the shard - # chunk_cid_bytes = await self.cas.load( - # target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len - # ) - - if all(b == 0 for b in chunk_cid_bytes): # Check for null CID placeholder (e.g. \x00 * cid_len) - return None # Chunk doesn't exist or is considered empty - - # Decode CID (assuming ASCII, remove potential null padding) - chunk_cid_str = chunk_cid_bytes.decode("ascii").rstrip('\x00') - if not chunk_cid_str: # Empty string after rstrip if all were \x00 (already caught above) - return None - - # Actual chunk data load using the retrieved chunk_cid_str - req_offset = byte_range.start if byte_range else None - req_length = None - if byte_range: - if byte_range.stop is not None: - if byte_range.start > byte_range.stop: # Zarr allows start == stop for 0 length - raise ValueError(f"Byte range start ({byte_range.start}) cannot be greater than stop ({byte_range.stop})") - req_length = byte_range.stop - byte_range.start - - data = await self.cas.load(chunk_cid_str, offset=req_offset, length=req_length) - return prototype.buffer.from_bytes(data) + if chunk_cid_bytes is None: # Not in cache or cache was invalid + chunk_cid_bytes = await self.cas.load( + target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len + ) + # After successfully fetching the specific CID bytes, + # check if we should initiate a background load of the full shard. + if shard_idx not in self._shard_data_cache and shard_idx not in self._pending_shard_loads: + self._pending_shard_loads[shard_idx] = asyncio.create_task( + self._fetch_and_cache_full_shard(shard_idx, target_shard_cid) + ) - except (KeyError, IndexError, TypeError, ValueError) as e: - # print(f"Error during get for key {key} (coords: {chunk_coords}): {type(e).__name__} - {e}") # for debugging - return None # Consistent with Zarr behavior for missing keys + if all(b == 0 for b in chunk_cid_bytes): # Check for null CID placeholder (e.g. \x00 * cid_len) + return None # Chunk doesn't exist or is considered empty + + # Decode CID (assuming ASCII, remove potential null padding) + chunk_cid_str = chunk_cid_bytes.decode("ascii").rstrip('\x00') + if not chunk_cid_str: # Empty string after rstrip if all were \x00 (already caught above) + return None + + # Actual chunk data load using the retrieved chunk_cid_str + req_offset = byte_range.start if byte_range else None + req_length = None + if byte_range: + if byte_range.end is not None: + if byte_range.start > byte_range.end: # Zarr allows start == stop for 0 length + raise ValueError(f"Byte range start ({byte_range.start}) cannot be greater than end ({byte_range.end})") + req_length = byte_range.end - byte_range.start + + data = await self.cas.load(chunk_cid_str, offset=req_offset, length=req_length) + return prototype.buffer.from_bytes(data) async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: @@ -512,8 +437,6 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: raise ValueError("Cannot write to a read-only store.") if self._root_obj is None: raise RuntimeError("Store not initialized for writing (root_obj is None). Call open() first.") - if self._cid_len is None: - raise RuntimeError("Store not initialized for writing (_cid_len is None).") raw_chunk_data_bytes = value.to_bytes() # Save the actual chunk data to CAS first, to get its CID @@ -568,21 +491,12 @@ async def set_pointer( async def exists(self, key: str) -> bool: if self._root_obj is None: - if not self._root_cid: return False - try: - await self._load_root_from_cid() - except Exception: # If loading fails, it doesn't exist in this store - return False - if self._root_obj is None: return False - + raise RuntimeError("Root object not loaded. Call _load_root_from_cid() first.") chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata return key in self._root_obj.get("metadata", {}) - # Chunk - if self._cid_len is None: return False # Store not properly configured - try: linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) @@ -631,15 +545,7 @@ async def delete(self, key: str) -> None: if self.read_only: raise ValueError("Cannot delete from a read-only store.") if self._root_obj is None: - if self._root_cid: # Try loading if deleting from an existing, non-modified store - try: - await self._load_root_from_cid() - except Exception as e: # If load fails, can't proceed - raise RuntimeError(f"Failed to load store for deletion: {e}") - if self._root_obj is None: # Still None after attempt - raise RuntimeError("Store not initialized for deletion (root_obj is None).") - if self._cid_len is None: - raise RuntimeError("Store not properly initialized for deletion (_cid_len is None).") + raise RuntimeError("Store not initialized for deletion (root_obj is None).") chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata @@ -685,28 +591,14 @@ async def delete(self, key: str) -> None: @property def supports_listing(self) -> bool: - return True # Can list metadata keys + return True async def list(self) -> AsyncIterator[str]: - if self._root_obj is None: - if not self._root_cid: - return # Equivalent to `yield from ()` for async iterators - try: - await self._load_root_from_cid() - except Exception: # If loading fails, store is effectively empty for listing - return - if self._root_obj is None: - return - for key in self._root_obj.get("metadata", {}): yield key - # Listing all actual chunk keys would require iterating all shards and - # checking for non-null CIDs, which is expensive and not implemented here. - # This behavior is consistent with the provided FlatZarrStore example. async def list_prefix(self, prefix: str) -> AsyncIterator[str]: - # Only lists metadata keys matching prefix. - async for key in self.list(): # self.list() currently only yields metadata keys + async for key in self.list(): if key.startswith(prefix): yield key @@ -714,12 +606,7 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: # This simplified version only works for the root directory (prefix == "") of metadata. # It lists unique first components of metadata keys. if self._root_obj is None: - if not self._root_cid: return - try: - await self._load_root_from_cid() - except Exception: - return - if self._root_obj is None: return + raise RuntimeError("Root object not loaded. Call _load_root_from_cid() first.") seen: Set[str] = set() if prefix == "": @@ -736,11 +623,10 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: # Zarr spec: list_dir(path) should yield children (both objects and "directories") # For simplicity, and consistency with original FlatZarrStore, keeping this minimal. # To make it more compliant for prefix="foo/": - # normalized_prefix = prefix if prefix.endswith('/') else prefix + '/' - # async for key in self.list_prefix(normalized_prefix): - # remainder = key[len(normalized_prefix):] - # child = remainder.split('/', 1)[0] - # if child not in seen: - # seen.add(child) - # yield child - pass # Or raise NotImplementedError for non-empty prefixes if strict. \ No newline at end of file + normalized_prefix = prefix if prefix.endswith('/') else prefix + '/' + async for key in self.list_prefix(normalized_prefix): + remainder = key[len(normalized_prefix):] + child = remainder.split('/', 1)[0] + if child not in seen: + seen.add(child) + yield child \ No newline at end of file diff --git a/tests/test_benchmark_stores.py b/tests/test_benchmark_stores.py index cff1b98..27a8646 100644 --- a/tests/test_benchmark_stores.py +++ b/tests/test_benchmark_stores.py @@ -1,282 +1,312 @@ -import time - -import numpy as np -import pandas as pd -import pytest -import xarray as xr -from dag_cbor.ipld import IPLDKind - -# Import both store implementations -from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore -from py_hamt.zarr_hamt_store import ZarrHAMTStore - - -@pytest.fixture(scope="module") -def random_zarr_dataset(): - """Creates a random xarray Dataset for benchmarking.""" - # Using a slightly larger dataset for a more meaningful benchmark - times = pd.date_range("2024-01-01", periods=100) - lats = np.linspace(-90, 90, 18) - lons = np.linspace(-180, 180, 36) - - temp = np.random.randn(len(times), len(lats), len(lons)) - precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons))) - - ds = xr.Dataset( - { - "temp": (["time", "lat", "lon"], temp), - }, - coords={"time": times, "lat": lats, "lon": lons}, - ) - - # Define chunking for the store - ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) - yield ds - - -# ### -# BENCHMARK FOR THE ORIGINAL ZarrHAMTStore -# ### -@pytest.mark.asyncio(loop_scope="session") -async def test_benchmark_hamt_store( - create_ipfs: tuple[str, str], - random_zarr_dataset: xr.Dataset, -): - """Benchmarks write and read performance for the ZarrHAMTStore.""" - print("\n\n" + "=" * 80) - print("🚀 STARTING BENCHMARK for ZarrHAMTStore") - print("=" * 80) - - rpc_base_url, gateway_base_url = create_ipfs - - # rpc_base_url = f"https://ipfs-gateway.dclimate.net" - # gateway_base_url = f"https://ipfs-gateway.dclimate.net" - # headers = { - # "X-API-Key": "", - # } - headers = {} - test_ds = random_zarr_dataset - - async with KuboCAS( - rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers - ) as kubo_cas: - # --- Write --- - print("Building HAMT store...") - hamt = await HAMT.build(cas=kubo_cas, values_are_bytes=True) - print("HAMT store built successfully.") - zhs = ZarrHAMTStore(hamt) - print("ZarrHAMTStore created successfully.") - - start_write = time.perf_counter() - # Perform an initial write and an append to simulate a common workflow - test_ds.to_zarr(store=zhs, mode="w") - print("Initial write completed, now appending...") - test_ds.to_zarr(store=zhs, mode="a", append_dim="time") - await hamt.make_read_only() # Flush and freeze to get the final CID - end_write = time.perf_counter() - - cid: IPLDKind = hamt.root_node_id - print(f"\n--- [HAMT] Write Stats ---") - print(f"Total time to write and append: {end_write - start_write:.2f} seconds") - print(f"Final Root CID: {cid}") - - # --- Read --- - hamt_ro = await HAMT.build( - cas=kubo_cas, root_node_id=cid, values_are_bytes=True, read_only=True - ) - zhs_ro = ZarrHAMTStore(hamt_ro, read_only=True) - - start_read = time.perf_counter() - ipfs_ds = xr.open_zarr(store=zhs_ro) - # Force a read of some data to ensure it's loaded - _ = ipfs_ds.temp.isel(time=0).values - print(_) - end_read = time.perf_counter() - - print(f"\n--- [HAMT] Read Stats ---") - print(f"Total time to open and read: {end_read - start_read:.2f} seconds") - - # --- Verification --- - full_test_ds = xr.concat([test_ds, test_ds], dim="time") - xr.testing.assert_identical(full_test_ds, ipfs_ds) - print("\n✅ [HAMT] Data verification successful.") - print("=" * 80) - - -# ### -# BENCHMARK FOR THE NEW FlatZarrStore -# ### -@pytest.mark.asyncio(loop_scope="session") -async def test_benchmark_flat_store( - create_ipfs: tuple[str, str], - random_zarr_dataset: xr.Dataset, -): - """Benchmarks write and read performance for the new FlatZarrStore.""" - print("\n\n" + "=" * 80) - print("🚀 STARTING BENCHMARK for FlatZarrStore") - print("=" * 80) - - rpc_base_url, gateway_base_url = create_ipfs - # rpc_base_url = f"https://ipfs-gateway.dclimate.net" - # gateway_base_url = f"https://ipfs-gateway.dclimate.net" - # headers = { - # "X-API-Key": "", - # } - headers = {} - test_ds = random_zarr_dataset - - async with KuboCAS( - rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers - ) as kubo_cas: - # --- Write --- - # The full shape after appending - appended_shape = list(test_ds.dims.values()) - time_axis_index = list(test_ds.dims).index("time") - appended_shape[time_axis_index] *= 2 - final_array_shape = tuple(appended_shape) - - final_chunk_shape = [] - for dim_name in test_ds.dims: # Preserves dimension order - if dim_name in test_ds.chunks: - # test_ds.chunks[dim_name] is a tuple e.g. (20,) - final_chunk_shape.append(test_ds.chunks[dim_name][0]) - else: - # Fallback if a dimension isn't explicitly chunked (should use its full size) - final_chunk_shape.append(test_ds.dims[dim_name]) - final_chunk_shape = tuple(final_chunk_shape) - - store_write = await FlatZarrStore.open( - cas=kubo_cas, - read_only=False, - array_shape=final_array_shape, - chunk_shape=final_chunk_shape, - ) - - start_write = time.perf_counter() - # Perform an initial write and an append - test_ds.to_zarr(store=store_write, mode="w") - test_ds.to_zarr(store=store_write, mode="a", append_dim="time") - root_cid = await store_write.flush() # Flush to get the final CID - end_write = time.perf_counter() - - print(f"\n--- [FlatZarr] Write Stats ---") - print(f"Total time to write and append: {end_write - start_write:.2f} seconds") - print(f"Final Root CID: {root_cid}") - - # --- Read --- - store_read = await FlatZarrStore.open( - cas=kubo_cas, read_only=True, root_cid=root_cid - ) - - start_read = time.perf_counter() - ipfs_ds = xr.open_zarr(store=store_read) - # Force a read of some data to ensure it's loaded - _ = ipfs_ds.temp.isel(time=0).values - print(_) - end_read = time.perf_counter() - - print(f"\n--- [FlatZarr] Read Stats ---") - print(f"Total time to open and read: {end_read - start_read:.2f} seconds") - - # --- Verification --- - full_test_ds = xr.concat([test_ds, test_ds], dim="time") - xr.testing.assert_identical(full_test_ds, ipfs_ds) - print("\n✅ [FlatZarr] Data verification successful.") - print("=" * 80) - -@pytest.mark.asyncio(loop_scope="session") -async def test_benchmark_sharded_store( # Renamed function - create_ipfs: tuple[str, str], - random_zarr_dataset: xr.Dataset, -): - """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring - print("\n\n" + "=" * 80) - print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print - print("=" * 80) - - rpc_base_url, gateway_base_url = create_ipfs - - # rpc_base_url = f"https://ipfs-gateway.dclimate.net" - # gateway_base_url = f"https://ipfs-gateway.dclimate.net" - # headers = { - # "X-API-Key": "", - # } - headers = {} - test_ds = random_zarr_dataset - - # Define chunks_per_shard for the ShardedZarrStore - chunks_per_shard_config = 1024 # Configuration for sharding - - async with KuboCAS( - rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers - ) as kubo_cas: - # --- Write --- - # The full shape after appending - appended_shape = list(test_ds.dims.values()) - time_axis_index = list(test_ds.dims).index("time") - appended_shape[time_axis_index] *= 2 # Simulating appending along time dimension - final_array_shape = tuple(appended_shape) - - # Determine chunk shape from the dataset's encoding or dimensions - final_chunk_shape_list = [] - for dim_name in test_ds.dims: # Preserves dimension order from the dataset - if dim_name in test_ds.chunks: - # test_ds.chunks is a dict like {'time': (20,), 'y': (20,), 'x': (20,)} - final_chunk_shape_list.append(test_ds.chunks[dim_name][0]) - else: - # Fallback if a dimension isn't explicitly chunked (should use its full size) - final_chunk_shape_list.append(test_ds.dims[dim_name]) - final_chunk_shape = tuple(final_chunk_shape_list) - - # Use ShardedZarrStore and provide chunks_per_shard - store_write = await ShardedZarrStore.open( - cas=kubo_cas, - read_only=False, - array_shape=final_array_shape, - chunk_shape=final_chunk_shape, - chunks_per_shard=chunks_per_shard_config # Added new parameter - ) - - start_write = time.perf_counter() - # Perform an initial write and an append - test_ds.to_zarr(store=store_write, mode="w") - test_ds.to_zarr(store=store_write, mode="a", append_dim="time") - root_cid = await store_write.flush() # Flush to get the final CID - end_write = time.perf_counter() - - print(f"\n--- [ShardedZarr] Write Stats (chunks_per_shard={chunks_per_shard_config}) ---") # Updated print - print(f"Total time to write and append: {end_write - start_write:.2f} seconds") - print(f"Final Root CID: {root_cid}") - - print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print - # --- Read --- - # When opening for read, chunks_per_shard is read from the store's metadata - store_read = await ShardedZarrStore.open( # Use ShardedZarrStore - cas=kubo_cas, read_only=True, root_cid=root_cid - ) - - start_read = time.perf_counter() - ipfs_ds = xr.open_zarr(store=store_read) - # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) - if "temp" in ipfs_ds.variables and "time" in ipfs_ds.coords: - _ = ipfs_ds.temp.isel(time=0).values - print(_) - elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable - first_var_name = list(ipfs_ds.data_vars.keys())[0] - # Construct a minimal selection based on available dimensions - selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} - if selection: - _ = ipfs_ds[first_var_name].isel(**selection).values - else: # If no dimensions, try loading the whole variable (e.g. scalar) - _ = ipfs_ds[first_var_name].values - end_read = time.perf_counter() - - print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print - print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") - - # --- Verification --- - # Create the expected full dataset after append operation - full_test_ds = xr.concat([test_ds, test_ds], dim="time") - xr.testing.assert_identical(full_test_ds, ipfs_ds) - print("\n✅ [ShardedZarr] Data verification successful.") # Updated print - print("=" * 80) \ No newline at end of file +# import time + +# import numpy as np +# import pandas as pd +# import pytest +# import xarray as xr +# from dag_cbor.ipld import IPLDKind + +# # Import both store implementations +# from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore +# from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +# @pytest.fixture(scope="module") +# def random_zarr_dataset(): +# """Creates a random xarray Dataset for benchmarking.""" +# # Using a slightly larger dataset for a more meaningful benchmark +# times = pd.date_range("2024-01-01", periods=100) +# lats = np.linspace(-90, 90, 18) +# lons = np.linspace(-180, 180, 36) + +# temp = np.random.randn(len(times), len(lats), len(lons)) +# precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons))) + +# ds = xr.Dataset( +# { +# "temp": (["time", "lat", "lon"], temp), +# }, +# coords={"time": times, "lat": lats, "lon": lons}, +# ) + +# # Define chunking for the store +# ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) +# yield ds + +# @pytest.fixture(scope="module") +# def random_shard_dataset(): +# """Creates a random xarray Dataset for benchmarking.""" +# # Using a slightly larger dataset for a more meaningful benchmark +# times = pd.date_range("2024-01-01", periods=100) +# lats = np.linspace(-90, 90, 18) +# lons = np.linspace(-180, 180, 36) + +# temp = np.random.randn(len(times), len(lats), len(lons)) +# precip = np.random.gamma(4, 2.5, size=(len(times), len(lats), len(lons))) + +# ds = xr.Dataset( +# { +# "precip": (["time", "lat", "lon"], precip), +# }, +# coords={"time": times, "lat": lats, "lon": lons}, +# ) + +# # Define chunking for the store +# ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) +# yield ds + + + + +# # # ### +# # # BENCHMARK FOR THE NEW FlatZarrStore +# # # ### +# # @pytest.mark.asyncio(loop_scope="session") +# # async def test_benchmark_flat_store( +# # create_ipfs: tuple[str, str], +# # random_zarr_dataset: xr.Dataset, +# # ): +# # """Benchmarks write and read performance for the new FlatZarrStore.""" +# # print("\n\n" + "=" * 80) +# # print("🚀 STARTING BENCHMARK for FlatZarrStore") +# # print("=" * 80) + +# # rpc_base_url, gateway_base_url = create_ipfs +# # # rpc_base_url = f"https://ipfs-gateway.dclimate.net" +# # # gateway_base_url = f"https://ipfs-gateway.dclimate.net" +# # # headers = { +# # # "X-API-Key": "", +# # # } +# # headers = {} +# # test_ds = random_zarr_dataset + +# # async with KuboCAS( +# # rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# # ) as kubo_cas: +# # # --- Write --- +# # # The full shape after appending +# # appended_shape = list(test_ds.dims.values()) +# # time_axis_index = list(test_ds.dims).index("time") +# # appended_shape[time_axis_index] *= 2 +# # final_array_shape = tuple(appended_shape) + +# # final_chunk_shape = [] +# # for dim_name in test_ds.dims: # Preserves dimension order +# # if dim_name in test_ds.chunks: +# # # test_ds.chunks[dim_name] is a tuple e.g. (20,) +# # final_chunk_shape.append(test_ds.chunks[dim_name][0]) +# # else: +# # # Fallback if a dimension isn't explicitly chunked (should use its full size) +# # final_chunk_shape.append(test_ds.dims[dim_name]) +# # final_chunk_shape = tuple(final_chunk_shape) + +# # store_write = await FlatZarrStore.open( +# # cas=kubo_cas, +# # read_only=False, +# # array_shape=final_array_shape, +# # chunk_shape=final_chunk_shape, +# # ) + +# # start_write = time.perf_counter() +# # # Perform an initial write and an append +# # test_ds.to_zarr(store=store_write, mode="w") +# # test_ds.to_zarr(store=store_write, mode="a", append_dim="time") +# # root_cid = await store_write.flush() # Flush to get the final CID +# # end_write = time.perf_counter() + +# # print(f"\n--- [FlatZarr] Write Stats ---") +# # print(f"Total time to write and append: {end_write - start_write:.2f} seconds") +# # print(f"Final Root CID: {root_cid}") + +# # # --- Read --- +# # store_read = await FlatZarrStore.open( +# # cas=kubo_cas, read_only=True, root_cid=root_cid +# # ) + +# # start_read = time.perf_counter() +# # ipfs_ds = xr.open_zarr(store=store_read) +# # # Force a read of some data to ensure it's loaded +# # _ = ipfs_ds.temp.isel(time=0).values +# # end_read = time.perf_counter() + +# # print(f"\n--- [FlatZarr] Read Stats ---") +# # print(f"Total time to open and read: {end_read - start_read:.2f} seconds") + +# # # --- Verification --- +# # full_test_ds = xr.concat([test_ds, test_ds], dim="time") +# # xr.testing.assert_identical(full_test_ds, ipfs_ds) +# # print("\n✅ [FlatZarr] Data verification successful.") +# # print("=" * 80) + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_benchmark_sharded_store( # Renamed function +# create_ipfs: tuple[str, str], +# random_shard_dataset: xr.Dataset, +# ): +# """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring +# print("\n\n" + "=" * 80) +# print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print +# print("=" * 80) + +# rpc_base_url, gateway_base_url = create_ipfs + +# rpc_base_url = f"https://ipfs-gateway.dclimate.net" +# gateway_base_url = f"https://ipfs-gateway.dclimate.net" +# headers = { +# "X-API-Key": "", +# } +# # headers = {} +# test_ds = random_shard_dataset + +# # Define chunks_per_shard for the ShardedZarrStore +# chunks_per_shard_config = 50 # Configuration for sharding + +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# ) as kubo_cas: +# # --- Write --- +# # The full shape after appending +# appended_shape = list(test_ds.dims.values()) +# time_axis_index = list(test_ds.dims).index("time") +# appended_shape[time_axis_index] *= 2 # Simulating appending along time dimension +# final_array_shape = tuple(appended_shape) + +# # Determine chunk shape from the dataset's encoding or dimensions +# final_chunk_shape_list = [] +# for dim_name in test_ds.dims: # Preserves dimension order from the dataset +# if dim_name in test_ds.chunks: +# # test_ds.chunks is a dict like {'time': (20,), 'y': (20,), 'x': (20,)} +# final_chunk_shape_list.append(test_ds.chunks[dim_name][0]) +# else: +# # Fallback if a dimension isn't explicitly chunked (should use its full size) +# final_chunk_shape_list.append(test_ds.dims[dim_name]) +# final_chunk_shape = tuple(final_chunk_shape_list) + +# # Use ShardedZarrStore and provide chunks_per_shard +# store_write = await ShardedZarrStore.open( +# cas=kubo_cas, +# read_only=False, +# array_shape=final_array_shape, +# chunk_shape=final_chunk_shape, +# chunks_per_shard=chunks_per_shard_config # Added new parameter +# ) + +# start_write = time.perf_counter() +# # Perform an initial write and an append +# test_ds.to_zarr(store=store_write, mode="w") +# test_ds.to_zarr(store=store_write, mode="a", append_dim="time") +# root_cid = await store_write.flush() # Flush to get the final CID +# end_write = time.perf_counter() + +# print(f"\n--- [ShardedZarr] Write Stats (chunks_per_shard={chunks_per_shard_config}) ---") # Updated print +# print(f"Total time to write and append: {end_write - start_write:.2f} seconds") +# print(f"Final Root CID: {root_cid}") + +# print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print +# # --- Read --- +# # When opening for read, chunks_per_shard is read from the store's metadata +# store_read = await ShardedZarrStore.open( # Use ShardedZarrStore +# cas=kubo_cas, read_only=True, root_cid=root_cid +# ) + +# start_read = time.perf_counter() +# ipfs_ds = xr.open_zarr(store=store_read) +# # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) +# if "precip" in ipfs_ds.variables and "time" in ipfs_ds.coords: +# # _ = ipfs_ds.temp.isel(time=0).values +# data_fetched = ipfs_ds.precip.isel(time=slice(0, 1)).values + +# # Calculate the size of the fetched data +# data_size = data_fetched.nbytes if data_fetched is not None else 0 +# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") +# elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable +# first_var_name = list(ipfs_ds.data_vars.keys())[0] +# # Construct a minimal selection based on available dimensions +# selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} +# if selection: +# _ = ipfs_ds[first_var_name].isel(**selection).values +# else: # If no dimensions, try loading the whole variable (e.g. scalar) +# _ = ipfs_ds[first_var_name].values +# end_read = time.perf_counter() + +# print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print +# print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") + +# # --- Verification --- +# # Create the expected full dataset after append operation +# full_test_ds = xr.concat([test_ds, test_ds], dim="time") +# xr.testing.assert_identical(full_test_ds, ipfs_ds) +# print("\n✅ [ShardedZarr] Data verification successful.") # Updated print +# print("=" * 80) + +# # ### +# # BENCHMARK FOR THE ORIGINAL ZarrHAMTStore +# # ### +# @pytest.mark.asyncio(loop_scope="session") +# async def test_benchmark_hamt_store( +# create_ipfs: tuple[str, str], +# random_zarr_dataset: xr.Dataset, +# ): +# """Benchmarks write and read performance for the ZarrHAMTStore.""" +# print("\n\n" + "=" * 80) +# print("🚀 STARTING BENCHMARK for ZarrHAMTStore") +# print("=" * 80) + +# rpc_base_url, gateway_base_url = create_ipfs + +# # rpc_base_url = f"https://ipfs-gateway.dclimate.net" +# # gateway_base_url = f"https://ipfs-gateway.dclimate.net" +# # headers = { +# # "X-API-Key": "", +# # } +# headers = {} +# test_ds = random_zarr_dataset + +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# ) as kubo_cas: +# # --- Write --- +# print("Building HAMT store...") +# hamt = await HAMT.build(cas=kubo_cas, values_are_bytes=True) +# print("HAMT store built successfully.") +# zhs = ZarrHAMTStore(hamt) +# print("ZarrHAMTStore created successfully.") + +# start_write = time.perf_counter() +# # Perform an initial write and an append to simulate a common workflow +# test_ds.to_zarr(store=zhs, mode="w") +# print("Initial write completed, now appending...") +# test_ds.to_zarr(store=zhs, mode="a", append_dim="time") +# await hamt.make_read_only() # Flush and freeze to get the final CID +# end_write = time.perf_counter() + +# cid: IPLDKind = hamt.root_node_id +# print(f"\n--- [HAMT] Write Stats ---") +# print(f"Total time to write and append: {end_write - start_write:.2f} seconds") +# print(f"Final Root CID: {cid}") + +# # --- Read --- +# hamt_ro = await HAMT.build( +# cas=kubo_cas, root_node_id=cid, values_are_bytes=True, read_only=True +# ) +# zhs_ro = ZarrHAMTStore(hamt_ro, read_only=True) + +# start_read = time.perf_counter() +# ipfs_ds = xr.open_zarr(store=zhs_ro) +# # Force a read of some data to ensure it's loaded +# data_fetched = ipfs_ds.temp.isel(time=slice(0, 1)).values + +# # Calculate the size of the fetched data +# data_size = data_fetched.nbytes if data_fetched is not None else 0 +# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") +# end_read = time.perf_counter() + +# print(f"\n--- [HAMT] Read Stats ---") +# print(f"Total time to open and read: {end_read - start_read:.2f} seconds") + + +# # --- Verification --- +# full_test_ds = xr.concat([test_ds, test_ds], dim="time") +# xr.testing.assert_identical(full_test_ds, ipfs_ds) +# print("\n✅ [HAMT] Data verification successful.") +# print("=" * 80) diff --git a/tests/test_converter.py b/tests/test_converter.py index e72eece..6c46827 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -1,6 +1,9 @@ import asyncio import time import uuid +import sys +from unittest.mock import patch +import aiohttp import numpy as np import pandas as pd @@ -8,7 +11,7 @@ import xarray as xr # Import store implementations -from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore, convert_hamt_to_sharded +from py_hamt import HAMT, KuboCAS, ShardedZarrStore, convert_hamt_to_sharded, sharded_converter_cli from py_hamt.zarr_hamt_store import ZarrHAMTStore @@ -126,4 +129,127 @@ async def test_converter_produces_identical_dataset( ) print("\n✅ Verification successful! The datasets are identical.") - print("=" * 80) \ No newline at end of file + print("=" * 80) + +@pytest.mark.asyncio(loop_scope="session") +async def test_hamt_to_sharded_cli_success( + create_ipfs: tuple[str, str], + converter_test_dataset: xr.Dataset, + capsys +): + """ + Tests the CLI for successful conversion of a HAMT store to a ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = converter_test_dataset + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Step 1: Create a HAMT store with the test dataset + hamt_write = await HAMT.build(cas=kubo_cas, values_are_bytes=True) + source_hamt_store = ZarrHAMTStore(hamt_write) + test_ds.to_zarr(store=source_hamt_store, mode="w", consolidated=True) + await hamt_write.make_read_only() + hamt_root_cid = str(hamt_write.root_node_id) + + # Step 2: Simulate CLI execution with valid arguments + test_args = [ + "script.py", # Dummy script name + hamt_root_cid, + "--chunks-per-shard", "64", + "--rpc-url", rpc_base_url, + "--gateway-url", gateway_base_url + ] + with patch.object(sys, "argv", test_args): + await sharded_converter_cli() + + # Step 3: Capture and verify CLI output + captured = capsys.readouterr() + assert "Starting Conversion from HAMT Root" in captured.out + assert "Conversion Complete!" in captured.out + assert f"New ShardedZarrStore Root CID" in captured.out + + # Step 4: Verify the converted dataset + # Extract the new root CID from output (assuming it's the last line) + output_lines = captured.out.strip().split("\n") + new_root_cid = output_lines[-1].split(": ")[-1] + dest_store_ro = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + ds_from_sharded = xr.open_zarr(dest_store_ro) + xr.testing.assert_identical(test_ds, ds_from_sharded) + +@pytest.mark.asyncio(loop_scope="session") +async def test_hamt_to_sharded_cli_default_args( + create_ipfs: tuple[str, str], + converter_test_dataset: xr.Dataset, + capsys +): + """ + Tests the CLI with default argument values. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = converter_test_dataset + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Create a HAMT store + hamt_write = await HAMT.build(cas=kubo_cas, values_are_bytes=True) + source_hamt_store = ZarrHAMTStore(hamt_write) + test_ds.to_zarr(store=source_hamt_store, mode="w", consolidated=True) + await hamt_write.make_read_only() + hamt_root_cid = str(hamt_write.root_node_id) + + # Simulate CLI with only hamt_cid and gateway URLs. + test_args = [ + "script.py", # Dummy script name + hamt_root_cid, + "--rpc-url", rpc_base_url, + "--gateway-url", gateway_base_url + ] + with patch.object(sys, "argv", test_args): + await sharded_converter_cli() + + # Verify output and conversion + captured = capsys.readouterr() + output_lines = captured.out.strip().split("\n") + print("Captured CLI Output:") + for line in output_lines: + print(line) + new_root_cid = output_lines[-1].split(": ")[-1] + dest_store_ro = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + ds_from_sharded = xr.open_zarr(dest_store_ro) + xr.testing.assert_identical(test_ds, ds_from_sharded) + +@pytest.mark.asyncio(loop_scope="session") +async def test_hamt_to_sharded_cli_invalid_cid( + create_ipfs: tuple[str, str], + capsys +): + """ + Tests the CLI with an invalid hamt_cid. + """ + rpc_base_url, gateway_base_url = create_ipfs + invalid_cid = "invalid_cid" + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + test_args = [ + "script.py", + invalid_cid, + "--chunks-per-shard", "64", + "--rpc-url", rpc_base_url, + "--gateway-url", gateway_base_url + ] + with patch.object(sys, "argv", test_args): + await sharded_converter_cli() + + # Verify error handling + captured = capsys.readouterr() + assert "An error occurred" in captured.out + assert f"{invalid_cid}" in captured.out \ No newline at end of file diff --git a/tests/test_cpc_compare.py b/tests/test_cpc_compare.py new file mode 100644 index 0000000..6930eef --- /dev/null +++ b/tests/test_cpc_compare.py @@ -0,0 +1,128 @@ +# import time + +# import numpy as np +# import pandas as pd +# import pytest +# import xarray as xr +# from dag_cbor.ipld import IPLDKind +# from multiformats import CID + +# # Import both store implementations +# from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore +# from py_hamt.zarr_hamt_store import ZarrHAMTStore + + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_benchmark_sharded_store(): +# """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring +# print("\n\n" + "=" * 80) +# print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print +# print("=" * 80) + + +# rpc_base_url = f"https://ipfs-gateway.dclimate.net" +# gateway_base_url = f"https://ipfs-gateway.dclimate.net" +# headers = { +# "X-API-Key": "", +# } + +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# ) as kubo_cas: +# # --- Write --- +# root_cid = "bafyr4ifjgdfafxfqtdkirmdyzlziswzo5gsxbrivqjzu35ukiixnu2omvm" +# print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print +# # --- Read --- +# # When opening for read, chunks_per_shard is read from the store's metadata +# store_read = await ShardedZarrStore.open( # Use ShardedZarrStore +# cas=kubo_cas, read_only=True, root_cid=root_cid +# ) +# print(f"Opened ShardedZarrStore for reading with root CID: {root_cid}") + +# start_read = time.perf_counter() +# ipfs_ds = xr.open_zarr(store=store_read) +# # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) +# if "precip" in ipfs_ds.variables and "time" in ipfs_ds.coords: +# # _ = ipfs_ds.temp.isel(time=0).values +# data_fetched = ipfs_ds.precip.values + +# # Calculate the size of the fetched data +# data_size = data_fetched.nbytes if data_fetched is not None else 0 +# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") +# elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable +# first_var_name = list(ipfs_ds.data_vars.keys())[0] +# # Construct a minimal selection based on available dimensions +# selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} +# if selection: +# _ = ipfs_ds[first_var_name].isel(**selection).values +# else: # If no dimensions, try loading the whole variable (e.g. scalar) +# _ = ipfs_ds[first_var_name].values +# end_read = time.perf_counter() + +# print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print +# print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") +# print("=" * 80) +# # Speed in MB/s +# if data_size > 0: +# speed = data_size / (end_read - start_read) / (1024 * 1024) +# print(f"Read speed: {speed:.2f} MB/s") +# else: +# print("No data fetched, cannot calculate speed.") + +# # ### +# # BENCHMARK FOR THE ORIGINAL ZarrHAMTStore +# # ### +# @pytest.mark.asyncio(loop_scope="session") +# async def test_benchmark_hamt_store(): +# """Benchmarks write and read performance for the ZarrHAMTStore.""" +# print("\n\n" + "=" * 80) +# print("🚀 STARTING BENCHMARK for ZarrHAMTStore") +# print("=" * 80) + +# rpc_base_url = f"https://ipfs-gateway.dclimate.net/" +# gateway_base_url = f"https://ipfs-gateway.dclimate.net/" +# # headers = { +# # "X-API-Key": "", +# # } +# # headers = {} + +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# ) as kubo_cas: + +# root_cid = "bafyr4ialorauxcpw77mgmnyoeptn4g4zkqdqhtsobff4v76rllvd3m6cqi" +# # root_node_id = CID.decode(root_cid) + +# hamt = await HAMT.build( +# cas=kubo_cas, root_node_id=root_cid, values_are_bytes=True, read_only=True +# ) +# start = time.perf_counter() +# ipfs_ds: xr.Dataset +# zhs = ZarrHAMTStore(hamt, read_only=True) +# ipfs_ds = xr.open_zarr(store=zhs) + +# # --- Read --- +# hamt = HAMT(cas=kubo_cas, values_are_bytes=True, root_node_id=root_cid, read_only=True) + +# # Initialize the store +# zhs = ZarrHAMTStore(hamt, read_only=True) + +# start_read = time.perf_counter() +# ipfs_ds = xr.open_zarr(store=zhs) +# # Force a read of some data to ensure it's loaded +# data_fetched = ipfs_ds.precip.values + +# # Calculate the size of the fetched data +# data_size = data_fetched.nbytes if data_fetched is not None else 0 +# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") +# end_read = time.perf_counter() + +# print(f"\n--- [HAMT] Read Stats ---") +# print(f"Total time to open and read: {end_read - start_read:.2f} seconds") + +# if data_size > 0: +# speed = data_size / (end_read - start_read) / (1024 * 1024) +# print(f"Read speed: {speed:.2f} MB/s") +# else: +# print("No data fetched, cannot calculate speed.") diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py new file mode 100644 index 0000000..bcdabf6 --- /dev/null +++ b/tests/test_sharded_zarr_store.py @@ -0,0 +1,669 @@ +import asyncio +import math + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from zarr.abc.store import RangeByteRequest +import zarr.core.buffer +import dag_cbor + +from py_hamt import HAMT, KuboCAS, ShardedZarrStore +from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset for benchmarking.""" + # Using a slightly larger dataset for a more meaningful benchmark + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + temp = np.random.randn(len(times), len(lats), len(lons)) + precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons))) + + ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], temp), + }, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) + yield ds + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_write_read( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests writing and reading a Zarr dataset using ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.dims) + array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # --- Write --- + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + assert root_cid is not None + + # --- Read --- + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + ds_read = xr.open_zarr(store=store_read) + xr.testing.assert_identical(test_ds, ds_read) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_init(create_ipfs: tuple[str, str]): + """ + Tests the initialization of the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + array_shape = (100, 100) + chunk_shape = (10, 10) + chunks_per_shard = 64 + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test successful creation + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + assert store is not None + + # Test missing parameters for new store + with pytest.raises(ValueError): + await ShardedZarrStore.open(cas=kubo_cas, read_only=False) + + # Test opening read-only store without root_cid + with pytest.raises(ValueError): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_metadata( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests metadata handling in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + + ordered_dims = list(test_ds.dims) + array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + # Test exists + assert await store_read.exists("lat/zarr.json") + assert await store_read.exists("lon/zarr.json") + assert await store_read.exists("time/zarr.json") + assert await store_read.exists("temp/zarr.json") + assert await store_read.exists("lat/c/0") + assert await store_read.exists("lon/c/0") + assert await store_read.exists("time/c/0") + # assert not await store_read.exists("nonexistent") + + # Test list + keys = [key async for key in store_read.list()] + assert len(keys) > 0 + assert "lat/zarr.json" in keys + + prefix = "lat" + keys_with_prefix = [key async for key in store_read.list_prefix(prefix=prefix)] + assert "lat/zarr.json" in keys_with_prefix + assert "lat/c/0" in keys_with_prefix + + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_chunks( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests chunk data handling in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + + ordered_dims = list(test_ds.dims) + array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + + # Test get + chunk_key = "temp/c/0/0/0" + proto = zarr.core.buffer.default_buffer_prototype() + chunk_data = await store_read.get(chunk_key, proto) + assert chunk_data is not None + + # Test delete + store_write = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + await store_write.delete(chunk_key) + await store_write.flush() + + store_read_after_delete = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=await store_write.flush() + ) + assert await store_read_after_delete.get(chunk_key, proto) is None + +@pytest.mark.asyncio +async def test_chunk_and_delete_logic( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """Tests chunk getting, deleting, and related error handling.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.dims) + array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w", consolidated=True) + root_cid = await store_write.flush() + + # Re-open as writable to test deletion + store_rw = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + + chunk_key = "temp/c/0/0/0" + proto = zarr.core.buffer.default_buffer_prototype() + + # Verify chunk exists and can be read + assert await store_rw.exists(chunk_key) + chunk_data = await store_rw.get(chunk_key, proto) + assert chunk_data is not None + + # Delete the chunk + await store_rw.delete(chunk_key) + new_root_cid = await store_rw.flush() + + # Verify it's gone + store_after_delete = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + assert not await store_after_delete.exists(chunk_key) + assert await store_after_delete.get(chunk_key, proto) is None + + # Test deleting a non-existent key + with pytest.raises(KeyError): + await store_rw.delete("nonexistent/c/0/0/0") + + # Test deleting an already deleted key + with pytest.raises(KeyError): + await store_rw.delete(chunk_key) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_partial_reads( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests partial reads in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + + ordered_dims = list(test_ds.dims) + array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + proto = zarr.core.buffer.default_buffer_prototype() + chunk_key = "temp/c/0/0/0" + full_chunk = await store_read.get(chunk_key, proto) + assert full_chunk is not None + full_chunk_bytes = full_chunk.to_bytes() + + # Test RangeByteRequest + byte_range = RangeByteRequest(start=10, end=50) + partial_chunk = await store_read.get(chunk_key, proto, byte_range=byte_range) + assert partial_chunk is not None + assert partial_chunk.to_bytes() == full_chunk_bytes[10:50] + +@pytest.mark.asyncio +async def test_partial_reads_and_errors( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """Tests partial reads and error handling in get().""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.dims) + array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w", consolidated=True) + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + proto = zarr.core.buffer.default_buffer_prototype() + chunk_key = "temp/c/0/0/0" + full_chunk = await store_read.get(chunk_key, proto) + assert full_chunk is not None + full_chunk_bytes = full_chunk.to_bytes() + + # Test RangeByteRequest + byte_range = RangeByteRequest(start=10, end=50) + partial_chunk = await store_read.get(chunk_key, proto, byte_range=byte_range) + assert partial_chunk is not None + assert partial_chunk.to_bytes() == full_chunk_bytes[10:50] + + # Test invalid byte range + with pytest.raises(ValueError): + await store_read.get(chunk_key, proto, byte_range=RangeByteRequest(start=50, end=10)) + +@pytest.mark.asyncio +async def test_zero_sized_array(create_ipfs: tuple[str, str]): + """Test handling of arrays with a zero-length dimension.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS(rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(100, 0), + chunk_shape=(10, 10), + chunks_per_shard=64 + ) + assert store._total_chunks == 0 + assert store._num_shards == 0 + root_cid = await store.flush() + + # Read it back and verify + store_read = await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=root_cid) + assert store_read._total_chunks == 0 + assert store_read._num_shards == 0 + +@pytest.mark.asyncio +async def test_store_eq_method(create_ipfs: tuple[str, str]): + """Tests the __eq__ method.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store1 = await ShardedZarrStore.open(cas=kubo_cas, read_only=False, array_shape=(1,1), chunk_shape=(1,1), chunks_per_shard=1) + root_cid = await store1.flush() + store2 = await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=root_cid) + + assert store1 == store2 + + +@pytest.mark.asyncio +async def test_listing_and_metadata( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests metadata handling and listing in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.dims) + array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w", consolidated=True) + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + # Test exists for metadata (correcting for xarray group structure) + assert await store_read.exists("temp/zarr.json") + assert await store_read.exists("lat/zarr.json") + assert not await store_read.exists("nonexistent.json") + + # Test listing + keys = {key async for key in store_read.list()} + assert "temp/zarr.json" in keys + assert "lat/zarr.json" in keys + + # Test list_prefix + prefix_keys = {key async for key in store_read.list_prefix("temp/")} + assert "temp/zarr.json" in prefix_keys + + # Test list_dir for root + dir_keys = {key async for key in store_read.list_dir("")} + assert "temp" in dir_keys + assert "lat" in dir_keys + assert "lon" in dir_keys + assert "zarr.json" in dir_keys + +@pytest.mark.asyncio +async def test_sharded_zarr_store_init_errors(create_ipfs: tuple[str, str]): + """ + Tests initialization errors for ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test missing parameters for a new store + with pytest.raises(ValueError, match="must be provided for a new store"): + await ShardedZarrStore.open(cas=kubo_cas, read_only=False) + + # Test opening a read-only store without a root_cid + with pytest.raises(ValueError, match="must be provided for a read-only store"): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True) + + # Test invalid chunks_per_shard + with pytest.raises(ValueError, match="must be a positive integer"): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=0, + ) + + # Test invalid chunk_shape + with pytest.raises(ValueError, match="All chunk_shape dimensions must be positive"): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(0,), + chunks_per_shard=10, + ) + + # Test invalid array_shape + with pytest.raises(ValueError, match="All array_shape dimensions must be non-negative"): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(-10,), + chunk_shape=(5,), + chunks_per_shard=10, + ) + +@pytest.mark.asyncio +async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, str]): + """Tests initialization with invalid shapes and manifest errors.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS(rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url) as kubo_cas: + # Test negative chunk_shape dimension (line 136) + with pytest.raises(ValueError, match="All chunk_shape dimensions must be positive"): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(-5, 5), + chunks_per_shard=10, + ) + + # Test negative array_shape dimension (line 141) + with pytest.raises(ValueError, match="All array_shape dimensions must be non-negative"): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, -10), + chunk_shape=(5, 5), + chunks_per_shard=10, + ) + + # Test zero-sized array (lines 150, 163) - reinforce existing test + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(0, 10), + chunk_shape=(5, 5), + chunks_per_shard=10, + ) + assert store._total_chunks == 0 + assert store._num_shards == 0 + assert len(store._root_obj["chunks"]["shard_cids"]) == 0 # Line 163 + root_cid = await store.flush() + + # Test invalid manifest version (line 224) + invalid_root_obj = { + "manifest_version": "invalid_version", + "metadata": {}, + "chunks": { + "array_shape": [10, 10], + "chunk_shape": [5, 5], + "cid_byte_length": 59, + "sharding_config": {"chunks8048": 10}, + "shard_cids": [None] * 4, + }, + } + invalid_root_cid = await kubo_cas.save(dag_cbor.encode(invalid_root_obj), codec="dag-cbor") + with pytest.raises(ValueError, match="Incompatible manifest version"): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=invalid_root_cid) + + # Test inconsistent shard count (line 236) + invalid_root_obj = { + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": [10, 10], # 100 chunks, with 10 chunks per shard -> 10 shards + "chunk_shape": [5, 5], + "cid_byte_length": 59, + "sharding_config": {"chunks_per_shard": 10}, + "shard_cids": [None] * 5, # Wrong number of shards + }, + } + invalid_root_cid = await kubo_cas.save(dag_cbor.encode(invalid_root_obj), codec="dag-cbor") + with pytest.raises(ValueError, match="Inconsistent number of shards"): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=invalid_root_cid) + +@pytest.mark.asyncio +async def test_sharded_zarr_store_parse_chunk_key(create_ipfs: tuple[str, str]): + """Tests chunk key parsing edge cases.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS(rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Test metadata key + assert store._parse_chunk_key("zarr.json") is None + assert store._parse_chunk_key("group1/zarr.json") is None + + # Test excluded array prefixes + assert store._parse_chunk_key("time/c/0") is None + assert store._parse_chunk_key("lat/c/0/0") is None + assert store._parse_chunk_key("lon/c/0/0") is None + + # Test uninitialized store + uninitialized_store = ShardedZarrStore(kubo_cas, read_only=False, root_cid=None) + assert uninitialized_store._parse_chunk_key("temp/c/0/0") is None + + # Test get on uninitialized store + with pytest.raises(RuntimeError, match="Load the root object first before accessing data."): + proto = zarr.core.buffer.default_buffer_prototype() + await uninitialized_store.get("temp/c/0/0", proto) + + with pytest.raises(RuntimeError, match="Cannot load root without a root_cid."): + await uninitialized_store._load_root_from_cid() + + # Test dimensionality mismatch + assert store._parse_chunk_key("temp/c/0/0/0") is None # 3D key for 2D array + + # Test invalid coordinates + assert store._parse_chunk_key("temp/c/3/0") is None # Out of bounds (3 >= 2 chunks) + assert store._parse_chunk_key("temp/c/0/invalid") is None # Non-integer + assert store._parse_chunk_key("temp/c/0/-1") is None # Negative coordinate + +@pytest.mark.asyncio +async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, str]): + """Tests initialization with invalid shapes and manifest errors.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS(rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url) as kubo_cas: + # Test negative chunk_shape dimension + with pytest.raises(ValueError, match="All chunk_shape dimensions must be positive"): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(-5, 5), + chunks_per_shard=10, + ) + + # Test negative array_shape dimension + with pytest.raises(ValueError, match="All array_shape dimensions must be non-negative"): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, -10), + chunk_shape=(5, 5), + chunks_per_shard=10, + ) + + # Test zero-sized array + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(0, 10), + chunk_shape=(5, 5), + chunks_per_shard=10, + ) + assert store._total_chunks == 0 + assert store._num_shards == 0 + assert len(store._root_obj["chunks"]["shard_cids"]) == 0 # Line 163 + root_cid = await store.flush() + + # Test invalid manifest version + invalid_root_obj = { + "manifest_version": "invalid_version", + "metadata": {}, + "chunks": { + "array_shape": [10, 10], + "chunk_shape": [5, 5], + "cid_byte_length": 59, + "sharding_config": {"chunks8048": 10}, + "shard_cids": [None] * 4, + }, + } + invalid_root_cid = await kubo_cas.save(dag_cbor.encode(invalid_root_obj), codec="dag-cbor") + with pytest.raises(ValueError, match="Incompatible manifest version"): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=invalid_root_cid) + + # Test inconsistent shard count + invalid_root_obj = { + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": [10, 10], # 100 chunks, with 10 chunks per shard -> 10 shards + "chunk_shape": [5, 5], + "cid_byte_length": 59, + "sharding_config": {"chunks_per_shard": 10}, + "shard_cids": [None] * 5, # Wrong number of shards + }, + } + invalid_root_cid = await kubo_cas.save(dag_cbor.encode(invalid_root_obj), codec="dag-cbor") + with pytest.raises(ValueError, match="Inconsistent number of shards"): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=invalid_root_cid) \ No newline at end of file From a87eb194197d71b45f46f22e6d11449b551054e1 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:19:35 -0400 Subject: [PATCH 13/74] fix: pinning --- py_hamt/flat_zarr_store.py | 6 +- py_hamt/hamt_to_sharded_converter.py | 13 +- py_hamt/manage_pins.py | 86 +++++ py_hamt/sharded_zarr_store.py | 472 ++++++++++++++++++++------- py_hamt/store.py | 53 ++- tests/test_benchmark_stores.py | 4 +- tests/test_converter.py | 65 ++-- tests/test_cpc_compare.py | 5 +- tests/test_sharded_zarr_pinning.py | 138 ++++++++ tests/test_sharded_zarr_store.py | 128 ++++++-- 10 files changed, 770 insertions(+), 200 deletions(-) create mode 100644 py_hamt/manage_pins.py create mode 100644 tests/test_sharded_zarr_pinning.py diff --git a/py_hamt/flat_zarr_store.py b/py_hamt/flat_zarr_store.py index 33f531f..850d7ae 100644 --- a/py_hamt/flat_zarr_store.py +++ b/py_hamt/flat_zarr_store.py @@ -165,7 +165,7 @@ # except (ValueError, IndexError): # return None -# async def set_partial_values( +# async def set_partial_values( # self, key_start_values: Iterable[tuple[str, int, BytesLike]] # ) -> None: # """@private""" @@ -316,7 +316,7 @@ # @property # def supports_deletes(self) -> bool: # """@private""" -# return not self.read_only +# return not self.read_only # async def delete(self, key: str) -> None: # if self.read_only: @@ -372,4 +372,4 @@ # name = key.split('/')[0] # if name not in seen: # seen.add(name) -# yield name \ No newline at end of file +# yield name diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index 7b6e4d9..a43f3e5 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -9,6 +9,7 @@ from multiformats import CID from zarr.core.buffer import Buffer, BufferPrototype + async def convert_hamt_to_sharded( cas: KuboCAS, hamt_root_cid: str, chunks_per_shard: int, cid_len: int = 59 ) -> str: @@ -42,9 +43,10 @@ async def convert_hamt_to_sharded( array_shape = array_shape_tuple chunk_shape = chunk_shape_tuple - # 3. Create the destination ShardedZarrStore for writing - print(f"Initializing new ShardedZarrStore with {chunks_per_shard} chunks per shard...") + print( + f"Initializing new ShardedZarrStore with {chunks_per_shard} chunks per shard..." + ) dest_store = await ShardedZarrStore.open( cas=cas, read_only=False, @@ -67,8 +69,8 @@ async def convert_hamt_to_sharded( # Write the exact same key-value pair to the destination. await dest_store.set_pointer(key, cid_base32_str) - if count % 200 == 0: # pragma: no cover - print(f"Migrated {count} keys...") # pragma: no cover + if count % 200 == 0: # pragma: no cover + print(f"Migrated {count} keys...") # pragma: no cover print(f"Migration of {count} total keys complete.") @@ -122,5 +124,6 @@ async def sharded_converter_cli(): except Exception as e: print(f"\nAn error occurred: {e}") + if __name__ == "__main__": - asyncio.run(sharded_converter_cli()) # pragma: no cover \ No newline at end of file + asyncio.run(sharded_converter_cli()) # pragma: no cover diff --git a/py_hamt/manage_pins.py b/py_hamt/manage_pins.py new file mode 100644 index 0000000..3795874 --- /dev/null +++ b/py_hamt/manage_pins.py @@ -0,0 +1,86 @@ +""" +A command-line tool to recursively pin or unpin all CIDs associated with a +sharded Zarr dataset on IPFS using its root CID. +""" +import asyncio +import argparse +import sys +from py_hamt import KuboCAS, ShardedZarrStore + +# --- CLI Logic Functions --- + +async def handle_pin(args): + """ + Connects to IPFS, loads the dataset from the root CID, and pins all + associated CIDs (root, metadata, shards, and data chunks). + """ + async with KuboCAS(rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url) as kubo_cas: + try: + print(f"-> Opening store with root CID: {args.root_cid}") + store = await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=args.root_cid) + except Exception as e: + print(f"Error: Failed to open Zarr store for CID {args.root_cid}. Ensure the CID is correct and the daemon is running.", file=sys.stderr) + print(f"Details: {e}", file=sys.stderr) + return + + print(f"-> Sending commands to pin the entire dataset to {args.rpc_url}...") + await store.pin_entire_dataset() + print("\n--- Pinning Commands Sent Successfully ---") + print("The IPFS node will now pin all objects in the background.") + + +async def handle_unpin(args): + """ + Connects to IPFS, loads the dataset from the root CID, and unpins all + associated CIDs. + """ + async with KuboCAS(rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url) as kubo_cas: + try: + print(f"-> Opening store with root CID: {args.root_cid}") + store = await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=args.root_cid) + except Exception as e: + print(f"Error: Failed to open Zarr store for CID {args.root_cid}. Ensure the CID is correct and the daemon is running.", file=sys.stderr) + print(f"Details: {e}", file=sys.stderr) + return + + print(f"-> Sending commands to unpin the entire dataset from {args.rpc_url}...") + await store.unpin_entire_dataset() + print("\n--- Unpinning Commands Sent Successfully ---") + print("The IPFS node will now unpin all objects in the background.") + + +def main(): + """Sets up the argument parser and runs the selected command.""" + parser = argparse.ArgumentParser( + description="A CLI tool to pin or unpin sharded Zarr datasets on IPFS.", + formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument('--rpc-url', default='http://127.0.0.1:5001', help='IPFS Kubo RPC API endpoint URL.') + parser.add_argument('--gateway-url', default='http://127.0.0.1:8080', help='IPFS Gateway URL (needed for loading shards).') + + subparsers = parser.add_subparsers(dest='command', required=True, help='Available commands') + + # --- Pin Command --- + parser_pin = subparsers.add_parser('pin', help='Recursively pin a dataset using its root CID.') + parser_pin.add_argument('root_cid', help='The root CID of the dataset to pin.') + parser_pin.set_defaults(func=handle_pin) + + # --- Unpin Command --- + parser_unpin = subparsers.add_parser('unpin', help='Recursively unpin a dataset using its root CID.') + parser_unpin.add_argument('root_cid', help='The root CID of the dataset to unpin.') + parser_unpin.set_defaults(func=handle_unpin) + + args = parser.parse_args() + + try: + asyncio.run(args.func(args)) + except KeyboardInterrupt: + print("\nOperation cancelled by user.", file=sys.stderr) + sys.exit(1) + except Exception as e: + print(f"\nAn unexpected error occurred: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 62e5925..2cbd64c 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -44,19 +44,25 @@ def __init__( self._root_cid = root_cid self._root_obj: Optional[dict] = None - self._shard_data_cache: Dict[int, bytearray] = {} # shard_index -> shard_byte_data - self._dirty_shards: Set[int] = set() # Set of shard_indices that need flushing - self._pending_shard_loads: Dict[int, asyncio.Task] = {} # shard_index -> Task loading the full shard + self._shard_data_cache: Dict[ + int, bytearray + ] = {} # shard_index -> shard_byte_data + self._dirty_shards: Set[int] = set() # Set of shard_indices that need flushing + self._pending_shard_loads: Dict[ + int, asyncio.Task + ] = {} # shard_index -> Task loading the full shard self._cid_len: Optional[int] = None self._array_shape: Optional[Tuple[int, ...]] = None self._chunk_shape: Optional[Tuple[int, ...]] = None - self._chunks_per_dim: Optional[Tuple[int, ...]] = None # Number of chunks in each dimension - self._chunks_per_shard: Optional[int] = None # How many chunk CIDs per shard - self._num_shards: Optional[int] = None # Total number of shards - self._total_chunks: Optional[int] = None # Total number of chunks in the array + self._chunks_per_dim: Optional[Tuple[int, ...]] = ( + None # Number of chunks in each dimension + ) + self._chunks_per_shard: Optional[int] = None # How many chunk CIDs per shard + self._num_shards: Optional[int] = None # Total number of shards + self._total_chunks: Optional[int] = None # Total number of chunks in the array - self._dirty_root = False # Indicates if the root object itself (metadata or shard_cids list) changed + self._dirty_root = False # Indicates if the root object itself (metadata or shard_cids list) changed @classmethod async def open( @@ -83,7 +89,9 @@ async def open( ) if not isinstance(chunks_per_shard, int) or chunks_per_shard <= 0: raise ValueError("chunks_per_shard must be a positive integer.") - store._initialize_new_root(array_shape, chunk_shape, chunks_per_shard, cid_len) + store._initialize_new_root( + array_shape, chunk_shape, chunks_per_shard, cid_len + ) else: raise ValueError("root_cid must be provided for a read-only store.") return store @@ -99,18 +107,19 @@ def _initialize_new_root( self._chunk_shape = chunk_shape self._cid_len = cid_len self._chunks_per_shard = chunks_per_shard - + if not all(cs > 0 for cs in chunk_shape): raise ValueError("All chunk_shape dimensions must be positive.") - if not all(asarray_s >= 0 for asarray_s in array_shape): # array_shape dims can be 0 - raise ValueError("All array_shape dimensions must be non-negative.") - + if not all( + asarray_s >= 0 for asarray_s in array_shape + ): # array_shape dims can be 0 + raise ValueError("All array_shape dimensions must be non-negative.") self._chunks_per_dim = tuple( math.ceil(a / c) if c > 0 else 0 for a, c in zip(array_shape, chunk_shape) ) self._total_chunks = math.prod(self._chunks_per_dim) - + if self._total_chunks == 0: self._num_shards = 0 else: @@ -118,15 +127,15 @@ def _initialize_new_root( self._root_obj = { "manifest_version": "sharded_zarr_v1", - "metadata": {}, # For .json - "chunks": { # Information about the chunk index itself - "array_shape": list(self._array_shape), # Original array shape - "chunk_shape": list(self._chunk_shape), # Original chunk shape + "metadata": {}, # For .json + "chunks": { # Information about the chunk index itself + "array_shape": list(self._array_shape), # Original array shape + "chunk_shape": list(self._chunk_shape), # Original chunk shape "cid_byte_length": self._cid_len, "sharding_config": { "chunks_per_shard": self._chunks_per_shard, }, - "shard_cids": [None] * self._num_shards, # List of CIDs for each shard + "shard_cids": [None] * self._num_shards, # List of CIDs for each shard }, } self._dirty_root = True @@ -138,27 +147,32 @@ async def _load_root_from_cid(self): self._root_obj = dag_cbor.decode(root_bytes) if self._root_obj.get("manifest_version") != "sharded_zarr_v1": - raise ValueError(f"Incompatible manifest version: {self._root_obj.get('manifest_version')}. Expected 'sharded_zarr_v1'.") + raise ValueError( + f"Incompatible manifest version: {self._root_obj.get('manifest_version')}. Expected 'sharded_zarr_v1'." + ) chunk_info = self._root_obj["chunks"] self._array_shape = tuple(chunk_info["array_shape"]) self._chunk_shape = tuple(chunk_info["chunk_shape"]) self._cid_len = chunk_info["cid_byte_length"] - sharding_cfg = chunk_info.get("sharding_config", {}) # Handle older formats if any planned + sharding_cfg = chunk_info.get( + "sharding_config", {} + ) # Handle older formats if any planned self._chunks_per_shard = sharding_cfg["chunks_per_shard"] if not all(cs > 0 for cs in self._chunk_shape): - raise ValueError("Loaded chunk_shape dimensions must be positive.") + raise ValueError("Loaded chunk_shape dimensions must be positive.") self._chunks_per_dim = tuple( - math.ceil(a / c) if c > 0 else 0 for a, c in zip(self._array_shape, self._chunk_shape) + math.ceil(a / c) if c > 0 else 0 + for a, c in zip(self._array_shape, self._chunk_shape) ) self._total_chunks = math.prod(self._chunks_per_dim) - + expected_num_shards = 0 if self._total_chunks > 0: expected_num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) self._num_shards = expected_num_shards - + if len(chunk_info["shard_cids"]) != self._num_shards: raise ValueError( f"Inconsistent number of shards. Expected {self._num_shards} from shapes/config, " @@ -171,12 +185,14 @@ async def _fetch_and_cache_full_shard(self, shard_idx: int, shard_cid: str): Manages removal from _pending_shard_loads. """ try: - shard_data_bytes = await self.cas.load(shard_cid) # Load full shard + shard_data_bytes = await self.cas.load(shard_cid) # Load full shard self._shard_data_cache[shard_idx] = bytearray(shard_data_bytes) except Exception as e: print(e) # Handle or log the exception appropriately - print(f"Warning: Failed to cache full shard {shard_idx} (CID: {shard_cid}): {e}") + print( + f"Warning: Failed to cache full shard {shard_idx} (CID: {shard_cid}): {e}" + ) # If it fails, subsequent requests might try again if it's still not in cache. finally: # Ensure the task is removed from pending list once done (success or failure) @@ -188,9 +204,9 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: if key.endswith(".json"): return None excluded_array_prefixes = {"time", "lat", "lon", "latitude", "longitude"} - + chunk_marker = "/c/" - marker_idx = key.rfind(chunk_marker) # Use rfind for robustness + marker_idx = key.rfind(chunk_marker) # Use rfind for robustness if marker_idx == -1: # Key does not contain "/c/", so it's not a chunk data key # in the expected format (e.g., could be .zattrs, .zgroup at various levels). @@ -205,8 +221,8 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # Determine the actual array name (the last component of the path before "/c/") actual_array_name = "" if path_before_c: - actual_array_name = path_before_c.split('/')[-1] - + actual_array_name = path_before_c.split("/")[-1] + # 2. If the determined array name is in our exclusion list, return None. if actual_array_name in excluded_array_prefixes: return None @@ -222,11 +238,11 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # This might also happen if a key like "some_other_main_array/c/0" is passed # but this store instance was configured for "temp". return None - + # The part after "/c/" contains the chunk coordinates - coord_part = key[marker_idx + len(chunk_marker):] - parts = coord_part.split('/') - + coord_part = key[marker_idx + len(chunk_marker) :] + parts = coord_part.split("/") + # Validate dimensionality: # The number of coordinate parts must match the dimensionality of the array # this store instance is configured for (self._chunks_per_dim). @@ -234,15 +250,15 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # This key's dimensionality does not match the store's configured array. # It's likely for a different array or a malformed key for the current array. return None - + try: coords = tuple(map(int, parts)) # Validate coordinates against the chunk grid of the store's configured array for i, c_coord in enumerate(coords): if not (0 <= c_coord < self._chunks_per_dim[i]): - return None # Coordinate out of bounds for this array's chunk grid + return None # Coordinate out of bounds for this array's chunk grid return coords - except (ValueError, IndexError): # If int conversion fails or other issues + except (ValueError, IndexError): # If int conversion fails or other issues return None def _get_linear_chunk_index(self, chunk_coords: Tuple[int, ...]) -> int: @@ -256,7 +272,9 @@ def _get_linear_chunk_index(self, chunk_coords: Tuple[int, ...]) -> int: def _get_shard_info(self, linear_chunk_index: int) -> Tuple[int, int]: if self._chunks_per_shard is None or self._chunks_per_shard <= 0: - raise RuntimeError("Sharding not configured properly: _chunks_per_shard invalid.") + raise RuntimeError( + "Sharding not configured properly: _chunks_per_shard invalid." + ) if linear_chunk_index < 0: raise ValueError("Linear chunk index cannot be negative.") @@ -280,30 +298,43 @@ async def _load_or_initialize_shard_cache(self, shard_idx: int) -> bytearray: del self._pending_shard_loads[shard_idx] # Fall through to normal loading except Exception as e: - print(f"Warning: Pending shard load for {shard_idx} failed: {e}. Attempting fresh load.") + print( + f"Warning: Pending shard load for {shard_idx} failed: {e}. Attempting fresh load." + ) if self._root_obj is None: - raise RuntimeError("Root object not loaded or initialized (_root_obj is None).") - if not (0 <= shard_idx < self._num_shards if self._num_shards is not None else False): - raise ValueError(f"Shard index {shard_idx} out of bounds for {self._num_shards} shards.") + raise RuntimeError( + "Root object not loaded or initialized (_root_obj is None)." + ) + if not ( + 0 <= shard_idx < self._num_shards if self._num_shards is not None else False + ): + raise ValueError( + f"Shard index {shard_idx} out of bounds for {self._num_shards} shards." + ) shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] if shard_cid: shard_data_bytes = await self.cas.load(shard_cid) self._shard_data_cache[shard_idx] = bytearray(shard_data_bytes) else: - if self._cid_len is None: # Should be set - raise RuntimeError("Store not initialized: _cid_len is None for shard initialization.") + if self._cid_len is None: # Should be set + raise RuntimeError( + "Store not initialized: _cid_len is None for shard initialization." + ) # New shard or shard not yet written, initialize with zeros shard_size_bytes = self._chunks_per_shard * self._cid_len - self._shard_data_cache[shard_idx] = bytearray(shard_size_bytes) # Filled with \x00 + self._shard_data_cache[shard_idx] = bytearray( + shard_size_bytes + ) # Filled with \x00 return self._shard_data_cache[shard_idx] - async def set_partial_values( self, key_start_values: Iterable[Tuple[str, int, BytesLike]] ) -> None: - raise NotImplementedError("Partial writes are not supported by ShardedZarrStore.") + raise NotImplementedError( + "Partial writes are not supported by ShardedZarrStore." + ) async def get_partial_values( self, @@ -312,7 +343,7 @@ async def get_partial_values( ) -> List[Optional[zarr.core.buffer.Buffer]]: tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] results = await asyncio.gather(*tasks) - return results # type: ignore + return results # type: ignore def __eq__(self, other: object) -> bool: if not isinstance(other, ShardedZarrStore): @@ -322,11 +353,15 @@ def __eq__(self, other: object) -> bool: async def flush(self) -> str: if self.read_only: - if self._root_cid is None: # Read-only store should have been opened with a root_cid - raise ValueError("Read-only store has no root CID to return. Was it opened correctly?") + if ( + self._root_cid is None + ): # Read-only store should have been opened with a root_cid + raise ValueError( + "Read-only store has no root CID to return. Was it opened correctly?" + ) return self._root_cid - if self._root_obj is None: # Should be initialized for a writable store + if self._root_obj is None: # Should be initialized for a writable store raise RuntimeError("Store not initialized for writing: _root_obj is None.") # Save all dirty shards first, as their CIDs might need to go into the root object @@ -335,30 +370,37 @@ async def flush(self) -> str: if shard_idx not in self._shard_data_cache: # This implies an internal logic error if a shard is dirty but not in cache # However, could happen if cache was cleared externally; robust code might reload/reinit - print(f"Warning: Dirty shard {shard_idx} not found in cache. Skipping save for this shard.") + print( + f"Warning: Dirty shard {shard_idx} not found in cache. Skipping save for this shard." + ) continue - + shard_data_bytes = bytes(self._shard_data_cache[shard_idx]) - + # The CAS save method here should return a string CID. - new_shard_cid = await self.cas.save(shard_data_bytes, codec="raw") # Shards are raw bytes of CIDs - + new_shard_cid = await self.cas.save( + shard_data_bytes, codec="raw" + ) # Shards are raw bytes of CIDs + if self._root_obj["chunks"]["shard_cids"][shard_idx] != new_shard_cid: self._root_obj["chunks"]["shard_cids"][shard_idx] = new_shard_cid - self._dirty_root = True # Root object changed because a shard_cid in its list changed + self._dirty_root = True # Root object changed because a shard_cid in its list changed self._dirty_shards.clear() if self._dirty_root: root_obj_bytes = dag_cbor.encode(self._root_obj) new_root_cid = await self.cas.save(root_obj_bytes, codec="dag-cbor") - self._root_cid = str(new_root_cid) # Ensure it's string + self._root_cid = str(new_root_cid) # Ensure it's string self._dirty_root = False - - if self._root_cid is None: # Should only happen if nothing was dirty AND it was a new store never flushed - raise RuntimeError("Failed to obtain a root CID after flushing. Store might be empty or unchanged.") - return self._root_cid + if ( + self._root_cid is None + ): # Should only happen if nothing was dirty AND it was a new store never flushed + raise RuntimeError( + "Failed to obtain a root CID after flushing. Store might be empty or unchanged." + ) + return self._root_cid async def get( self, @@ -377,8 +419,10 @@ async def get( return None # byte_range is not typically applicable to metadata JSON objects themselves if byte_range is not None: - # Consider if this should be an error or ignored for metadata - print(f"Warning: byte_range requested for metadata key '{key}'. Ignoring range.") + # Consider if this should be an error or ignored for metadata + print( + f"Warning: byte_range requested for metadata key '{key}'. Ignoring range." + ) data = await self.cas.load(metadata_cid) return prototype.buffer.from_bytes(data) @@ -386,11 +430,13 @@ async def get( shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) if not (0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): - # This case implies linear_chunk_index was out of _total_chunks bounds or bad sharding logic - return None + # This case implies linear_chunk_index was out of _total_chunks bounds or bad sharding logic + return None target_shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] - if target_shard_cid is None: # This shard has no data (all chunks within it are implicitly empty) + if ( + target_shard_cid is None + ): # This shard has no data (all chunks within it are implicitly empty) return None offset_in_shard_bytes = index_in_shard * self._cid_len @@ -398,63 +444,80 @@ async def get( if shard_idx in self._shard_data_cache: cached_shard_data = self._shard_data_cache[shard_idx] - chunk_cid_bytes = bytes(cached_shard_data[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len]) + chunk_cid_bytes = bytes( + cached_shard_data[ + offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len + ] + ) - if chunk_cid_bytes is None: # Not in cache or cache was invalid + if chunk_cid_bytes is None: # Not in cache or cache was invalid chunk_cid_bytes = await self.cas.load( target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len ) # After successfully fetching the specific CID bytes, # check if we should initiate a background load of the full shard. - if shard_idx not in self._shard_data_cache and shard_idx not in self._pending_shard_loads: + if ( + shard_idx not in self._shard_data_cache + and shard_idx not in self._pending_shard_loads + ): self._pending_shard_loads[shard_idx] = asyncio.create_task( self._fetch_and_cache_full_shard(shard_idx, target_shard_cid) ) - if all(b == 0 for b in chunk_cid_bytes): # Check for null CID placeholder (e.g. \x00 * cid_len) + if all( + b == 0 for b in chunk_cid_bytes + ): # Check for null CID placeholder (e.g. \x00 * cid_len) return None # Chunk doesn't exist or is considered empty # Decode CID (assuming ASCII, remove potential null padding) - chunk_cid_str = chunk_cid_bytes.decode("ascii").rstrip('\x00') - if not chunk_cid_str: # Empty string after rstrip if all were \x00 (already caught above) - return None + chunk_cid_str = chunk_cid_bytes.decode("ascii").rstrip("\x00") + if ( + not chunk_cid_str + ): # Empty string after rstrip if all were \x00 (already caught above) + return None # Actual chunk data load using the retrieved chunk_cid_str req_offset = byte_range.start if byte_range else None req_length = None if byte_range: if byte_range.end is not None: - if byte_range.start > byte_range.end: # Zarr allows start == stop for 0 length - raise ValueError(f"Byte range start ({byte_range.start}) cannot be greater than end ({byte_range.end})") + if ( + byte_range.start > byte_range.end + ): # Zarr allows start == stop for 0 length + raise ValueError( + f"Byte range start ({byte_range.start}) cannot be greater than end ({byte_range.end})" + ) req_length = byte_range.end - byte_range.start data = await self.cas.load(chunk_cid_str, offset=req_offset, length=req_length) return prototype.buffer.from_bytes(data) - async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: if self.read_only: raise ValueError("Cannot write to a read-only store.") if self._root_obj is None: - raise RuntimeError("Store not initialized for writing (root_obj is None). Call open() first.") + raise RuntimeError( + "Store not initialized for writing (root_obj is None). Call open() first." + ) raw_chunk_data_bytes = value.to_bytes() # Save the actual chunk data to CAS first, to get its CID - chunk_data_cid_obj = await self.cas.save(raw_chunk_data_bytes, codec="raw") # Chunks are typically raw bytes + chunk_data_cid_obj = await self.cas.save( + raw_chunk_data_bytes, codec="raw" + ) # Chunks are typically raw bytes chunk_data_cid_str = str(chunk_data_cid_obj) - await self.set_pointer(key, chunk_data_cid_str) # Store the CID in the index + await self.set_pointer(key, chunk_data_cid_str) # Store the CID in the index - async def set_pointer( - self, key: str, pointer: str - ) -> None: + async def set_pointer(self, key: str, pointer: str) -> None: # Ensure the CID (as ASCII bytes) fits in the allocated slot, padding with nulls chunk_data_cid_ascii_bytes = pointer.encode("ascii") if len(chunk_data_cid_ascii_bytes) > self._cid_len: raise ValueError( f"Encoded CID byte length ({len(chunk_data_cid_ascii_bytes)}) exceeds configured CID length ({self._cid_len}). CID: {pointer}" ) - padded_chunk_data_cid_bytes = chunk_data_cid_ascii_bytes.ljust(self._cid_len, b'\0') - + padded_chunk_data_cid_bytes = chunk_data_cid_ascii_bytes.ljust( + self._cid_len, b"\0" + ) chunk_coords = self._parse_chunk_key(key) @@ -463,7 +526,9 @@ async def set_pointer( # So, we store the metadata content, get its CID, and put *that* CID in root_obj. # This means the `value_cid_str` for metadata should be from `raw_chunk_data_bytes`. # This seems to align with FlatZarrStore, where `value_cid` is used for both. - self._root_obj["metadata"][key] = pointer # Store the string CID of the metadata content + self._root_obj["metadata"][key] = ( + pointer # Store the string CID of the metadata content + ) self._dirty_root = True return @@ -476,70 +541,79 @@ async def set_pointer( target_shard_data_cache = await self._load_or_initialize_shard_cache(shard_idx) offset_in_shard_bytes = index_in_shard * self._cid_len - + # Check if the content is actually changing to avoid unnecessary dirtying (optional optimization) # current_bytes_in_shard = target_shard_data_cache[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len] # if current_bytes_in_shard == padded_chunk_data_cid_bytes: # return # No change - target_shard_data_cache[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len] = padded_chunk_data_cid_bytes + target_shard_data_cache[ + offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len + ] = padded_chunk_data_cid_bytes self._dirty_shards.add(shard_idx) # If this write implies the shard CID in root_obj["chunks"]["shard_cids"] might change # (e.g., from None to an actual CID when the shard is first flushed), # then _dirty_root will be set by flush(). - async def exists(self, key: str) -> bool: if self._root_obj is None: - raise RuntimeError("Root object not loaded. Call _load_root_from_cid() first.") + raise RuntimeError( + "Root object not loaded. Call _load_root_from_cid() first." + ) chunk_coords = self._parse_chunk_key(key) - if chunk_coords is None: # Metadata + if chunk_coords is None: # Metadata return key in self._root_obj.get("metadata", {}) try: linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - if not (self._root_obj and "chunks" in self._root_obj and \ - 0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): + if not ( + self._root_obj + and "chunks" in self._root_obj + and 0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"]) + ): return False - + target_shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] - if target_shard_cid is None: # Shard itself doesn't exist + if target_shard_cid is None: # Shard itself doesn't exist return False offset_in_shard_bytes = index_in_shard * self._cid_len - + # Optimization: Check local shard cache first if shard_idx in self._shard_data_cache: - cached_shard_data = self._shard_data_cache[shard_idx] - # Ensure index_in_shard is valid for this cached data length - if offset_in_shard_bytes + self._cid_len <= len(cached_shard_data): - chunk_cid_bytes_from_cache = cached_shard_data[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len] + cached_shard_data = self._shard_data_cache[shard_idx] + # Ensure index_in_shard is valid for this cached data length + if offset_in_shard_bytes + self._cid_len <= len(cached_shard_data): + chunk_cid_bytes_from_cache = cached_shard_data[ + offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len + ] return not all(b == 0 for b in chunk_cid_bytes_from_cache) - # else: fall through to CAS load, cache might be out of sync or wrong size (should not happen with correct logic) + # else: fall through to CAS load, cache might be out of sync or wrong size (should not happen with correct logic) # If not in cache or cache check was inconclusive, fetch from CAS chunk_cid_bytes_from_cas = await self.cas.load( target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len ) return not all(b == 0 for b in chunk_cid_bytes_from_cas) - except Exception: # Broad catch for issues like invalid coords, CAS errors during load etc. + except ( + Exception + ): # Broad catch for issues like invalid coords, CAS errors during load etc. return False - @property def supports_writes(self) -> bool: return not self.read_only @property def supports_partial_writes(self) -> bool: - return False # Each chunk CID is written atomically into a shard slot + return False # Each chunk CID is written atomically into a shard slot @property def supports_deletes(self) -> bool: - return not self.read_only + return not self.read_only async def delete(self, key: str) -> None: if self.read_only: @@ -548,42 +622,49 @@ async def delete(self, key: str) -> None: raise RuntimeError("Store not initialized for deletion (root_obj is None).") chunk_coords = self._parse_chunk_key(key) - if chunk_coords is None: # Metadata + if chunk_coords is None: # Metadata if key in self._root_obj.get("metadata", {}): del self._root_obj["metadata"][key] self._dirty_root = True return else: raise KeyError(f"Metadata key '{key}' not found for deletion.") - # Chunk deletion: zero out the CID entry in the shard linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - if not (0 <= shard_idx < (self._num_shards if self._num_shards is not None else 0)): - raise KeyError(f"Chunk key '{key}' maps to an invalid shard index {shard_idx}.") + if not ( + 0 <= shard_idx < (self._num_shards if self._num_shards is not None else 0) + ): + raise KeyError( + f"Chunk key '{key}' maps to an invalid shard index {shard_idx}." + ) # Ensure shard data is available for modification (loads from CAS if not in cache, or initializes if new) target_shard_data_cache = await self._load_or_initialize_shard_cache(shard_idx) - + offset_in_shard_bytes = index_in_shard * self._cid_len - + # Check if the entry is already zeroed (meaning it doesn't exist or already deleted) is_already_zero = True for i in range(self._cid_len): - if offset_in_shard_bytes + i >= len(target_shard_data_cache) or \ - target_shard_data_cache[offset_in_shard_bytes + i] != 0: + if ( + offset_in_shard_bytes + i >= len(target_shard_data_cache) + or target_shard_data_cache[offset_in_shard_bytes + i] != 0 + ): is_already_zero = False break - + if is_already_zero: - raise KeyError(f"Chunk key '{key}' not found or already effectively deleted (CID slot is zeroed).") + raise KeyError( + f"Chunk key '{key}' not found or already effectively deleted (CID slot is zeroed)." + ) # Zero out the CID entry in the shard cache for i in range(self._cid_len): target_shard_data_cache[offset_in_shard_bytes + i] = 0 - + self._dirty_shards.add(shard_idx) # If this shard becomes non-None in root_obj due to other writes, flush will handle it. # If this deletion makes a previously non-None shard all zeros, the shard itself might @@ -606,14 +687,16 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: # This simplified version only works for the root directory (prefix == "") of metadata. # It lists unique first components of metadata keys. if self._root_obj is None: - raise RuntimeError("Root object not loaded. Call _load_root_from_cid() first.") + raise RuntimeError( + "Root object not loaded. Call _load_root_from_cid() first." + ) seen: Set[str] = set() if prefix == "": - async for key in self.list(): # Iterates metadata keys + async for key in self.list(): # Iterates metadata keys # e.g., if key is "group1/.zgroup" or "array1/.zarray", first_component is "group1" or "array1" # if key is ".zgroup", first_component is ".zgroup" - first_component = key.split('/', 1)[0] + first_component = key.split("/", 1)[0] if first_component not in seen: seen.add(first_component) yield first_component @@ -623,10 +706,147 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: # Zarr spec: list_dir(path) should yield children (both objects and "directories") # For simplicity, and consistency with original FlatZarrStore, keeping this minimal. # To make it more compliant for prefix="foo/": - normalized_prefix = prefix if prefix.endswith('/') else prefix + '/' + normalized_prefix = prefix if prefix.endswith("/") else prefix + "/" async for key in self.list_prefix(normalized_prefix): - remainder = key[len(normalized_prefix):] - child = remainder.split('/', 1)[0] - if child not in seen: - seen.add(child) - yield child \ No newline at end of file + remainder = key[len(normalized_prefix) :] + child = remainder.split("/", 1)[0] + if child not in seen: + seen.add(child) + yield child + + async def pin_entire_dataset( + self, target_rpc: str = "http://127.0.0.1:5001" + ) -> None: + """ + Pins the entire dataset in the CAS, ensuring the root, metadata, shards, + and all data chunks are pinned. This is useful for performance optimization + when the dataset is accessed frequently. + """ + if self._root_obj is None: + raise RuntimeError( + "Root object not loaded. Call _load_root_from_cid() first." + ) + if self._cid_len is None: + raise RuntimeError( + "Store is not initialized properly; _cid_len is missing." + ) + + # Pin the root CID itself + if self._root_cid: + await self.cas.pin_cid(self._root_cid, target_rpc=target_rpc) + + # Pin metadata CIDs + for cid in self._root_obj.get("metadata", {}).values(): + if cid: + await self.cas.pin_cid(cid, target_rpc=target_rpc) + + # Pin all shard CIDs and the chunk CIDs within them + for index, shard_cid in enumerate(self._root_obj["chunks"]["shard_cids"]): + if not shard_cid: + continue + + # Pin the shard itself + print(f"Pinning shard {shard_cid} to {target_rpc}...") + await self.cas.pin_cid(shard_cid, target_rpc=target_rpc) + + try: + # Load shard data to find and pin the chunk CIDs within + shard_data = await self.cas.load(shard_cid) + + chunks_pinned = 0 + for i in range(0, len(shard_data), self._cid_len): + cid_bytes = shard_data[i : i + self._cid_len] + + if all(b == 0 for b in cid_bytes): # Skip null/empty CID slots + continue + print(f"Processing chunk CID bytes: {cid_bytes}") + + chunk_cid_str = cid_bytes.decode("ascii").rstrip("\x00") + if chunk_cid_str: + await self.cas.pin_cid(chunk_cid_str, target_rpc=target_rpc) + chunks_pinned += 1 + print(f"Pinned {chunks_pinned} chunk CIDs in shard {shard_cid}.") + print( + f"Total shards processed: {index + 1}/{len(self._root_obj['chunks']['shard_cids'])}" + ) + # Print progress based on amount of shards processed + # Catch any exceptions during shard loading or pinning + except Exception as e: + print( + f"Warning: Could not load or process shard {shard_cid} for pinning: {e}" + ) + + async def unpin_entire_dataset( + self, target_rpc: str = "http://127.0.0.1:5001" + ) -> None: + """ + Unpins the entire dataset from the CAS, removing the root, metadata, shards, + and all data chunks from the pin set. This is useful for freeing up storage + resources when the dataset is no longer needed. + """ + if self._root_obj is None: + raise RuntimeError( + "Root object not loaded. Call _load_root_from_cid() first." + ) + if self._cid_len is None: + raise RuntimeError( + "Store is not initialized properly; _cid_len is missing." + ) + + # Unpin all chunk CIDs by reading from shards first + for shard_cid in self._root_obj["chunks"]["shard_cids"]: + if not shard_cid: + continue + + try: + shard_data = await self.cas.load(shard_cid) + # Iterate through the packed CIDs in the shard data + for i in range(0, len(shard_data), self._cid_len): + cid_bytes = shard_data[i : i + self._cid_len] + if all(b == 0 for b in cid_bytes): + continue + + chunk_cid_str = cid_bytes.decode("ascii").rstrip("\x00") + if chunk_cid_str: + try: + await self.cas.unpin_cid( + chunk_cid_str, target_rpc=target_rpc + ) + except Exception as e: + # ignore + continue + print( + f"Unpinned all chunk CIDs in shard {shard_cid} from {target_rpc}." + ) + except Exception as e: + # Log error but continue to attempt to unpin the shard itself + print( + f"Warning: Could not load or process chunks in shard {str(shard_cid)} for unpinning: {e}" + ) + # After unpinning all chunks within, unpin the shard itself + try: + await self.cas.unpin_cid(str(shard_cid), target_rpc=target_rpc) + except Exception as e: + print(f"Warning: Could not unpin shard {str(shard_cid)}") + print(f"Unpinned shard {shard_cid} from {target_rpc}.") + + # Unpin metadata CIDs + for cid in self._root_obj.get("metadata", {}).values(): + if cid: + try: + await self.cas.unpin_cid(cid, target_rpc=target_rpc) + print(f"Unpinned metadata CID {cid} from {target_rpc}...") + except Exception as e: + print( + f"Warning: Could not unpin metadata CID {cid}. Likely already unpinned." + ) + + # Finally, unpin the root CID itself + if self._root_cid: + try: + await self.cas.unpin_cid(self._root_cid, target_rpc=target_rpc) + print(f"Unpinned root CID {self._root_cid} from {target_rpc}...") + except Exception as e: + print( + f"Warning: Could not unpin root CID {self._root_cid}. Likely already unpinned." + ) diff --git a/py_hamt/store.py b/py_hamt/store.py index be32b5e..ae3b46a 100644 --- a/py_hamt/store.py +++ b/py_hamt/store.py @@ -165,6 +165,7 @@ def __init__( *, headers: dict[str, str] | None = None, auth: aiohttp.BasicAuth | None = None, + pinOnAdd: bool = False, ): """ If None is passed into the rpc or gateway base url, then the default for kubo local daemons will be used. The default local values will also be used if nothing is passed in at all. @@ -210,7 +211,11 @@ def __init__( if gateway_base_url is None: gateway_base_url = KuboCAS.KUBO_DEFAULT_LOCAL_GATEWAY_BASE_URL - self.rpc_url: str = f"{rpc_base_url}/api/v0/add?hash={self.hasher}&pin=false" + pinString: str = "true" if pinOnAdd else "false" + + self.rpc_url: str = ( + f"{rpc_base_url}/api/v0/add?hash={self.hasher}&pin={pinString}" + ) """@private""" self.gateway_base_url: str = f"{gateway_base_url}/ipfs/" """@private""" @@ -334,3 +339,49 @@ async def load( async with self._loop_session().get(url, headers=headers or None) as resp: resp.raise_for_status() return await resp.read() + + # --------------------------------------------------------------------- # + # pin_cid() – method to pin a CID # + # --------------------------------------------------------------------- # + async def pin_cid( + self, + cid: CID, + name: Optional[str] = None, + target_rpc: str = "http://127.0.0.1:5001", + ) -> None: + """ + Pins a CID to the local Kubo node via the RPC API. + + This call is recursive by default, pinning all linked objects. + + Args: + cid (CID): The Content ID to pin. + name (Optional[str]): An optional name for the pin. + """ + params = {"arg": str(cid), "recursive": "false"} + if name: + params["name"] = name + pin_add_url_base: str = f"{target_rpc}/api/v0/pin/add" + + async with self._sem: # throttle RPC + async with self._loop_session().post( + pin_add_url_base, params=params + ) as resp: + resp.raise_for_status() + # A 200 OK is sufficient to indicate success. + + async def unpin_cid( + self, cid: CID, target_rpc: str = "http://127.0.0.1:5001" + ) -> None: + """ + Unpins a CID from the local Kubo node via the RPC API. + + Args: + cid (CID): The Content ID to unpin. + """ + params = {"arg": str(cid), "recursive": "false"} + unpin_url_base: str = f"{target_rpc}/api/v0/pin/rm" + async with self._sem: # throttle RPC + async with self._loop_session().post(unpin_url_base, params=params) as resp: + resp.raise_for_status() + # A 200 OK is sufficient to indicate success. diff --git a/tests/test_benchmark_stores.py b/tests/test_benchmark_stores.py index 27a8646..b2b9534 100644 --- a/tests/test_benchmark_stores.py +++ b/tests/test_benchmark_stores.py @@ -56,8 +56,6 @@ # yield ds - - # # # ### # # # BENCHMARK FOR THE NEW FlatZarrStore # # # ### @@ -224,7 +222,7 @@ # # Construct a minimal selection based on available dimensions # selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} # if selection: -# _ = ipfs_ds[first_var_name].isel(**selection).values +# _ = ipfs_ds[first_var_name].isel(**selection).values # else: # If no dimensions, try loading the whole variable (e.g. scalar) # _ = ipfs_ds[first_var_name].values # end_read = time.perf_counter() diff --git a/tests/test_converter.py b/tests/test_converter.py index 6c46827..cdc7eab 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -11,7 +11,13 @@ import xarray as xr # Import store implementations -from py_hamt import HAMT, KuboCAS, ShardedZarrStore, convert_hamt_to_sharded, sharded_converter_cli +from py_hamt import ( + HAMT, + KuboCAS, + ShardedZarrStore, + convert_hamt_to_sharded, + sharded_converter_cli, +) from py_hamt.zarr_hamt_store import ZarrHAMTStore @@ -25,7 +31,7 @@ def converter_test_dataset(): times = pd.date_range("2025-01-01", periods=20) lats = np.linspace(40, 50, 10) lons = np.linspace(-85, -75, 20) - + # Generate a unique variable name for this test run unique_var_name = f"data_{str(uuid.uuid4())[:8]}" @@ -95,12 +101,15 @@ async def test_converter_produces_identical_dataset( # STEP 3: Verification # -------------------------------------------------------------------- print("\n--- STEP 3: Verifying data integrity ---") - + # Open the original dataset from the HAMT store print("Reading data back from original HAMT store...") hamt_ro = await HAMT.build( - cas=kubo_cas, root_node_id=hamt_root_cid, values_are_bytes=True, read_only=True + cas=kubo_cas, + root_node_id=hamt_root_cid, + values_are_bytes=True, + read_only=True, ) zhs_ro = ZarrHAMTStore(hamt_ro, read_only=True) @@ -109,14 +118,14 @@ async def test_converter_produces_identical_dataset( end_read = time.perf_counter() print(f"Original HAMT store read in {end_read - start_read:.2f}s") - + # Open the converted dataset from the new Sharded store print("Reading data back from new Sharded store...") dest_store_ro = await ShardedZarrStore.open( cas=kubo_cas, read_only=True, root_cid=sharded_root_cid ) ds_from_sharded = xr.open_zarr(dest_store_ro) - + # The ultimate test: are the two xarray.Dataset objects identical? # This checks coordinates, variables, data values, and attributes. print("Comparing the two datasets...") @@ -127,15 +136,14 @@ async def test_converter_produces_identical_dataset( np.testing.assert_array_equal( ds_from_hamt[var].values, ds_from_sharded[var].values ) - + print("\n✅ Verification successful! The datasets are identical.") print("=" * 80) + @pytest.mark.asyncio(loop_scope="session") async def test_hamt_to_sharded_cli_success( - create_ipfs: tuple[str, str], - converter_test_dataset: xr.Dataset, - capsys + create_ipfs: tuple[str, str], converter_test_dataset: xr.Dataset, capsys ): """ Tests the CLI for successful conversion of a HAMT store to a ShardedZarrStore. @@ -157,9 +165,12 @@ async def test_hamt_to_sharded_cli_success( test_args = [ "script.py", # Dummy script name hamt_root_cid, - "--chunks-per-shard", "64", - "--rpc-url", rpc_base_url, - "--gateway-url", gateway_base_url + "--chunks-per-shard", + "64", + "--rpc-url", + rpc_base_url, + "--gateway-url", + gateway_base_url, ] with patch.object(sys, "argv", test_args): await sharded_converter_cli() @@ -180,11 +191,10 @@ async def test_hamt_to_sharded_cli_success( ds_from_sharded = xr.open_zarr(dest_store_ro) xr.testing.assert_identical(test_ds, ds_from_sharded) + @pytest.mark.asyncio(loop_scope="session") async def test_hamt_to_sharded_cli_default_args( - create_ipfs: tuple[str, str], - converter_test_dataset: xr.Dataset, - capsys + create_ipfs: tuple[str, str], converter_test_dataset: xr.Dataset, capsys ): """ Tests the CLI with default argument values. @@ -206,8 +216,10 @@ async def test_hamt_to_sharded_cli_default_args( test_args = [ "script.py", # Dummy script name hamt_root_cid, - "--rpc-url", rpc_base_url, - "--gateway-url", gateway_base_url + "--rpc-url", + rpc_base_url, + "--gateway-url", + gateway_base_url, ] with patch.object(sys, "argv", test_args): await sharded_converter_cli() @@ -225,11 +237,9 @@ async def test_hamt_to_sharded_cli_default_args( ds_from_sharded = xr.open_zarr(dest_store_ro) xr.testing.assert_identical(test_ds, ds_from_sharded) + @pytest.mark.asyncio(loop_scope="session") -async def test_hamt_to_sharded_cli_invalid_cid( - create_ipfs: tuple[str, str], - capsys -): +async def test_hamt_to_sharded_cli_invalid_cid(create_ipfs: tuple[str, str], capsys): """ Tests the CLI with an invalid hamt_cid. """ @@ -242,9 +252,12 @@ async def test_hamt_to_sharded_cli_invalid_cid( test_args = [ "script.py", invalid_cid, - "--chunks-per-shard", "64", - "--rpc-url", rpc_base_url, - "--gateway-url", gateway_base_url + "--chunks-per-shard", + "64", + "--rpc-url", + rpc_base_url, + "--gateway-url", + gateway_base_url, ] with patch.object(sys, "argv", test_args): await sharded_converter_cli() @@ -252,4 +265,4 @@ async def test_hamt_to_sharded_cli_invalid_cid( # Verify error handling captured = capsys.readouterr() assert "An error occurred" in captured.out - assert f"{invalid_cid}" in captured.out \ No newline at end of file + assert f"{invalid_cid}" in captured.out diff --git a/tests/test_cpc_compare.py b/tests/test_cpc_compare.py index 6930eef..40df647 100644 --- a/tests/test_cpc_compare.py +++ b/tests/test_cpc_compare.py @@ -12,7 +12,6 @@ # from py_hamt.zarr_hamt_store import ZarrHAMTStore - # @pytest.mark.asyncio(loop_scope="session") # async def test_benchmark_sharded_store(): # """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring @@ -26,7 +25,7 @@ # headers = { # "X-API-Key": "", # } - + # async with KuboCAS( # rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers # ) as kubo_cas: @@ -55,7 +54,7 @@ # # Construct a minimal selection based on available dimensions # selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} # if selection: -# _ = ipfs_ds[first_var_name].isel(**selection).values +# _ = ipfs_ds[first_var_name].isel(**selection).values # else: # If no dimensions, try loading the whole variable (e.g. scalar) # _ = ipfs_ds[first_var_name].values # end_read = time.perf_counter() diff --git a/tests/test_sharded_zarr_pinning.py b/tests/test_sharded_zarr_pinning.py new file mode 100644 index 0000000..14e7cd8 --- /dev/null +++ b/tests/test_sharded_zarr_pinning.py @@ -0,0 +1,138 @@ +import asyncio +import aiohttp +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from py_hamt import KuboCAS, ShardedZarrStore + + +# Helper function to query the IPFS daemon for all pinned CIDs +async def get_pinned_cids(rpc_base_url: str) -> set[str]: + """Queries the Kubo RPC API and returns a set of all pinned CIDs.""" + url = f"{rpc_base_url}/api/v0/pin/ls" + try: + async with aiohttp.ClientSession() as session: + async with session.post(url, params={'type': 'all'}) as resp: + resp.raise_for_status() + data = await resp.json() + return set(data.get("Keys", {}).keys()) + except Exception as e: + pytest.fail(f"Failed to query pinned CIDs from Kubo RPC API: {e}") + return set() + + +# Helper function to gather all CIDs from a store instance +async def get_all_dataset_cids(store: ShardedZarrStore) -> set[str]: + """Helper to collect all CIDs associated with a ShardedZarrStore instance.""" + if store._root_obj is None or store._cid_len is None: + raise RuntimeError("Store is not properly initialized.") + + cids = set() + if store._root_cid: + cids.add(store._root_cid) + + # Gather metadata CIDs + for cid in store._root_obj.get("metadata", {}).values(): + if cid: + cids.add(cid) + + + # Gather shard and all chunk CIDs within them + for shard_cid in store._root_obj["chunks"]["shard_cids"]: + if not shard_cid: + continue + cids.add(str(shard_cid)) + try: + # Load shard data to find the chunk CIDs within + shard_data = await store.cas.load(shard_cid) + for i in range(0, len(shard_data), store._cid_len): + cid_bytes = shard_data[i : i + store._cid_len] + if all(b == 0 for b in cid_bytes): # Skip null/empty CID slots + continue + + chunk_cid_str = cid_bytes.decode("ascii").rstrip('\x00') + if chunk_cid_str: + cids.add(chunk_cid_str) + except Exception as e: + print(f"Warning: Could not load shard {shard_cid} to gather its CIDs: {e}") + + return cids + + +@pytest.fixture(scope="module") +def random_zarr_dataset_for_pinning(): + """Creates a random xarray Dataset specifically for the pinning test.""" + times = pd.date_range("2025-01-01", periods=50) + lats = np.linspace(-90, 90, 10) + lons = np.linspace(-180, 180, 20) + + temp = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + {"temp": (["time", "lat", "lon"], temp)}, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 10, "lat": 10, "lon": 20}) + yield ds + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_pinning( + create_ipfs: tuple[str, str], random_zarr_dataset_for_pinning: xr.Dataset +): + """ + Tests the pin_entire_dataset and unpin_entire_dataset methods. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset_for_pinning + + ordered_dims = list(test_ds.dims) + array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # --- 1. Write dataset to the store --- + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=1, # Use a smaller number to ensure multiple shards + ) + test_ds.to_zarr(store=store, mode="w", consolidated=True) + root_cid = await store.flush() + assert root_cid is not None + + # --- 2. Gather all expected CIDs from the written store --- + expected_cids = await get_all_dataset_cids(store) + assert len(expected_cids) > 5 # Sanity check: ensure we have CIDs to test + + # --- 3. Pin the dataset and verify --- + await store.pin_entire_dataset(target_rpc=rpc_base_url) + + # Allow a moment for pins to register + await asyncio.sleep(1) + + currently_pinned = await get_pinned_cids(rpc_base_url) + + # Check if all our dataset's CIDs are in the main pin list + missing_pins = expected_cids - currently_pinned + assert not missing_pins, f"The following CIDs were expected to be pinned but were not: {missing_pins}" + + # --- 4. Unpin the dataset and verify --- + await store.unpin_entire_dataset(target_rpc=rpc_base_url) + + # Allow a moment for pins to be removed + await asyncio.sleep(1) + + pinned_after_unpin = await get_pinned_cids(rpc_base_url) + + # Check that none of our dataset's CIDs are in the pin list anymore + lingering_pins = expected_cids.intersection(pinned_after_unpin) + assert not lingering_pins, f"The following CIDs were expected to be unpinned but still exist: {lingering_pins}" diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index bcdabf6..3a09df5 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -115,7 +115,6 @@ async def test_sharded_zarr_store_metadata( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) @@ -157,7 +156,6 @@ async def test_sharded_zarr_store_metadata( assert "lat/c/0" in keys_with_prefix - @pytest.mark.asyncio async def test_sharded_zarr_store_chunks( create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset @@ -168,7 +166,6 @@ async def test_sharded_zarr_store_chunks( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) @@ -208,6 +205,7 @@ async def test_sharded_zarr_store_chunks( ) assert await store_read_after_delete.get(chunk_key, proto) is None + @pytest.mark.asyncio async def test_chunk_and_delete_logic( create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset @@ -260,7 +258,7 @@ async def test_chunk_and_delete_logic( # Test deleting a non-existent key with pytest.raises(KeyError): await store_rw.delete("nonexistent/c/0/0/0") - + # Test deleting an already deleted key with pytest.raises(KeyError): await store_rw.delete(chunk_key) @@ -276,7 +274,6 @@ async def test_sharded_zarr_store_partial_reads( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) @@ -309,6 +306,7 @@ async def test_sharded_zarr_store_partial_reads( assert partial_chunk is not None assert partial_chunk.to_bytes() == full_chunk_bytes[10:50] + @pytest.mark.asyncio async def test_partial_reads_and_errors( create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset @@ -351,29 +349,37 @@ async def test_partial_reads_and_errors( # Test invalid byte range with pytest.raises(ValueError): - await store_read.get(chunk_key, proto, byte_range=RangeByteRequest(start=50, end=10)) + await store_read.get( + chunk_key, proto, byte_range=RangeByteRequest(start=50, end=10) + ) + @pytest.mark.asyncio async def test_zero_sized_array(create_ipfs: tuple[str, str]): """Test handling of arrays with a zero-length dimension.""" rpc_base_url, gateway_base_url = create_ipfs - async with KuboCAS(rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url) as kubo_cas: + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: store = await ShardedZarrStore.open( cas=kubo_cas, read_only=False, array_shape=(100, 0), chunk_shape=(10, 10), - chunks_per_shard=64 + chunks_per_shard=64, ) assert store._total_chunks == 0 assert store._num_shards == 0 root_cid = await store.flush() - + # Read it back and verify - store_read = await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=root_cid) + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) assert store_read._total_chunks == 0 assert store_read._num_shards == 0 + @pytest.mark.asyncio async def test_store_eq_method(create_ipfs: tuple[str, str]): """Tests the __eq__ method.""" @@ -381,9 +387,17 @@ async def test_store_eq_method(create_ipfs: tuple[str, str]): async with KuboCAS( rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url ) as kubo_cas: - store1 = await ShardedZarrStore.open(cas=kubo_cas, read_only=False, array_shape=(1,1), chunk_shape=(1,1), chunks_per_shard=1) + store1 = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(1, 1), + chunk_shape=(1, 1), + chunks_per_shard=1, + ) root_cid = await store1.flush() - store2 = await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=root_cid) + store2 = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) assert store1 == store2 @@ -439,6 +453,7 @@ async def test_listing_and_metadata( assert "lon" in dir_keys assert "zarr.json" in dir_keys + @pytest.mark.asyncio async def test_sharded_zarr_store_init_errors(create_ipfs: tuple[str, str]): """ @@ -467,7 +482,9 @@ async def test_sharded_zarr_store_init_errors(create_ipfs: tuple[str, str]): ) # Test invalid chunk_shape - with pytest.raises(ValueError, match="All chunk_shape dimensions must be positive"): + with pytest.raises( + ValueError, match="All chunk_shape dimensions must be positive" + ): await ShardedZarrStore.open( cas=kubo_cas, read_only=False, @@ -477,7 +494,9 @@ async def test_sharded_zarr_store_init_errors(create_ipfs: tuple[str, str]): ) # Test invalid array_shape - with pytest.raises(ValueError, match="All array_shape dimensions must be non-negative"): + with pytest.raises( + ValueError, match="All array_shape dimensions must be non-negative" + ): await ShardedZarrStore.open( cas=kubo_cas, read_only=False, @@ -486,13 +505,18 @@ async def test_sharded_zarr_store_init_errors(create_ipfs: tuple[str, str]): chunks_per_shard=10, ) + @pytest.mark.asyncio async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, str]): """Tests initialization with invalid shapes and manifest errors.""" rpc_base_url, gateway_base_url = create_ipfs - async with KuboCAS(rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url) as kubo_cas: + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: # Test negative chunk_shape dimension (line 136) - with pytest.raises(ValueError, match="All chunk_shape dimensions must be positive"): + with pytest.raises( + ValueError, match="All chunk_shape dimensions must be positive" + ): await ShardedZarrStore.open( cas=kubo_cas, read_only=False, @@ -502,7 +526,9 @@ async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, st ) # Test negative array_shape dimension (line 141) - with pytest.raises(ValueError, match="All array_shape dimensions must be non-negative"): + with pytest.raises( + ValueError, match="All array_shape dimensions must be non-negative" + ): await ShardedZarrStore.open( cas=kubo_cas, read_only=False, @@ -536,31 +562,45 @@ async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, st "shard_cids": [None] * 4, }, } - invalid_root_cid = await kubo_cas.save(dag_cbor.encode(invalid_root_obj), codec="dag-cbor") + invalid_root_cid = await kubo_cas.save( + dag_cbor.encode(invalid_root_obj), codec="dag-cbor" + ) with pytest.raises(ValueError, match="Incompatible manifest version"): - await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=invalid_root_cid) + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_root_cid + ) # Test inconsistent shard count (line 236) invalid_root_obj = { "manifest_version": "sharded_zarr_v1", "metadata": {}, "chunks": { - "array_shape": [10, 10], # 100 chunks, with 10 chunks per shard -> 10 shards + "array_shape": [ + 10, + 10, + ], # 100 chunks, with 10 chunks per shard -> 10 shards "chunk_shape": [5, 5], "cid_byte_length": 59, "sharding_config": {"chunks_per_shard": 10}, "shard_cids": [None] * 5, # Wrong number of shards }, } - invalid_root_cid = await kubo_cas.save(dag_cbor.encode(invalid_root_obj), codec="dag-cbor") + invalid_root_cid = await kubo_cas.save( + dag_cbor.encode(invalid_root_obj), codec="dag-cbor" + ) with pytest.raises(ValueError, match="Inconsistent number of shards"): - await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=invalid_root_cid) + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_root_cid + ) + @pytest.mark.asyncio async def test_sharded_zarr_store_parse_chunk_key(create_ipfs: tuple[str, str]): """Tests chunk key parsing edge cases.""" rpc_base_url, gateway_base_url = create_ipfs - async with KuboCAS(rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url) as kubo_cas: + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: store = await ShardedZarrStore.open( cas=kubo_cas, read_only=False, @@ -583,7 +623,9 @@ async def test_sharded_zarr_store_parse_chunk_key(create_ipfs: tuple[str, str]): assert uninitialized_store._parse_chunk_key("temp/c/0/0") is None # Test get on uninitialized store - with pytest.raises(RuntimeError, match="Load the root object first before accessing data."): + with pytest.raises( + RuntimeError, match="Load the root object first before accessing data." + ): proto = zarr.core.buffer.default_buffer_prototype() await uninitialized_store.get("temp/c/0/0", proto) @@ -594,17 +636,24 @@ async def test_sharded_zarr_store_parse_chunk_key(create_ipfs: tuple[str, str]): assert store._parse_chunk_key("temp/c/0/0/0") is None # 3D key for 2D array # Test invalid coordinates - assert store._parse_chunk_key("temp/c/3/0") is None # Out of bounds (3 >= 2 chunks) + assert ( + store._parse_chunk_key("temp/c/3/0") is None + ) # Out of bounds (3 >= 2 chunks) assert store._parse_chunk_key("temp/c/0/invalid") is None # Non-integer assert store._parse_chunk_key("temp/c/0/-1") is None # Negative coordinate + @pytest.mark.asyncio async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, str]): """Tests initialization with invalid shapes and manifest errors.""" rpc_base_url, gateway_base_url = create_ipfs - async with KuboCAS(rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url) as kubo_cas: + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: # Test negative chunk_shape dimension - with pytest.raises(ValueError, match="All chunk_shape dimensions must be positive"): + with pytest.raises( + ValueError, match="All chunk_shape dimensions must be positive" + ): await ShardedZarrStore.open( cas=kubo_cas, read_only=False, @@ -614,7 +663,9 @@ async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, st ) # Test negative array_shape dimension - with pytest.raises(ValueError, match="All array_shape dimensions must be non-negative"): + with pytest.raises( + ValueError, match="All array_shape dimensions must be non-negative" + ): await ShardedZarrStore.open( cas=kubo_cas, read_only=False, @@ -648,22 +699,33 @@ async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, st "shard_cids": [None] * 4, }, } - invalid_root_cid = await kubo_cas.save(dag_cbor.encode(invalid_root_obj), codec="dag-cbor") + invalid_root_cid = await kubo_cas.save( + dag_cbor.encode(invalid_root_obj), codec="dag-cbor" + ) with pytest.raises(ValueError, match="Incompatible manifest version"): - await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=invalid_root_cid) + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_root_cid + ) # Test inconsistent shard count invalid_root_obj = { "manifest_version": "sharded_zarr_v1", "metadata": {}, "chunks": { - "array_shape": [10, 10], # 100 chunks, with 10 chunks per shard -> 10 shards + "array_shape": [ + 10, + 10, + ], # 100 chunks, with 10 chunks per shard -> 10 shards "chunk_shape": [5, 5], "cid_byte_length": 59, "sharding_config": {"chunks_per_shard": 10}, "shard_cids": [None] * 5, # Wrong number of shards }, } - invalid_root_cid = await kubo_cas.save(dag_cbor.encode(invalid_root_obj), codec="dag-cbor") + invalid_root_cid = await kubo_cas.save( + dag_cbor.encode(invalid_root_obj), codec="dag-cbor" + ) with pytest.raises(ValueError, match="Inconsistent number of shards"): - await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=invalid_root_cid) \ No newline at end of file + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_root_cid + ) From 21e15c3446b077dc3420a6d7976f6168ed60879a Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:19:49 -0400 Subject: [PATCH 14/74] fix: pre-commit --- py_hamt/manage_pins.py | 60 ++++++++++++++++++++++-------- tests/test_sharded_zarr_pinning.py | 13 ++++--- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/py_hamt/manage_pins.py b/py_hamt/manage_pins.py index 3795874..ec7cd40 100644 --- a/py_hamt/manage_pins.py +++ b/py_hamt/manage_pins.py @@ -2,6 +2,7 @@ A command-line tool to recursively pin or unpin all CIDs associated with a sharded Zarr dataset on IPFS using its root CID. """ + import asyncio import argparse import sys @@ -9,17 +10,25 @@ # --- CLI Logic Functions --- + async def handle_pin(args): """ Connects to IPFS, loads the dataset from the root CID, and pins all associated CIDs (root, metadata, shards, and data chunks). """ - async with KuboCAS(rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url) as kubo_cas: + async with KuboCAS( + rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url + ) as kubo_cas: try: print(f"-> Opening store with root CID: {args.root_cid}") - store = await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=args.root_cid) + store = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=args.root_cid + ) except Exception as e: - print(f"Error: Failed to open Zarr store for CID {args.root_cid}. Ensure the CID is correct and the daemon is running.", file=sys.stderr) + print( + f"Error: Failed to open Zarr store for CID {args.root_cid}. Ensure the CID is correct and the daemon is running.", + file=sys.stderr, + ) print(f"Details: {e}", file=sys.stderr) return @@ -34,12 +43,19 @@ async def handle_unpin(args): Connects to IPFS, loads the dataset from the root CID, and unpins all associated CIDs. """ - async with KuboCAS(rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url) as kubo_cas: + async with KuboCAS( + rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url + ) as kubo_cas: try: print(f"-> Opening store with root CID: {args.root_cid}") - store = await ShardedZarrStore.open(cas=kubo_cas, read_only=True, root_cid=args.root_cid) + store = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=args.root_cid + ) except Exception as e: - print(f"Error: Failed to open Zarr store for CID {args.root_cid}. Ensure the CID is correct and the daemon is running.", file=sys.stderr) + print( + f"Error: Failed to open Zarr store for CID {args.root_cid}. Ensure the CID is correct and the daemon is running.", + file=sys.stderr, + ) print(f"Details: {e}", file=sys.stderr) return @@ -53,21 +69,35 @@ def main(): """Sets up the argument parser and runs the selected command.""" parser = argparse.ArgumentParser( description="A CLI tool to pin or unpin sharded Zarr datasets on IPFS.", - formatter_class=argparse.RawTextHelpFormatter + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--rpc-url", + default="http://127.0.0.1:5001", + help="IPFS Kubo RPC API endpoint URL.", + ) + parser.add_argument( + "--gateway-url", + default="http://127.0.0.1:8080", + help="IPFS Gateway URL (needed for loading shards).", ) - parser.add_argument('--rpc-url', default='http://127.0.0.1:5001', help='IPFS Kubo RPC API endpoint URL.') - parser.add_argument('--gateway-url', default='http://127.0.0.1:8080', help='IPFS Gateway URL (needed for loading shards).') - subparsers = parser.add_subparsers(dest='command', required=True, help='Available commands') + subparsers = parser.add_subparsers( + dest="command", required=True, help="Available commands" + ) # --- Pin Command --- - parser_pin = subparsers.add_parser('pin', help='Recursively pin a dataset using its root CID.') - parser_pin.add_argument('root_cid', help='The root CID of the dataset to pin.') + parser_pin = subparsers.add_parser( + "pin", help="Recursively pin a dataset using its root CID." + ) + parser_pin.add_argument("root_cid", help="The root CID of the dataset to pin.") parser_pin.set_defaults(func=handle_pin) # --- Unpin Command --- - parser_unpin = subparsers.add_parser('unpin', help='Recursively unpin a dataset using its root CID.') - parser_unpin.add_argument('root_cid', help='The root CID of the dataset to unpin.') + parser_unpin = subparsers.add_parser( + "unpin", help="Recursively unpin a dataset using its root CID." + ) + parser_unpin.add_argument("root_cid", help="The root CID of the dataset to unpin.") parser_unpin.set_defaults(func=handle_unpin) args = parser.parse_args() @@ -83,4 +113,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_sharded_zarr_pinning.py b/tests/test_sharded_zarr_pinning.py index 14e7cd8..27c5b01 100644 --- a/tests/test_sharded_zarr_pinning.py +++ b/tests/test_sharded_zarr_pinning.py @@ -14,7 +14,7 @@ async def get_pinned_cids(rpc_base_url: str) -> set[str]: url = f"{rpc_base_url}/api/v0/pin/ls" try: async with aiohttp.ClientSession() as session: - async with session.post(url, params={'type': 'all'}) as resp: + async with session.post(url, params={"type": "all"}) as resp: resp.raise_for_status() data = await resp.json() return set(data.get("Keys", {}).keys()) @@ -38,7 +38,6 @@ async def get_all_dataset_cids(store: ShardedZarrStore) -> set[str]: if cid: cids.add(cid) - # Gather shard and all chunk CIDs within them for shard_cid in store._root_obj["chunks"]["shard_cids"]: if not shard_cid: @@ -52,7 +51,7 @@ async def get_all_dataset_cids(store: ShardedZarrStore) -> set[str]: if all(b == 0 for b in cid_bytes): # Skip null/empty CID slots continue - chunk_cid_str = cid_bytes.decode("ascii").rstrip('\x00') + chunk_cid_str = cid_bytes.decode("ascii").rstrip("\x00") if chunk_cid_str: cids.add(chunk_cid_str) except Exception as e: @@ -123,7 +122,9 @@ async def test_sharded_zarr_store_pinning( # Check if all our dataset's CIDs are in the main pin list missing_pins = expected_cids - currently_pinned - assert not missing_pins, f"The following CIDs were expected to be pinned but were not: {missing_pins}" + assert not missing_pins, ( + f"The following CIDs were expected to be pinned but were not: {missing_pins}" + ) # --- 4. Unpin the dataset and verify --- await store.unpin_entire_dataset(target_rpc=rpc_base_url) @@ -135,4 +136,6 @@ async def test_sharded_zarr_store_pinning( # Check that none of our dataset's CIDs are in the pin list anymore lingering_pins = expected_cids.intersection(pinned_after_unpin) - assert not lingering_pins, f"The following CIDs were expected to be unpinned but still exist: {lingering_pins}" + assert not lingering_pins, ( + f"The following CIDs were expected to be unpinned but still exist: {lingering_pins}" + ) From 247bd695b95d0f8e04e23885995d4d4c3a08a951 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:21:59 -0400 Subject: [PATCH 15/74] fix: tidying --- py_hamt/store.py | 9 ++ tests/test_sharded_zarr_store.py | 172 +++++++++++++++---------------- 2 files changed, 95 insertions(+), 86 deletions(-) diff --git a/py_hamt/store.py b/py_hamt/store.py index ae3b46a..40fff78 100644 --- a/py_hamt/store.py +++ b/py_hamt/store.py @@ -39,6 +39,15 @@ async def load( ) -> bytes: """Retrieve data.""" + # Optional abstract methods for pinning and unpinning CIDs + async def pin_cid(self, id: IPLDKind) -> None: + """Pin a CID in the storage.""" + pass + + async def unpin_cid(self, id: IPLDKind) -> None: + """Unpin a CID in the storage.""" + pass + class InMemoryCAS(ContentAddressedStore): """Used mostly for faster testing, this is why this is not exported. It hashes all inputs and uses that as a key to an in-memory python dict, mimicking a content addressed storage system. The hash bytes are the ID that `save` returns and `load` takes in.""" diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 3a09df5..ca56512 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -506,92 +506,92 @@ async def test_sharded_zarr_store_init_errors(create_ipfs: tuple[str, str]): ) -@pytest.mark.asyncio -async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, str]): - """Tests initialization with invalid shapes and manifest errors.""" - rpc_base_url, gateway_base_url = create_ipfs - async with KuboCAS( - rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url - ) as kubo_cas: - # Test negative chunk_shape dimension (line 136) - with pytest.raises( - ValueError, match="All chunk_shape dimensions must be positive" - ): - await ShardedZarrStore.open( - cas=kubo_cas, - read_only=False, - array_shape=(10, 10), - chunk_shape=(-5, 5), - chunks_per_shard=10, - ) - - # Test negative array_shape dimension (line 141) - with pytest.raises( - ValueError, match="All array_shape dimensions must be non-negative" - ): - await ShardedZarrStore.open( - cas=kubo_cas, - read_only=False, - array_shape=(10, -10), - chunk_shape=(5, 5), - chunks_per_shard=10, - ) - - # Test zero-sized array (lines 150, 163) - reinforce existing test - store = await ShardedZarrStore.open( - cas=kubo_cas, - read_only=False, - array_shape=(0, 10), - chunk_shape=(5, 5), - chunks_per_shard=10, - ) - assert store._total_chunks == 0 - assert store._num_shards == 0 - assert len(store._root_obj["chunks"]["shard_cids"]) == 0 # Line 163 - root_cid = await store.flush() - - # Test invalid manifest version (line 224) - invalid_root_obj = { - "manifest_version": "invalid_version", - "metadata": {}, - "chunks": { - "array_shape": [10, 10], - "chunk_shape": [5, 5], - "cid_byte_length": 59, - "sharding_config": {"chunks8048": 10}, - "shard_cids": [None] * 4, - }, - } - invalid_root_cid = await kubo_cas.save( - dag_cbor.encode(invalid_root_obj), codec="dag-cbor" - ) - with pytest.raises(ValueError, match="Incompatible manifest version"): - await ShardedZarrStore.open( - cas=kubo_cas, read_only=True, root_cid=invalid_root_cid - ) - - # Test inconsistent shard count (line 236) - invalid_root_obj = { - "manifest_version": "sharded_zarr_v1", - "metadata": {}, - "chunks": { - "array_shape": [ - 10, - 10, - ], # 100 chunks, with 10 chunks per shard -> 10 shards - "chunk_shape": [5, 5], - "cid_byte_length": 59, - "sharding_config": {"chunks_per_shard": 10}, - "shard_cids": [None] * 5, # Wrong number of shards - }, - } - invalid_root_cid = await kubo_cas.save( - dag_cbor.encode(invalid_root_obj), codec="dag-cbor" - ) - with pytest.raises(ValueError, match="Inconsistent number of shards"): - await ShardedZarrStore.open( - cas=kubo_cas, read_only=True, root_cid=invalid_root_cid - ) +# @pytest.mark.asyncio +# async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, str]): +# """Tests initialization with invalid shapes and manifest errors.""" +# rpc_base_url, gateway_base_url = create_ipfs +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url +# ) as kubo_cas: +# # Test negative chunk_shape dimension (line 136) +# with pytest.raises( +# ValueError, match="All chunk_shape dimensions must be positive" +# ): +# await ShardedZarrStore.open( +# cas=kubo_cas, +# read_only=False, +# array_shape=(10, 10), +# chunk_shape=(-5, 5), +# chunks_per_shard=10, +# ) + +# # Test negative array_shape dimension (line 141) +# with pytest.raises( +# ValueError, match="All array_shape dimensions must be non-negative" +# ): +# await ShardedZarrStore.open( +# cas=kubo_cas, +# read_only=False, +# array_shape=(10, -10), +# chunk_shape=(5, 5), +# chunks_per_shard=10, +# ) + +# # Test zero-sized array (lines 150, 163) - reinforce existing test +# store = await ShardedZarrStore.open( +# cas=kubo_cas, +# read_only=False, +# array_shape=(0, 10), +# chunk_shape=(5, 5), +# chunks_per_shard=10, +# ) +# assert store._total_chunks == 0 +# assert store._num_shards == 0 +# assert len(store._root_obj["chunks"]["shard_cids"]) == 0 # Line 163 +# root_cid = await store.flush() + +# # Test invalid manifest version (line 224) +# invalid_root_obj = { +# "manifest_version": "invalid_version", +# "metadata": {}, +# "chunks": { +# "array_shape": [10, 10], +# "chunk_shape": [5, 5], +# "cid_byte_length": 59, +# "sharding_config": {"chunks8048": 10}, +# "shard_cids": [None] * 4, +# }, +# } +# invalid_root_cid = await kubo_cas.save( +# dag_cbor.encode(invalid_root_obj), codec="dag-cbor" +# ) +# with pytest.raises(ValueError, match="Incompatible manifest version"): +# await ShardedZarrStore.open( +# cas=kubo_cas, read_only=True, root_cid=invalid_root_cid +# ) + +# # Test inconsistent shard count (line 236) +# invalid_root_obj = { +# "manifest_version": "sharded_zarr_v1", +# "metadata": {}, +# "chunks": { +# "array_shape": [ +# 10, +# 10, +# ], # 100 chunks, with 10 chunks per shard -> 10 shards +# "chunk_shape": [5, 5], +# "cid_byte_length": 59, +# "sharding_config": {"chunks_per_shard": 10}, +# "shard_cids": [None] * 5, # Wrong number of shards +# }, +# } +# invalid_root_cid = await kubo_cas.save( +# dag_cbor.encode(invalid_root_obj), codec="dag-cbor" +# ) +# with pytest.raises(ValueError, match="Inconsistent number of shards"): +# await ShardedZarrStore.open( +# cas=kubo_cas, read_only=True, root_cid=invalid_root_cid +# ) @pytest.mark.asyncio From 7ae056196ab0b60865ff84bf341e4e2f484ed1de Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:22:34 -0400 Subject: [PATCH 16/74] fix: target --- py_hamt/store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py_hamt/store.py b/py_hamt/store.py index 40fff78..e8cb90f 100644 --- a/py_hamt/store.py +++ b/py_hamt/store.py @@ -40,11 +40,11 @@ async def load( """Retrieve data.""" # Optional abstract methods for pinning and unpinning CIDs - async def pin_cid(self, id: IPLDKind) -> None: + async def pin_cid(self, id: IPLDKind, target_rpc: str) -> None: """Pin a CID in the storage.""" pass - async def unpin_cid(self, id: IPLDKind) -> None: + async def unpin_cid(self, id: IPLDKind, target_rpc: str) -> None: """Unpin a CID in the storage.""" pass From 3ef4aee2b5d6229ab1daeb6754d8730e42d12328 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:30:47 -0400 Subject: [PATCH 17/74] fix: fixing types --- py_hamt/sharded_zarr_store.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 2cbd64c..b43969a 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -83,12 +83,14 @@ async def open( if root_cid: await store._load_root_from_cid() elif not read_only: - if not all([array_shape, chunk_shape, chunks_per_shard is not None]): + if array_shape is None or chunk_shape is None: raise ValueError( - "array_shape, chunk_shape, and chunks_per_shard must be provided for a new store." + "array_shape and chunk_shape must be provided for a new store." ) + if not isinstance(chunks_per_shard, int) or chunks_per_shard <= 0: raise ValueError("chunks_per_shard must be a positive integer.") + store._initialize_new_root( array_shape, chunk_shape, chunks_per_shard, cid_len ) @@ -408,7 +410,7 @@ async def get( prototype: zarr.core.buffer.BufferPrototype, byte_range: Optional[zarr.abc.store.ByteRequest] = None, ) -> Optional[zarr.core.buffer.Buffer]: - if self._root_obj is None: + if self._root_obj is None or self._cid_len is None: raise RuntimeError("Load the root object first before accessing data.") chunk_coords = self._parse_chunk_key(key) @@ -509,6 +511,8 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: await self.set_pointer(key, chunk_data_cid_str) # Store the CID in the index async def set_pointer(self, key: str, pointer: str) -> None: + if self._root_obj is None or self._cid_len is None: + raise RuntimeError("Load the root object first before accessing data.") # Ensure the CID (as ASCII bytes) fits in the allocated slot, padding with nulls chunk_data_cid_ascii_bytes = pointer.encode("ascii") if len(chunk_data_cid_ascii_bytes) > self._cid_len: @@ -556,7 +560,7 @@ async def set_pointer(self, key: str, pointer: str) -> None: # then _dirty_root will be set by flush(). async def exists(self, key: str) -> bool: - if self._root_obj is None: + if self._root_obj is None or self._cid_len is None: raise RuntimeError( "Root object not loaded. Call _load_root_from_cid() first." ) From b0268ad3d1fedc53be70a5a428a41a917c2e42b0 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:42:00 -0400 Subject: [PATCH 18/74] fix: more type cleanups --- py_hamt/sharded_zarr_store.py | 14 +++++++++++++- py_hamt/store.py | 3 --- tests/test_sharded_zarr_store.py | 1 + 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index b43969a..089ccf9 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -264,6 +264,8 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: return None def _get_linear_chunk_index(self, chunk_coords: Tuple[int, ...]) -> int: + if self._chunks_per_dim is None: + raise ValueError("Chunks per dimension not set") linear_index = 0 multiplier = 1 # Convert N-D chunk coordinates to a flat 1-D index (row-major order) @@ -324,6 +326,10 @@ async def _load_or_initialize_shard_cache(self, shard_idx: int) -> bytearray: raise RuntimeError( "Store not initialized: _cid_len is None for shard initialization." ) + if self._chunks_per_shard is None: + raise RuntimeError( + "Store not initialized: _chunks_per_shard is None for shard initialization." + ) # New shard or shard not yet written, initialize with zeros shard_size_bytes = self._chunks_per_shard * self._cid_len self._shard_data_cache[shard_idx] = bytearray( @@ -624,6 +630,8 @@ async def delete(self, key: str) -> None: raise ValueError("Cannot delete from a read-only store.") if self._root_obj is None: raise RuntimeError("Store not initialized for deletion (root_obj is None).") + if self._cid_len is None: + raise RuntimeError("Store not initialized properly; _cid_len is missing.") chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata @@ -679,6 +687,10 @@ def supports_listing(self) -> bool: return True async def list(self) -> AsyncIterator[str]: + if self._root_obj is None: + raise RuntimeError( + "Root object not loaded. Call _load_root_from_cid() first." + ) for key in self._root_obj.get("metadata", {}): yield key @@ -763,7 +775,7 @@ async def pin_entire_dataset( if all(b == 0 for b in cid_bytes): # Skip null/empty CID slots continue - print(f"Processing chunk CID bytes: {cid_bytes}") + print(f"Processing chunk CID bytes: {cid_bytes!r}") chunk_cid_str = cid_bytes.decode("ascii").rstrip("\x00") if chunk_cid_str: diff --git a/py_hamt/store.py b/py_hamt/store.py index e8cb90f..149fc57 100644 --- a/py_hamt/store.py +++ b/py_hamt/store.py @@ -355,7 +355,6 @@ async def load( async def pin_cid( self, cid: CID, - name: Optional[str] = None, target_rpc: str = "http://127.0.0.1:5001", ) -> None: """ @@ -368,8 +367,6 @@ async def pin_cid( name (Optional[str]): An optional name for the pin. """ params = {"arg": str(cid), "recursive": "false"} - if name: - params["name"] = name pin_add_url_base: str = f"{target_rpc}/api/v0/pin/add" async with self._sem: # throttle RPC diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index ca56512..bd4224d 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -684,6 +684,7 @@ async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, st ) assert store._total_chunks == 0 assert store._num_shards == 0 + assert store._root_obj is not None assert len(store._root_obj["chunks"]["shard_cids"]) == 0 # Line 163 root_cid = await store.flush() From 092a5387e06b50ba469218f3ac9f78f803ce15c4 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:43:05 -0400 Subject: [PATCH 19/74] fix: remove unused --- tests/test_sharded_zarr_store.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index bd4224d..6169811 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -1,6 +1,3 @@ -import asyncio -import math - import numpy as np import pandas as pd import pytest @@ -9,8 +6,7 @@ import zarr.core.buffer import dag_cbor -from py_hamt import HAMT, KuboCAS, ShardedZarrStore -from py_hamt.zarr_hamt_store import ZarrHAMTStore +from py_hamt import KuboCAS, ShardedZarrStore @pytest.fixture(scope="module") @@ -22,7 +18,6 @@ def random_zarr_dataset(): lons = np.linspace(-180, 180, 36) temp = np.random.randn(len(times), len(lats), len(lons)) - precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons))) ds = xr.Dataset( { @@ -686,7 +681,6 @@ async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, st assert store._num_shards == 0 assert store._root_obj is not None assert len(store._root_obj["chunks"]["shard_cids"]) == 0 # Line 163 - root_cid = await store.flush() # Test invalid manifest version invalid_root_obj = { From dec826ee3bd4c51cb9a8c5540c2a6a214bca758b Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:49:19 -0400 Subject: [PATCH 20/74] fix: change imports --- py_hamt/__init__.py | 3 +- py_hamt/hamt_to_sharded_converter.py | 9 +++--- py_hamt/manage_pins.py | 3 +- py_hamt/sharded_zarr_store.py | 11 +++---- tests/test_converter.py | 45 +++++++++++++--------------- tests/test_sharded_zarr_pinning.py | 1 + tests/test_sharded_zarr_store.py | 4 +-- 7 files changed, 37 insertions(+), 39 deletions(-) diff --git a/py_hamt/__init__.py b/py_hamt/__init__.py index 918bec7..5394757 100644 --- a/py_hamt/__init__.py +++ b/py_hamt/__init__.py @@ -1,9 +1,8 @@ from .encryption_hamt_store import SimpleEncryptedZarrHAMTStore from .hamt import HAMT, blake3_hashfn +from .sharded_zarr_store import ShardedZarrStore from .store import ContentAddressedStore, InMemoryCAS, KuboCAS from .zarr_hamt_store import ZarrHAMTStore -from .sharded_zarr_store import ShardedZarrStore -from .hamt_to_sharded_converter import convert_hamt_to_sharded, sharded_converter_cli __all__ = [ "blake3_hashfn", diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index a43f3e5..c4ef9f5 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -1,13 +1,12 @@ import argparse import asyncio -import json import time -from typing import Dict, Any -from py_hamt import HAMT, KuboCAS, ShardedZarrStore -from py_hamt.zarr_hamt_store import ZarrHAMTStore + import xarray as xr from multiformats import CID -from zarr.core.buffer import Buffer, BufferPrototype + +from py_hamt import HAMT, KuboCAS, ShardedZarrStore +from py_hamt.zarr_hamt_store import ZarrHAMTStore async def convert_hamt_to_sharded( diff --git a/py_hamt/manage_pins.py b/py_hamt/manage_pins.py index ec7cd40..63d41a1 100644 --- a/py_hamt/manage_pins.py +++ b/py_hamt/manage_pins.py @@ -3,9 +3,10 @@ sharded Zarr dataset on IPFS using its root CID. """ -import asyncio import argparse +import asyncio import sys + from py_hamt import KuboCAS, ShardedZarrStore # --- CLI Logic Functions --- diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 089ccf9..5a54053 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -1,12 +1,13 @@ import asyncio import math from collections.abc import AsyncIterator, Iterable -from typing import Optional, cast, Dict, List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple import dag_cbor import zarr.abc.store import zarr.core.buffer from zarr.core.common import BytesLike + from .store import ContentAddressedStore @@ -828,7 +829,7 @@ async def unpin_entire_dataset( await self.cas.unpin_cid( chunk_cid_str, target_rpc=target_rpc ) - except Exception as e: + except Exception: # ignore continue print( @@ -842,7 +843,7 @@ async def unpin_entire_dataset( # After unpinning all chunks within, unpin the shard itself try: await self.cas.unpin_cid(str(shard_cid), target_rpc=target_rpc) - except Exception as e: + except Exception: print(f"Warning: Could not unpin shard {str(shard_cid)}") print(f"Unpinned shard {shard_cid} from {target_rpc}.") @@ -852,7 +853,7 @@ async def unpin_entire_dataset( try: await self.cas.unpin_cid(cid, target_rpc=target_rpc) print(f"Unpinned metadata CID {cid} from {target_rpc}...") - except Exception as e: + except Exception: print( f"Warning: Could not unpin metadata CID {cid}. Likely already unpinned." ) @@ -862,7 +863,7 @@ async def unpin_entire_dataset( try: await self.cas.unpin_cid(self._root_cid, target_rpc=target_rpc) print(f"Unpinned root CID {self._root_cid} from {target_rpc}...") - except Exception as e: + except Exception: print( f"Warning: Could not unpin root CID {self._root_cid}. Likely already unpinned." ) diff --git a/tests/test_converter.py b/tests/test_converter.py index cdc7eab..c5b865f 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -1,9 +1,7 @@ -import asyncio +import sys import time import uuid -import sys from unittest.mock import patch -import aiohttp import numpy as np import pandas as pd @@ -15,6 +13,8 @@ HAMT, KuboCAS, ShardedZarrStore, +) +from py_hamt.hamt_to_sharded_converter import ( convert_hamt_to_sharded, sharded_converter_cli, ) @@ -179,7 +179,7 @@ async def test_hamt_to_sharded_cli_success( captured = capsys.readouterr() assert "Starting Conversion from HAMT Root" in captured.out assert "Conversion Complete!" in captured.out - assert f"New ShardedZarrStore Root CID" in captured.out + assert "New ShardedZarrStore Root CID" in captured.out # Step 4: Verify the converted dataset # Extract the new root CID from output (assuming it's the last line) @@ -246,23 +246,20 @@ async def test_hamt_to_sharded_cli_invalid_cid(create_ipfs: tuple[str, str], cap rpc_base_url, gateway_base_url = create_ipfs invalid_cid = "invalid_cid" - async with KuboCAS( - rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url - ) as kubo_cas: - test_args = [ - "script.py", - invalid_cid, - "--chunks-per-shard", - "64", - "--rpc-url", - rpc_base_url, - "--gateway-url", - gateway_base_url, - ] - with patch.object(sys, "argv", test_args): - await sharded_converter_cli() - - # Verify error handling - captured = capsys.readouterr() - assert "An error occurred" in captured.out - assert f"{invalid_cid}" in captured.out + test_args = [ + "script.py", + invalid_cid, + "--chunks-per-shard", + "64", + "--rpc-url", + rpc_base_url, + "--gateway-url", + gateway_base_url, + ] + with patch.object(sys, "argv", test_args): + await sharded_converter_cli() + + # Verify error handling + captured = capsys.readouterr() + assert "An error occurred" in captured.out + assert f"{invalid_cid}" in captured.out diff --git a/tests/test_sharded_zarr_pinning.py b/tests/test_sharded_zarr_pinning.py index 27c5b01..c9b1bd0 100644 --- a/tests/test_sharded_zarr_pinning.py +++ b/tests/test_sharded_zarr_pinning.py @@ -1,4 +1,5 @@ import asyncio + import aiohttp import numpy as np import pandas as pd diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 6169811..5b2a61c 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -1,10 +1,10 @@ +import dag_cbor import numpy as np import pandas as pd import pytest import xarray as xr -from zarr.abc.store import RangeByteRequest import zarr.core.buffer -import dag_cbor +from zarr.abc.store import RangeByteRequest from py_hamt import KuboCAS, ShardedZarrStore From 3834667bb22d7502b1339f3610cb1c410717d8e2 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 16 Jun 2025 07:13:23 -0400 Subject: [PATCH 21/74] fix: fix ruff --- py_hamt/sharded_zarr_store.py | 2 +- py_hamt/store_httpx.py | 13 +++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 5a54053..98a32a4 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -8,7 +8,7 @@ import zarr.core.buffer from zarr.core.common import BytesLike -from .store import ContentAddressedStore +from .store_httpx import ContentAddressedStore class ShardedZarrStore(zarr.abc.store.Store): diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index 1c7cab1..fc5c071 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -1,7 +1,8 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any, Dict, Literal, Optional, Tuple, cast +from typing import Any, Dict, Literal, Optional, cast +import aiohttp import httpx from dag_cbor.ipld import IPLDKind from multiformats import CID, multihash @@ -343,7 +344,7 @@ async def load( """@private""" cid = cast(CID, id) url: str = self.gateway_base_url + str(cid) - headers: dict[str, str] = {} + headers: Dict[str, str] = {} # Construct the Range header if required if offset is not None: @@ -387,9 +388,7 @@ async def pin_cid( async with self._sem: # throttle RPC client = self._loop_client() - response = await client.post( - pin_add_url_base, params=params - ) + response = await client.post(pin_add_url_base, params=params) response.raise_for_status() # async with self._loop_session().post( @@ -411,9 +410,7 @@ async def unpin_cid( unpin_url_base: str = f"{target_rpc}/api/v0/pin/rm" async with self._sem: # throttle RPC client = self._loop_client() - response = await client.post( - unpin_url_base, params=params - ) + response = await client.post(unpin_url_base, params=params) response.raise_for_status() # async with self._loop_session().post(unpin_url_base, params=params) as resp: # resp.raise_for_status() From 87e9085cf322f95c0c569647f5f6cb7ae9af0f8f Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 16 Jun 2025 11:33:04 -0400 Subject: [PATCH 22/74] fix: remove aiohttp --- py_hamt/store_httpx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index fc5c071..e020de8 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Literal, Optional, cast -import aiohttp import httpx from dag_cbor.ipld import IPLDKind from multiformats import CID, multihash @@ -174,7 +173,7 @@ def __init__( concurrency: int = 32, *, headers: dict[str, str] | None = None, - auth: aiohttp.BasicAuth | None = None, + auth: Tuple[str, str] | None = None, pinOnAdd: bool = False, ): """ From 5eeaedfc045790919e7d6eb38632a3318108ce51 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 16 Jun 2025 11:35:00 -0400 Subject: [PATCH 23/74] fix: add tuple --- py_hamt/store_httpx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index e020de8..94127ef 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -1,6 +1,6 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any, Dict, Literal, Optional, cast +from typing import Any, Dict, Literal, Optional, Tuple, cast import httpx from dag_cbor.ipld import IPLDKind From 4991475fd98393dad5b8323d69e3379c18a39108 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Tue, 17 Jun 2025 08:10:12 -0400 Subject: [PATCH 24/74] fix: ruff and version --- py_hamt/store_httpx.py | 7 ++++--- uv.lock | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index 033f3f3..6ae9a4c 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -220,7 +220,6 @@ def __init__( if gateway_base_url is None: gateway_base_url = KuboCAS.KUBO_DEFAULT_LOCAL_GATEWAY_BASE_URL - if "/ipfs/" in gateway_base_url: gateway_base_url = gateway_base_url.split("/ipfs/")[0] @@ -231,8 +230,10 @@ def __init__( gateway_base_url = f"{gateway_base_url}/ipfs/" pinString: str = "true" if pinOnAdd else "false" - - self.rpc_url: str = f"{rpc_base_url}/api/v0/add?hash={self.hasher}&pin={pinString}" + + self.rpc_url: str = ( + f"{rpc_base_url}/api/v0/add?hash={self.hasher}&pin={pinString}" + ) """@private""" self.gateway_base_url: str = gateway_base_url """@private""" diff --git a/uv.lock b/uv.lock index 60815fc..3428102 100644 --- a/uv.lock +++ b/uv.lock @@ -1504,7 +1504,7 @@ wheels = [ [[package]] name = "py-hamt" -version = "3.1.0" +version = "3.2.0" source = { editable = "." } dependencies = [ { name = "dag-cbor" }, From 573845b9e6edf275a088bbbdd72e9905e91e063c Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Tue, 17 Jun 2025 10:18:07 -0400 Subject: [PATCH 25/74] fix: remove aiohttp --- tests/test_sharded_zarr_pinning.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_sharded_zarr_pinning.py b/tests/test_sharded_zarr_pinning.py index c9b1bd0..bbd6a31 100644 --- a/tests/test_sharded_zarr_pinning.py +++ b/tests/test_sharded_zarr_pinning.py @@ -1,6 +1,6 @@ import asyncio -import aiohttp +import httpx import numpy as np import pandas as pd import pytest @@ -14,11 +14,11 @@ async def get_pinned_cids(rpc_base_url: str) -> set[str]: """Queries the Kubo RPC API and returns a set of all pinned CIDs.""" url = f"{rpc_base_url}/api/v0/pin/ls" try: - async with aiohttp.ClientSession() as session: - async with session.post(url, params={"type": "all"}) as resp: - resp.raise_for_status() - data = await resp.json() - return set(data.get("Keys", {}).keys()) + async with httpx.AsyncClient() as client: + resp = await client.post(url, params={"type": "all"}) + resp.raise_for_status() # Raises an exception for 4xx/5xx status codes + data = resp.json() + return set(data.get("Keys", {}).keys()) except Exception as e: pytest.fail(f"Failed to query pinned CIDs from Kubo RPC API: {e}") return set() From 255e10e91bc7b11e434fe30eb7e70759346402b5 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Tue, 17 Jun 2025 10:25:01 -0400 Subject: [PATCH 26/74] fix: remove -s --- .github/workflows/run-checks.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run-checks.yaml b/.github/workflows/run-checks.yaml index c1fedeb..439dd5f 100644 --- a/.github/workflows/run-checks.yaml +++ b/.github/workflows/run-checks.yaml @@ -43,7 +43,7 @@ jobs: run_daemon: true - name: Run pytest with coverage - run: uv run pytest --ipfs --cov=py_hamt tests/ --cov-report=xml -s + run: uv run pytest --ipfs --cov=py_hamt tests/ --cov-report=xml - name: Upload coverage reports to Codecov uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24 # v5 From 888a13950612654063001e7922c89bf892f26949 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 19 Jun 2025 16:35:21 -0400 Subject: [PATCH 27/74] fix: more changes --- py_hamt/__init__.py | 2 - py_hamt/sharded_zarr_store.py | 326 ++++++++++++++++++++++++---------- 2 files changed, 232 insertions(+), 96 deletions(-) diff --git a/py_hamt/__init__.py b/py_hamt/__init__.py index 5419328..aba37df 100644 --- a/py_hamt/__init__.py +++ b/py_hamt/__init__.py @@ -17,5 +17,3 @@ "convert_hamt_to_sharded", "sharded_converter_cli", ] - -print("Running py-hamt from source!") diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 98a32a4..70f60a1 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -2,6 +2,9 @@ import math from collections.abc import AsyncIterator, Iterable from typing import Dict, List, Optional, Set, Tuple +import json +import itertools + import dag_cbor import zarr.abc.store @@ -65,6 +68,27 @@ def __init__( self._dirty_root = False # Indicates if the root object itself (metadata or shard_cids list) changed + + def _update_geometry(self): + """Calculates derived geometric properties from the base shapes.""" + if self._array_shape is None or self._chunk_shape is None or self._chunks_per_shard is None: + raise RuntimeError("Base shape information is not set.") + + if not all(cs > 0 for cs in self._chunk_shape): + raise ValueError("All chunk_shape dimensions must be positive.") + if not all(s >= 0 for s in self._array_shape): + raise ValueError("All array_shape dimensions must be non-negative.") + + self._chunks_per_dim = tuple( + math.ceil(a / c) if c > 0 else 0 for a, c in zip(self._array_shape, self._chunk_shape) + ) + self._total_chunks = math.prod(self._chunks_per_dim) + + if self._total_chunks == 0: + self._num_shards = 0 + else: + self._num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) + @classmethod async def open( cls, @@ -108,25 +132,10 @@ def _initialize_new_root( ): self._array_shape = array_shape self._chunk_shape = chunk_shape - self._cid_len = cid_len self._chunks_per_shard = chunks_per_shard + self._cid_len = cid_len - if not all(cs > 0 for cs in chunk_shape): - raise ValueError("All chunk_shape dimensions must be positive.") - if not all( - asarray_s >= 0 for asarray_s in array_shape - ): # array_shape dims can be 0 - raise ValueError("All array_shape dimensions must be non-negative.") - - self._chunks_per_dim = tuple( - math.ceil(a / c) if c > 0 else 0 for a, c in zip(array_shape, chunk_shape) - ) - self._total_chunks = math.prod(self._chunks_per_dim) - - if self._total_chunks == 0: - self._num_shards = 0 - else: - self._num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) + self._update_geometry() self._root_obj = { "manifest_version": "sharded_zarr_v1", @@ -157,24 +166,9 @@ async def _load_root_from_cid(self): self._array_shape = tuple(chunk_info["array_shape"]) self._chunk_shape = tuple(chunk_info["chunk_shape"]) self._cid_len = chunk_info["cid_byte_length"] - sharding_cfg = chunk_info.get( - "sharding_config", {} - ) # Handle older formats if any planned - self._chunks_per_shard = sharding_cfg["chunks_per_shard"] + self._chunks_per_shard = chunk_info["sharding_config"]["chunks_per_shard"] - if not all(cs > 0 for cs in self._chunk_shape): - raise ValueError("Loaded chunk_shape dimensions must be positive.") - - self._chunks_per_dim = tuple( - math.ceil(a / c) if c > 0 else 0 - for a, c in zip(self._array_shape, self._chunk_shape) - ) - self._total_chunks = math.prod(self._chunks_per_dim) - - expected_num_shards = 0 - if self._total_chunks > 0: - expected_num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) - self._num_shards = expected_num_shards + self._update_geometry() if len(chunk_info["shard_cids"]) != self._num_shards: raise ValueError( @@ -235,6 +229,9 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # Now, proceed with the original parsing logic using self._array_shape and # self._chunks_per_dim, which should be configured for this main data variable. + # print( + # f"Parsing chunk key: {key} for array: {actual_array_name} with shape: {self._array_shape} and chunks_per_dim: {self._chunks_per_dim}") + if not self._array_shape or not self._chunks_per_dim: # This ShardedZarrStore instance is not properly initialized # with the shape/chunking info for the array it's supposed to manage. @@ -373,6 +370,23 @@ async def flush(self) -> str: if self._root_obj is None: # Should be initialized for a writable store raise RuntimeError("Store not initialized for writing: _root_obj is None.") + # Update the array_shape and chunk_shape in the root object based on current state in the zarr.json + # Fetch the current array_shape and chunk_shape from root_ob + # TODO: + # zarr_json_cid = self._root_obj.get("metadata", {}).get("zarr.json", {}) + # if zarr_json_cid: + # zarr_json_bytes = await self.cas.load(zarr_json_cid) + # zarr_json = json.loads(zarr_json_bytes.decode("utf-8")) + # consolidated_metadata = zarr_json.get("consolidated_metadata", {}) + + # print("ZArr jSON bytes", zarr_json) + + # else: + # raise ValueError("Zarr JSON metadata CID not found.") + # print(self._array_shape, self._chunk_shape) + # self._root_obj["chunks"]["array_shape"] = list(self._array_shape) + # self._root_obj["chunks"]["chunk_shape"] = list(self._chunk_shape) + # Save all dirty shards first, as their CIDs might need to go into the root object if self._dirty_shards: for shard_idx in sorted(list(self._dirty_shards)): @@ -419,6 +433,7 @@ async def get( ) -> Optional[zarr.core.buffer.Buffer]: if self._root_obj is None or self._cid_len is None: raise RuntimeError("Load the root object first before accessing data.") + # print('Getting key', key) chunk_coords = self._parse_chunk_key(key) # Metadata request (e.g., ".json") @@ -508,6 +523,24 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: raise RuntimeError( "Store not initialized for writing (root_obj is None). Call open() first." ) + # print('Setting key', key) + # Find the data variable and update its chunk data + if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: + # extract the metadata from the value + # and store it in the root_obj["metadata"] dict + converted_value = value.to_bytes().decode("utf-8") + # Read the json + metadata_json = json.loads(converted_value) + new_array_shape = metadata_json.get("shape") + if not new_array_shape: + raise ValueError("Shape not found in metadata.") + + if tuple(new_array_shape) != self._array_shape: + print(f"Detected shape change from {self._array_shape} to {tuple(new_array_shape)}. Resizing shard index...") + # Use your existing resize_store method to handle all the recalculations + # and extend the list of shard CIDs. + await self.resize_store(new_shape=tuple(new_array_shape)) + raw_chunk_data_bytes = value.to_bytes() # Save the actual chunk data to CAS first, to get its CID @@ -531,8 +564,9 @@ async def set_pointer(self, key: str, pointer: str) -> None: ) chunk_coords = self._parse_chunk_key(key) + # print(f"Setting for key '{key}': {chunk_coords}") - if chunk_coords is None: # Metadata key (e.g., ".zarray") + if chunk_coords is None: # Metadata key (e.g., ".json") # For metadata, the 'value' is the metadata content itself, not a CID to it. # So, we store the metadata content, get its CID, and put *that* CID in root_obj. # This means the `value_cid_str` for metadata should be from `raw_chunk_data_bytes`. @@ -633,6 +667,7 @@ async def delete(self, key: str) -> None: raise RuntimeError("Store not initialized for deletion (root_obj is None).") if self._cid_len is None: raise RuntimeError("Store not initialized properly; _cid_len is missing.") + # print(f"Deleting key: {key}") chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata @@ -670,9 +705,7 @@ async def delete(self, key: str) -> None: break if is_already_zero: - raise KeyError( - f"Chunk key '{key}' not found or already effectively deleted (CID slot is zeroed)." - ) + self._dirty_shards.add(shard_idx) # Zero out the CID entry in the shard cache for i in range(self._cid_len): @@ -692,7 +725,7 @@ async def list(self) -> AsyncIterator[str]: raise RuntimeError( "Root object not loaded. Call _load_root_from_cid() first." ) - for key in self._root_obj.get("metadata", {}): + for key in list(self._root_obj.get("metadata", {})): yield key async def list_prefix(self, prefix: str) -> AsyncIterator[str]: @@ -700,6 +733,133 @@ async def list_prefix(self, prefix: str) -> AsyncIterator[str]: if key.startswith(prefix): yield key + async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, ...]): + """ + Performs a high-performance, metadata-only append by "grafting" the + chunk CIDs from another store into this store at a given offset. + + Args: + store_to_graft_cid: The root CID of the Zarr store whose chunks will be copied. + chunk_offset: A tuple defining the starting chunk coordinates in the target store. + e.g., (3, 0, 0) to start at the 4th time chunk. + """ + if self.read_only: + raise ValueError("Cannot graft onto a read-only store.") + if self._root_obj is None: + raise RuntimeError("Main store must be initialized before grafting.") + + print(f"Grafting store {store_to_graft_cid[:10]}... at chunk offset {chunk_offset}") + + # 1. Open the store we want to copy chunks from (read-only) + store_to_graft = await ShardedZarrStore.open(cas=self.cas, read_only=True, root_cid=store_to_graft_cid) + if store_to_graft._root_obj is None: + raise ValueError("Store to graft could not be loaded.") + if store_to_graft._chunks_per_dim is None or self._cid_len is None: + raise ValueError("Store to graft is not properly configured.") + + source_shard_cache: Dict[int, bytes] = {} + + source_chunk_grid = store_to_graft._chunks_per_dim + for local_coords in itertools.product(*[range(s) for s in source_chunk_grid]): + + # 3. Get the pointer (CID) for each chunk from the source store + linear_local_index = store_to_graft._get_linear_chunk_index(local_coords) + local_shard_idx, index_in_local_shard = store_to_graft._get_shard_info(linear_local_index) + + if local_shard_idx not in source_shard_cache: + source_shard_cid = store_to_graft._root_obj["chunks"]["shard_cids"][local_shard_idx] + if not source_shard_cid: + source_shard_cache[local_shard_idx] = b'' # Mark as loaded but empty + continue + source_shard_cache[local_shard_idx] = await self.cas.load(source_shard_cid) + + source_shard_data = source_shard_cache[local_shard_idx] + + if not source_shard_data: + continue # This chunk was empty (all fill value) + + # Extract the pointer bytes from the in-memory shard data + offset_in_source_shard = index_in_local_shard * store_to_graft._cid_len + pointer_bytes = source_shard_data[offset_in_source_shard : offset_in_source_shard + store_to_graft._cid_len] + + if all(b == 0 for b in pointer_bytes): + continue # Skip empty CID slots + + # Calculate global coordinates and write to the main store's index + global_coords = tuple(c_local + c_offset for c_local, c_offset in zip(local_coords, chunk_offset)) + linear_global_index = self._get_linear_chunk_index(global_coords) + global_shard_idx, index_in_global_shard = self._get_shard_info(linear_global_index) + + target_shard_cache = await self._load_or_initialize_shard_cache(global_shard_idx) + offset_in_global_shard = index_in_global_shard * self._cid_len + + target_shard_cache[offset_in_global_shard : offset_in_global_shard + self._cid_len] = pointer_bytes + self._dirty_shards.add(global_shard_idx) + + print(f"✓ Grafting complete for store {store_to_graft_cid[:10]}...") + + + + async def resize_store(self, new_shape: Tuple[int, ...]): + """ + Resizes the store's main shard index to accommodate a new overall array shape. + This is a metadata-only operation on the store's root object. + """ + if self.read_only: + raise ValueError("Cannot resize a read-only store.") + if self._root_obj is None or self._chunk_shape is None or self._chunks_per_shard is None: + raise RuntimeError("Store is not properly initialized for resizing.") + if len(new_shape) != len(self._array_shape): + raise ValueError("New shape must have the same number of dimensions as the old shape.") + + self._array_shape = tuple(new_shape) + self._chunks_per_dim = tuple( + math.ceil(a / c) if c > 0 else 0 + for a, c in zip(self._array_shape, self._chunk_shape) + ) + self._total_chunks = math.prod(self._chunks_per_dim) + old_num_shards = self._num_shards if self._num_shards is not None else 0 + self._num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) if self._total_chunks > 0 else 0 + self._root_obj["chunks"]["array_shape"] = list(self._array_shape) + if self._num_shards > old_num_shards: + self._root_obj["chunks"]["shard_cids"].extend([None] * (self._num_shards - old_num_shards)) + elif self._num_shards < old_num_shards: + self._root_obj["chunks"]["shard_cids"] = self._root_obj["chunks"]["shard_cids"][:self._num_shards] + + self._dirty_root = True + print(f"Store's internal shard index resized. New main array shape: {self._array_shape}") + + + async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): + """ + Resizes the Zarr metadata for a specific variable (e.g., '.json' file). + This does NOT change the store's main shard index. + """ + if self.read_only: + raise ValueError("Cannot resize a read-only store.") + if self._root_obj is None: + raise RuntimeError("Store is not properly initialized for resizing.") + + # Zarr v2 uses .json, not zarr.json + zarr_metadata_key = f"{variable_name}/zarr.json" + + old_zarr_metadata_cid = self._root_obj["metadata"].get(zarr_metadata_key) + if not old_zarr_metadata_cid: + raise KeyError(f"Cannot find metadata for key '{zarr_metadata_key}' to resize.") + + old_zarr_metadata_bytes = await self.cas.load(old_zarr_metadata_cid) + zarr_metadata_json = json.loads(old_zarr_metadata_bytes) + + zarr_metadata_json["shape"] = list(new_shape) + + new_zarr_metadata_bytes = json.dumps(zarr_metadata_json, indent=2).encode('utf-8') + # Metadata is a raw blob of bytes + new_zarr_metadata_cid = await self.cas.save(new_zarr_metadata_bytes, codec='raw') + + self._root_obj["metadata"][zarr_metadata_key] = str(new_zarr_metadata_cid) + self._dirty_root = True + print(f"Resized metadata for variable '{variable_name}'. New shape: {new_shape}") + async def list_dir(self, prefix: str) -> AsyncIterator[str]: # This simplified version only works for the root directory (prefix == "") of metadata. # It lists unique first components of metadata keys. @@ -711,7 +871,7 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: seen: Set[str] = set() if prefix == "": async for key in self.list(): # Iterates metadata keys - # e.g., if key is "group1/.zgroup" or "array1/.zarray", first_component is "group1" or "array1" + # e.g., if key is "group1/.zgroup" or "array1/.json", first_component is "group1" or "array1" # if key is ".zgroup", first_component is ".zgroup" first_component = key.split("/", 1)[0] if first_component not in seen: @@ -731,6 +891,24 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: seen.add(child) yield child + async def _iterate_chunk_cids(self, shard_cid) -> AsyncIterator[str]: + """An async generator that yields all non-empty chunk CIDs in the store.""" + if self._root_obj is None or self._cid_len is None: + raise RuntimeError( + "Root object not loaded. Call _load_root_from_cid() first." + ) + if not shard_cid: + return + + try: + shard_data = await self.cas.load(shard_cid) + for i in range(0, len(shard_data), self._cid_len): + cid_bytes = shard_data[i : i + self._cid_len] + if not all(b == 0 for b in cid_bytes): + yield cid_bytes.decode("ascii").rstrip("\x00") + except Exception as e: + logging.warning(f"Could not process shard {shard_cid} for iteration: {e}") + async def pin_entire_dataset( self, target_rpc: str = "http://127.0.0.1:5001" ) -> None: @@ -763,35 +941,12 @@ async def pin_entire_dataset( continue # Pin the shard itself - print(f"Pinning shard {shard_cid} to {target_rpc}...") await self.cas.pin_cid(shard_cid, target_rpc=target_rpc) - - try: - # Load shard data to find and pin the chunk CIDs within - shard_data = await self.cas.load(shard_cid) - - chunks_pinned = 0 - for i in range(0, len(shard_data), self._cid_len): - cid_bytes = shard_data[i : i + self._cid_len] - - if all(b == 0 for b in cid_bytes): # Skip null/empty CID slots - continue - print(f"Processing chunk CID bytes: {cid_bytes!r}") - - chunk_cid_str = cid_bytes.decode("ascii").rstrip("\x00") - if chunk_cid_str: - await self.cas.pin_cid(chunk_cid_str, target_rpc=target_rpc) - chunks_pinned += 1 - print(f"Pinned {chunks_pinned} chunk CIDs in shard {shard_cid}.") - print( - f"Total shards processed: {index + 1}/{len(self._root_obj['chunks']['shard_cids'])}" - ) - # Print progress based on amount of shards processed - # Catch any exceptions during shard loading or pinning - except Exception as e: - print( - f"Warning: Could not load or process shard {shard_cid} for pinning: {e}" - ) + chunks_pinned = 0 + async for chunk_cid in self._iterate_chunk_cids(shard_cid): + if chunk_cid: + chunks_pinned += 1 + await self.cas.pin_cid(chunk_cid, target_rpc=target_rpc) async def unpin_entire_dataset( self, target_rpc: str = "http://127.0.0.1:5001" @@ -814,33 +969,16 @@ async def unpin_entire_dataset( for shard_cid in self._root_obj["chunks"]["shard_cids"]: if not shard_cid: continue - - try: - shard_data = await self.cas.load(shard_cid) - # Iterate through the packed CIDs in the shard data - for i in range(0, len(shard_data), self._cid_len): - cid_bytes = shard_data[i : i + self._cid_len] - if all(b == 0 for b in cid_bytes): + print(f"Unpinning shard {shard_cid} from {target_rpc}...") + chunks_pinned = 0 + async for chunk_cid in self._iterate_chunk_cids(shard_cid): + if chunk_cid: + chunks_pinned += 1 + try: + await self.cas.unpin_cid(chunk_cid, target_rpc=target_rpc) + except Exception: + print(f"Warning: Could not unpin chunk CID {chunk_cid}. Likely already unpinned.") continue - - chunk_cid_str = cid_bytes.decode("ascii").rstrip("\x00") - if chunk_cid_str: - try: - await self.cas.unpin_cid( - chunk_cid_str, target_rpc=target_rpc - ) - except Exception: - # ignore - continue - print( - f"Unpinned all chunk CIDs in shard {shard_cid} from {target_rpc}." - ) - except Exception as e: - # Log error but continue to attempt to unpin the shard itself - print( - f"Warning: Could not load or process chunks in shard {str(shard_cid)} for unpinning: {e}" - ) - # After unpinning all chunks within, unpin the shard itself try: await self.cas.unpin_cid(str(shard_cid), target_rpc=target_rpc) except Exception: From 1288f55f1d7e8570c1ca602a645dce0b0ae53e3e Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 26 Jun 2025 03:50:37 -0400 Subject: [PATCH 28/74] fix: test era5 --- py_hamt/sharded_zarr_store.py | 2 +- tests/test_cpc_compare.py | 153 ++++++++++++++++++---------------- 2 files changed, 84 insertions(+), 71 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 70f60a1..49d450e 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -433,7 +433,7 @@ async def get( ) -> Optional[zarr.core.buffer.Buffer]: if self._root_obj is None or self._cid_len is None: raise RuntimeError("Load the root object first before accessing data.") - # print('Getting key', key) + print('Getting key', key) chunk_coords = self._parse_chunk_key(key) # Metadata request (e.g., ".json") diff --git a/tests/test_cpc_compare.py b/tests/test_cpc_compare.py index 40df647..7de9807 100644 --- a/tests/test_cpc_compare.py +++ b/tests/test_cpc_compare.py @@ -1,73 +1,86 @@ -# import time - -# import numpy as np -# import pandas as pd -# import pytest -# import xarray as xr -# from dag_cbor.ipld import IPLDKind -# from multiformats import CID - -# # Import both store implementations -# from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore -# from py_hamt.zarr_hamt_store import ZarrHAMTStore - - -# @pytest.mark.asyncio(loop_scope="session") -# async def test_benchmark_sharded_store(): -# """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring -# print("\n\n" + "=" * 80) -# print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print -# print("=" * 80) - - -# rpc_base_url = f"https://ipfs-gateway.dclimate.net" -# gateway_base_url = f"https://ipfs-gateway.dclimate.net" -# headers = { -# "X-API-Key": "", -# } - -# async with KuboCAS( -# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers -# ) as kubo_cas: -# # --- Write --- -# root_cid = "bafyr4ifjgdfafxfqtdkirmdyzlziswzo5gsxbrivqjzu35ukiixnu2omvm" -# print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print -# # --- Read --- -# # When opening for read, chunks_per_shard is read from the store's metadata -# store_read = await ShardedZarrStore.open( # Use ShardedZarrStore -# cas=kubo_cas, read_only=True, root_cid=root_cid -# ) -# print(f"Opened ShardedZarrStore for reading with root CID: {root_cid}") - -# start_read = time.perf_counter() -# ipfs_ds = xr.open_zarr(store=store_read) -# # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) -# if "precip" in ipfs_ds.variables and "time" in ipfs_ds.coords: -# # _ = ipfs_ds.temp.isel(time=0).values -# data_fetched = ipfs_ds.precip.values - -# # Calculate the size of the fetched data -# data_size = data_fetched.nbytes if data_fetched is not None else 0 -# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") -# elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable -# first_var_name = list(ipfs_ds.data_vars.keys())[0] -# # Construct a minimal selection based on available dimensions -# selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} -# if selection: -# _ = ipfs_ds[first_var_name].isel(**selection).values -# else: # If no dimensions, try loading the whole variable (e.g. scalar) -# _ = ipfs_ds[first_var_name].values -# end_read = time.perf_counter() - -# print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print -# print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") -# print("=" * 80) -# # Speed in MB/s -# if data_size > 0: -# speed = data_size / (end_read - start_read) / (1024 * 1024) -# print(f"Read speed: {speed:.2f} MB/s") -# else: -# print("No data fetched, cannot calculate speed.") +import time + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from dag_cbor.ipld import IPLDKind +from multiformats import CID + +# Import both store implementations +from py_hamt import HAMT, KuboCAS, ShardedZarrStore +from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +@pytest.mark.asyncio(loop_scope="session") +async def test_benchmark_sharded_store(): + """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring + print("\n\n" + "=" * 80) + print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print + print("=" * 80) + + + rpc_base_url = f"https://ipfs-gateway.dclimate.net" + gateway_base_url = f"https://ipfs-gateway.dclimate.net" + headers = { + "X-API-Key": "", + } + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers + ) as kubo_cas: + # --- Write --- + root_cid = "bafyr4ienfetuujjqeqhrjvtr6dpcfh2bdowxrofsgl6dz5oknqauhxicie" + print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print + # --- Read --- + start = time.perf_counter() + # When opening for read, chunks_per_shard is read from the store's metadata + store_read = await ShardedZarrStore.open( # Use ShardedZarrStore + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + stop = time.perf_counter() + print(f"Total time to open ShardedZarrStore: {stop - start:.2f} seconds") + print(f"Opened ShardedZarrStore for reading with root CID: {root_cid}") + + start_read = time.perf_counter() + ipfs_ds = xr.open_zarr(store=store_read) + start_read = time.perf_counter() + print(ipfs_ds) + stop_read = time.perf_counter() + print(f"Total time to read dataset: {stop_read - start_read:.2f} seconds") + # start_read = time.perf_counter() + # print(ipfs_ds.variables, ipfs_ds.coords) # Print available variables and coordinates for debugging + # stop_read = time.perf_counter() + # print(f"Total time to read dataset variables and coordinates: {stop_read - start_read:.2f} seconds") + start_read = time.perf_counter() + # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) + if "2m_temperature" in ipfs_ds.variables and "time" in ipfs_ds.coords: + print("Fetching '2m_temperature' data...") + data_fetched = ipfs_ds["2m_temperature"].isel(time=0).values + # data_fetched = ipfs_ds["2m_temperature"].values + + # Calculate the size of the fetched data + data_size = data_fetched.nbytes if data_fetched is not None else 0 + print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") + elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable + first_var_name = list(ipfs_ds.data_vars.keys())[0] + # Construct a minimal selection based on available dimensions + selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} + if selection: + _ = ipfs_ds[first_var_name].isel(**selection).values + else: # If no dimensions, try loading the whole variable (e.g. scalar) + _ = ipfs_ds[first_var_name].values + end_read = time.perf_counter() + + print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print + print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") + print("=" * 80) + # Speed in MB/s + if data_size > 0: + speed = data_size / (end_read - start_read) / (1024 * 1024) + print(f"Read speed: {speed:.2f} MB/s") + else: + print("No data fetched, cannot calculate speed.") # # ### # # BENCHMARK FOR THE ORIGINAL ZarrHAMTStore From 4c92d2c1194afb60f0d62e406438310428eac883 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Fri, 27 Jun 2025 02:24:22 -0400 Subject: [PATCH 29/74] Update test_sharded_zarr_store.py --- tests/test_sharded_zarr_store.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 5b2a61c..2778f4d 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -250,14 +250,6 @@ async def test_chunk_and_delete_logic( assert not await store_after_delete.exists(chunk_key) assert await store_after_delete.get(chunk_key, proto) is None - # Test deleting a non-existent key - with pytest.raises(KeyError): - await store_rw.delete("nonexistent/c/0/0/0") - - # Test deleting an already deleted key - with pytest.raises(KeyError): - await store_rw.delete(chunk_key) - @pytest.mark.asyncio async def test_sharded_zarr_store_partial_reads( From 309ea4a8f9387068acb201886550c9a3e40a891b Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Fri, 27 Jun 2025 07:27:02 -0400 Subject: [PATCH 30/74] fix: logging --- py_hamt/sharded_zarr_store.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 49d450e..5cc5304 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -910,7 +910,7 @@ async def _iterate_chunk_cids(self, shard_cid) -> AsyncIterator[str]: logging.warning(f"Could not process shard {shard_cid} for iteration: {e}") async def pin_entire_dataset( - self, target_rpc: str = "http://127.0.0.1:5001" + self, target_rpc: str = "http://127.0.0.1:5001", increment: int = 100 ) -> None: """ Pins the entire dataset in the CAS, ensuring the root, metadata, shards, @@ -947,6 +947,10 @@ async def pin_entire_dataset( if chunk_cid: chunks_pinned += 1 await self.cas.pin_cid(chunk_cid, target_rpc=target_rpc) + if chunks_pinned % increment == 0: + print( + f"Pinned {chunks_pinned} chunks in shard {index}..." + ) async def unpin_entire_dataset( self, target_rpc: str = "http://127.0.0.1:5001" From be19965599966e31d71f6f231ccbc6341d4d359f Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Fri, 27 Jun 2025 08:31:57 -0400 Subject: [PATCH 31/74] fix: async pinning, chunker --- py_hamt/sharded_zarr_store.py | 134 +++++++++++++++++++--------------- py_hamt/store_httpx.py | 14 ++-- 2 files changed, 83 insertions(+), 65 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 5cc5304..1657086 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -910,102 +910,116 @@ async def _iterate_chunk_cids(self, shard_cid) -> AsyncIterator[str]: logging.warning(f"Could not process shard {shard_cid} for iteration: {e}") async def pin_entire_dataset( - self, target_rpc: str = "http://127.0.0.1:5001", increment: int = 100 + self, + target_rpc: str = "http://127.0.0.1:5001", + concurrency_limit: int = 50, + show_progress: bool = True ) -> None: """ - Pins the entire dataset in the CAS, ensuring the root, metadata, shards, - and all data chunks are pinned. This is useful for performance optimization - when the dataset is accessed frequently. + Pins the entire dataset in parallel, with a limit on concurrent requests. """ if self._root_obj is None: - raise RuntimeError( - "Root object not loaded. Call _load_root_from_cid() first." - ) + raise RuntimeError("Root object not loaded.") if self._cid_len is None: - raise RuntimeError( - "Store is not initialized properly; _cid_len is missing." - ) + raise RuntimeError("Store is not initialized properly.") + + # --- 1. First, gather all unique CIDs to pin --- + print("Gathering all CIDs for pinning...") + cids_to_pin: Set[str] = set() - # Pin the root CID itself if self._root_cid: - await self.cas.pin_cid(self._root_cid, target_rpc=target_rpc) + cids_to_pin.add(self._root_cid) - # Pin metadata CIDs for cid in self._root_obj.get("metadata", {}).values(): if cid: - await self.cas.pin_cid(cid, target_rpc=target_rpc) + cids_to_pin.add(cid) - # Pin all shard CIDs and the chunk CIDs within them - for index, shard_cid in enumerate(self._root_obj["chunks"]["shard_cids"]): + for shard_cid in self._root_obj["chunks"]["shard_cids"]: if not shard_cid: continue - - # Pin the shard itself - await self.cas.pin_cid(shard_cid, target_rpc=target_rpc) - chunks_pinned = 0 + cids_to_pin.add(shard_cid) async for chunk_cid in self._iterate_chunk_cids(shard_cid): if chunk_cid: - chunks_pinned += 1 - await self.cas.pin_cid(chunk_cid, target_rpc=target_rpc) - if chunks_pinned % increment == 0: - print( - f"Pinned {chunks_pinned} chunks in shard {index}..." - ) + cids_to_pin.add(chunk_cid) + + total_cids = len(cids_to_pin) + print(f"Found {total_cids} unique CIDs to pin.") + + # --- 2. Create and run pinning tasks in parallel with a semaphore --- + semaphore = asyncio.Semaphore(concurrency_limit) + tasks: list[Coroutine] = [] + + # Helper function to wrap the pin call with the semaphore + async def pin_with_semaphore(cid: str): + async with semaphore: + await self.cas.pin_cid(cid, target_rpc=target_rpc) + if show_progress: + # This progress reporting is approximate as tasks run out of order + # For precise progress, see asyncio.as_completed in the notes below + pass + + for cid in cids_to_pin: + tasks.append(pin_with_semaphore(cid)) + + print(f"Pinning {total_cids} CIDs with a concurrency of {concurrency_limit}...") + # The 'return_exceptions=False' (default) means this will raise the first exception it encounters. + await asyncio.gather(*tasks) + print("Successfully pinned the entire dataset.") + async def unpin_entire_dataset( - self, target_rpc: str = "http://127.0.0.1:5001" + self, + target_rpc: str = "http://127.0.0.1:5001", + concurrency_limit: int = 50 ) -> None: """ - Unpins the entire dataset from the CAS, removing the root, metadata, shards, - and all data chunks from the pin set. This is useful for freeing up storage - resources when the dataset is no longer needed. + Unpins the entire dataset in parallel, with a limit on concurrent requests. """ if self._root_obj is None: - raise RuntimeError( - "Root object not loaded. Call _load_root_from_cid() first." - ) + raise RuntimeError("Root object not loaded.") if self._cid_len is None: - raise RuntimeError( - "Store is not initialized properly; _cid_len is missing." - ) + raise RuntimeError("Store is not initialized properly.") + + # --- 1. Gather all CIDs to unpin (chunks, shards, metadata) --- + print("Gathering all CIDs for unpinning...") + cids_to_unpin: Set[str] = set() - # Unpin all chunk CIDs by reading from shards first for shard_cid in self._root_obj["chunks"]["shard_cids"]: if not shard_cid: continue - print(f"Unpinning shard {shard_cid} from {target_rpc}...") - chunks_pinned = 0 + cids_to_unpin.add(shard_cid) async for chunk_cid in self._iterate_chunk_cids(shard_cid): if chunk_cid: - chunks_pinned += 1 - try: - await self.cas.unpin_cid(chunk_cid, target_rpc=target_rpc) - except Exception: - print(f"Warning: Could not unpin chunk CID {chunk_cid}. Likely already unpinned.") - continue - try: - await self.cas.unpin_cid(str(shard_cid), target_rpc=target_rpc) - except Exception: - print(f"Warning: Could not unpin shard {str(shard_cid)}") - print(f"Unpinned shard {shard_cid} from {target_rpc}.") + cids_to_unpin.add(chunk_cid) - # Unpin metadata CIDs for cid in self._root_obj.get("metadata", {}).values(): if cid: + cids_to_unpin.add(cid) + + # --- 2. Unpin all children in parallel --- + semaphore = asyncio.Semaphore(concurrency_limit) + tasks: list[Coroutine] = [] + + async def unpin_with_semaphore(cid: str): + async with semaphore: try: await self.cas.unpin_cid(cid, target_rpc=target_rpc) - print(f"Unpinned metadata CID {cid} from {target_rpc}...") except Exception: - print( - f"Warning: Could not unpin metadata CID {cid}. Likely already unpinned." - ) + # This is safe to ignore in a mass-unpin operation + print(f"Warning: Could not unpin chunk CID {cid}. Likely already unpinned.") - # Finally, unpin the root CID itself + for cid in cids_to_unpin: + tasks.append(unpin_with_semaphore(cid)) + + print(f"Unpinning {len(tasks)} child CIDs with a concurrency of {concurrency_limit}...") + await asyncio.gather(*tasks) + print("Successfully unpinned all child objects.") + + # --- 3. Finally, unpin the root CID itself after its children --- if self._root_cid: try: + print(f"Unpinning root CID {self._root_cid}...") await self.cas.unpin_cid(self._root_cid, target_rpc=target_rpc) - print(f"Unpinned root CID {self._root_cid} from {target_rpc}...") + print("Successfully unpinned the entire dataset.") except Exception: - print( - f"Warning: Could not unpin root CID {self._root_cid}. Likely already unpinned." - ) + print(f"Warning: Could not unpin root CID {self._root_cid}. Likely already unpinned.") diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index 6ae9a4c..8d0a51c 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -175,6 +175,7 @@ def __init__( headers: dict[str, str] | None = None, auth: Tuple[str, str] | None = None, pinOnAdd: bool = False, + chunker: str | None = None, ): """ If None is passed into the rpc or gateway base url, then the default for kubo local daemons will be used. The default local values will also be used if nothing is passed in at all. @@ -231,9 +232,12 @@ def __init__( pinString: str = "true" if pinOnAdd else "false" - self.rpc_url: str = ( - f"{rpc_base_url}/api/v0/add?hash={self.hasher}&pin={pinString}" - ) + rpc_url = f"{rpc_base_url}/api/v0/add?hash={self.hasher}&pin={pinString}" + if chunker: + rpc_url += f"&chunker={chunker}" + + self.rpc_url = rpc_url + """@private""" self.gateway_base_url: str = gateway_base_url """@private""" @@ -391,7 +395,7 @@ async def pin_cid( cid (CID): The Content ID to pin. name (Optional[str]): An optional name for the pin. """ - params = {"arg": str(cid), "recursive": "false"} + params = {"arg": str(cid), "recursive": "true"} pin_add_url_base: str = f"{target_rpc}/api/v0/pin/add" async with self._sem: # throttle RPC @@ -414,7 +418,7 @@ async def unpin_cid( Args: cid (CID): The Content ID to unpin. """ - params = {"arg": str(cid), "recursive": "false"} + params = {"arg": str(cid), "recursive": "true"} unpin_url_base: str = f"{target_rpc}/api/v0/pin/rm" async with self._sem: # throttle RPC client = self._loop_client() From 79796a6e2c627c8f36aa451cc377b531f0c30a0b Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 2 Jul 2025 01:34:59 -0400 Subject: [PATCH 32/74] fix: dag cbor --- py_hamt/hamt_to_sharded_converter.py | 3 +- py_hamt/manage_pins.py | 117 ----- py_hamt/sharded_zarr_store.py | 681 +++++++-------------------- tests/test_sharded_zarr_pinning.py | 142 ------ 4 files changed, 175 insertions(+), 768 deletions(-) delete mode 100644 py_hamt/manage_pins.py delete mode 100644 tests/test_sharded_zarr_pinning.py diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index c4ef9f5..681bdce 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -10,7 +10,7 @@ async def convert_hamt_to_sharded( - cas: KuboCAS, hamt_root_cid: str, chunks_per_shard: int, cid_len: int = 59 + cas: KuboCAS, hamt_root_cid: str, chunks_per_shard: int ) -> str: """ Converts a Zarr dataset from a HAMT-based store to a ShardedZarrStore. @@ -52,7 +52,6 @@ async def convert_hamt_to_sharded( array_shape=array_shape, chunk_shape=chunk_shape, chunks_per_shard=chunks_per_shard, - cid_len=cid_len, ) print("Destination store initialized.") diff --git a/py_hamt/manage_pins.py b/py_hamt/manage_pins.py deleted file mode 100644 index 63d41a1..0000000 --- a/py_hamt/manage_pins.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -A command-line tool to recursively pin or unpin all CIDs associated with a -sharded Zarr dataset on IPFS using its root CID. -""" - -import argparse -import asyncio -import sys - -from py_hamt import KuboCAS, ShardedZarrStore - -# --- CLI Logic Functions --- - - -async def handle_pin(args): - """ - Connects to IPFS, loads the dataset from the root CID, and pins all - associated CIDs (root, metadata, shards, and data chunks). - """ - async with KuboCAS( - rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url - ) as kubo_cas: - try: - print(f"-> Opening store with root CID: {args.root_cid}") - store = await ShardedZarrStore.open( - cas=kubo_cas, read_only=True, root_cid=args.root_cid - ) - except Exception as e: - print( - f"Error: Failed to open Zarr store for CID {args.root_cid}. Ensure the CID is correct and the daemon is running.", - file=sys.stderr, - ) - print(f"Details: {e}", file=sys.stderr) - return - - print(f"-> Sending commands to pin the entire dataset to {args.rpc_url}...") - await store.pin_entire_dataset() - print("\n--- Pinning Commands Sent Successfully ---") - print("The IPFS node will now pin all objects in the background.") - - -async def handle_unpin(args): - """ - Connects to IPFS, loads the dataset from the root CID, and unpins all - associated CIDs. - """ - async with KuboCAS( - rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url - ) as kubo_cas: - try: - print(f"-> Opening store with root CID: {args.root_cid}") - store = await ShardedZarrStore.open( - cas=kubo_cas, read_only=True, root_cid=args.root_cid - ) - except Exception as e: - print( - f"Error: Failed to open Zarr store for CID {args.root_cid}. Ensure the CID is correct and the daemon is running.", - file=sys.stderr, - ) - print(f"Details: {e}", file=sys.stderr) - return - - print(f"-> Sending commands to unpin the entire dataset from {args.rpc_url}...") - await store.unpin_entire_dataset() - print("\n--- Unpinning Commands Sent Successfully ---") - print("The IPFS node will now unpin all objects in the background.") - - -def main(): - """Sets up the argument parser and runs the selected command.""" - parser = argparse.ArgumentParser( - description="A CLI tool to pin or unpin sharded Zarr datasets on IPFS.", - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.add_argument( - "--rpc-url", - default="http://127.0.0.1:5001", - help="IPFS Kubo RPC API endpoint URL.", - ) - parser.add_argument( - "--gateway-url", - default="http://127.0.0.1:8080", - help="IPFS Gateway URL (needed for loading shards).", - ) - - subparsers = parser.add_subparsers( - dest="command", required=True, help="Available commands" - ) - - # --- Pin Command --- - parser_pin = subparsers.add_parser( - "pin", help="Recursively pin a dataset using its root CID." - ) - parser_pin.add_argument("root_cid", help="The root CID of the dataset to pin.") - parser_pin.set_defaults(func=handle_pin) - - # --- Unpin Command --- - parser_unpin = subparsers.add_parser( - "unpin", help="Recursively unpin a dataset using its root CID." - ) - parser_unpin.add_argument("root_cid", help="The root CID of the dataset to unpin.") - parser_unpin.set_defaults(func=handle_unpin) - - args = parser.parse_args() - - try: - asyncio.run(args.func(args)) - except KeyboardInterrupt: - print("\nOperation cancelled by user.", file=sys.stderr) - sys.exit(1) - except Exception as e: - print(f"\nAn unexpected error occurred: {e}", file=sys.stderr) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 1657086..104d808 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -1,12 +1,13 @@ import asyncio import math from collections.abc import AsyncIterator, Iterable -from typing import Dict, List, Optional, Set, Tuple +from typing import Coroutine, Dict, List, Optional, Set, Tuple import json import itertools - +import logging import dag_cbor +from multiformats.cid import CID import zarr.abc.store import zarr.core.buffer from zarr.core.common import BytesLike @@ -18,22 +19,16 @@ class ShardedZarrStore(zarr.abc.store.Store): """ Implements the Zarr Store API using a sharded layout for chunk CIDs. - This store divides the flat index of chunk CIDs into multiple smaller "shards". - Each shard is a contiguous block of bytes containing CIDs for a subset of chunks. - This can improve performance for certain access patterns and reduce the size - of individual index objects stored in the CAS. + # CHANGED: Docstring updated to reflect DAG-CBOR format. + This store divides the flat index of chunk CIDs into multiple "shards". + Each shard is a DAG-CBOR array where each element is either a CID link + to a chunk or a null value if the chunk is empty. This structure allows + for efficient traversal by IPLD-aware systems. The store's root object contains: 1. A dictionary mapping metadata keys (like 'zarr.json') to their CIDs. - 2. A list of CIDs, where each CID points to a shard of the chunk index. + 2. A list of CIDs, where each CID points to a shard object. 3. Sharding configuration details (e.g., chunks_per_shard). - - Accessing a chunk involves: - 1. Loading the root object (if not cached). - 2. Determining the shard index and the offset of the chunk's CID within that shard. - 3. Fetching the specific shard's CID from the root object. - 4. Fetching the chunk's CID using a byte-range request on the identified shard. - 5. Fetching the actual chunk data using the retrieved chunk CID. """ def __init__( @@ -48,25 +43,25 @@ def __init__( self._root_cid = root_cid self._root_obj: Optional[dict] = None + # CHANGED: The cache now stores a list of CID objects or None, not a bytearray. self._shard_data_cache: Dict[ - int, bytearray - ] = {} # shard_index -> shard_byte_data - self._dirty_shards: Set[int] = set() # Set of shard_indices that need flushing + int, list[Optional[CID]] + ] = {} + self._dirty_shards: Set[int] = set() self._pending_shard_loads: Dict[ int, asyncio.Task - ] = {} # shard_index -> Task loading the full shard + ] = {} - self._cid_len: Optional[int] = None + # REMOVED: _cid_len is no longer needed with structured DAG-CBOR shards. + # self._cid_len: Optional[int] = None self._array_shape: Optional[Tuple[int, ...]] = None self._chunk_shape: Optional[Tuple[int, ...]] = None - self._chunks_per_dim: Optional[Tuple[int, ...]] = ( - None # Number of chunks in each dimension - ) - self._chunks_per_shard: Optional[int] = None # How many chunk CIDs per shard - self._num_shards: Optional[int] = None # Total number of shards - self._total_chunks: Optional[int] = None # Total number of chunks in the array + self._chunks_per_dim: Optional[Tuple[int, ...]] = None + self._chunks_per_shard: Optional[int] = None + self._num_shards: Optional[int] = None + self._total_chunks: Optional[int] = None - self._dirty_root = False # Indicates if the root object itself (metadata or shard_cids list) changed + self._dirty_root = False def _update_geometry(self): @@ -99,7 +94,7 @@ async def open( array_shape: Optional[Tuple[int, ...]] = None, chunk_shape: Optional[Tuple[int, ...]] = None, chunks_per_shard: Optional[int] = None, - cid_len: int = 59, # Default for base32 v1 CIDs like bafy... (e.g., bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi) + # REMOVED: cid_len is no longer needed. ) -> "ShardedZarrStore": """ Asynchronously opens an existing ShardedZarrStore or initializes a new one. @@ -117,7 +112,7 @@ async def open( raise ValueError("chunks_per_shard must be a positive integer.") store._initialize_new_root( - array_shape, chunk_shape, chunks_per_shard, cid_len + array_shape, chunk_shape, chunks_per_shard ) else: raise ValueError("root_cid must be provided for a read-only store.") @@ -128,26 +123,24 @@ def _initialize_new_root( array_shape: Tuple[int, ...], chunk_shape: Tuple[int, ...], chunks_per_shard: int, - cid_len: int, ): self._array_shape = array_shape self._chunk_shape = chunk_shape self._chunks_per_shard = chunks_per_shard - self._cid_len = cid_len self._update_geometry() self._root_obj = { - "manifest_version": "sharded_zarr_v1", - "metadata": {}, # For .json - "chunks": { # Information about the chunk index itself - "array_shape": list(self._array_shape), # Original array shape - "chunk_shape": list(self._chunk_shape), # Original chunk shape - "cid_byte_length": self._cid_len, + "manifest_version": "sharded_zarr_v1", # CHANGED: Version reflects new format + "metadata": {}, + "chunks": { + "array_shape": list(self._array_shape), + "chunk_shape": list(self._chunk_shape), + # REMOVED: cid_byte_length is no longer relevant "sharding_config": { "chunks_per_shard": self._chunks_per_shard, }, - "shard_cids": [None] * self._num_shards, # List of CIDs for each shard + "shard_cids": [None] * self._num_shards, }, } self._dirty_root = True @@ -162,40 +155,37 @@ async def _load_root_from_cid(self): raise ValueError( f"Incompatible manifest version: {self._root_obj.get('manifest_version')}. Expected 'sharded_zarr_v1'." ) + chunk_info = self._root_obj["chunks"] self._array_shape = tuple(chunk_info["array_shape"]) self._chunk_shape = tuple(chunk_info["chunk_shape"]) - self._cid_len = chunk_info["cid_byte_length"] self._chunks_per_shard = chunk_info["sharding_config"]["chunks_per_shard"] self._update_geometry() if len(chunk_info["shard_cids"]) != self._num_shards: raise ValueError( - f"Inconsistent number of shards. Expected {self._num_shards} from shapes/config, " - f"found {len(chunk_info['shard_cids'])} in root object's shard_cids list." + f"Inconsistent number of shards. Expected {self._num_shards}, found {len(chunk_info['shard_cids'])}." ) async def _fetch_and_cache_full_shard(self, shard_idx: int, shard_cid: str): - """ - Fetches the full data for a shard and caches it. - Manages removal from _pending_shard_loads. - """ + # CHANGED: Logic now decodes the shard from DAG-CBOR into a list. try: - shard_data_bytes = await self.cas.load(shard_cid) # Load full shard - self._shard_data_cache[shard_idx] = bytearray(shard_data_bytes) + shard_data_bytes = await self.cas.load(shard_cid) + # Decode the CBOR object, which should be a list of CIDs/None + decoded_shard = dag_cbor.decode(shard_data_bytes) + if not isinstance(decoded_shard, list): + raise TypeError(f"Shard {shard_idx} did not decode to a list.") + self._shard_data_cache[shard_idx] = decoded_shard except Exception as e: - print(e) - # Handle or log the exception appropriately - print( - f"Warning: Failed to cache full shard {shard_idx} (CID: {shard_cid}): {e}" + logging.warning( + f"Failed to fetch or decode shard {shard_idx} (CID: {shard_cid}): {e}" ) - # If it fails, subsequent requests might try again if it's still not in cache. finally: - # Ensure the task is removed from pending list once done (success or failure) if shard_idx in self._pending_shard_loads: del self._pending_shard_loads[shard_idx] - + + # ... (Keep _parse_chunk_key, _get_linear_chunk_index, _get_shard_info as they are) ... def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # 1. Exclude .json files immediately (metadata) if key.endswith(".json"): @@ -284,55 +274,35 @@ def _get_shard_info(self, linear_chunk_index: int) -> Tuple[int, int]: index_in_shard = linear_chunk_index % self._chunks_per_shard return shard_idx, index_in_shard - async def _load_or_initialize_shard_cache(self, shard_idx: int) -> bytearray: + async def _load_or_initialize_shard_cache(self, shard_idx: int) -> list: + # CHANGED: This method is updated to handle list-based cache and DAG-CBOR decoding. if shard_idx in self._shard_data_cache: return self._shard_data_cache[shard_idx] if shard_idx in self._pending_shard_loads: - try: - await self._pending_shard_loads[shard_idx] - if shard_idx in self._shard_data_cache: - return self._shard_data_cache[shard_idx] - else: - pass # Fall through to normal loading - except asyncio.CancelledError: - if shard_idx in self._pending_shard_loads: - del self._pending_shard_loads[shard_idx] - # Fall through to normal loading - except Exception as e: - print( - f"Warning: Pending shard load for {shard_idx} failed: {e}. Attempting fresh load." - ) - - if self._root_obj is None: - raise RuntimeError( - "Root object not loaded or initialized (_root_obj is None)." - ) - if not ( - 0 <= shard_idx < self._num_shards if self._num_shards is not None else False - ): - raise ValueError( - f"Shard index {shard_idx} out of bounds for {self._num_shards} shards." - ) - - shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] - if shard_cid: - shard_data_bytes = await self.cas.load(shard_cid) - self._shard_data_cache[shard_idx] = bytearray(shard_data_bytes) + await self._pending_shard_loads[shard_idx] + if shard_idx in self._shard_data_cache: + return self._shard_data_cache[shard_idx] + + if self._root_obj is None or self._num_shards is None: + raise RuntimeError("Root object not loaded or initialized.") + if not (0 <= shard_idx < self._num_shards): + raise ValueError(f"Shard index {shard_idx} out of bounds.") + + shard_cid_obj = self._root_obj["chunks"]["shard_cids"][shard_idx] + if shard_cid_obj: + # The CID in the root should already be a CID object if loaded correctly. + shard_cid_str = str(shard_cid_obj) + await self._fetch_and_cache_full_shard(shard_idx, shard_cid_str) else: - if self._cid_len is None: # Should be set - raise RuntimeError( - "Store not initialized: _cid_len is None for shard initialization." - ) if self._chunks_per_shard is None: - raise RuntimeError( - "Store not initialized: _chunks_per_shard is None for shard initialization." - ) - # New shard or shard not yet written, initialize with zeros - shard_size_bytes = self._chunks_per_shard * self._cid_len - self._shard_data_cache[shard_idx] = bytearray( - shard_size_bytes - ) # Filled with \x00 + raise RuntimeError("Store not initialized: _chunks_per_shard is None.") + # Initialize new shard as a list of Nones + self._shard_data_cache[shard_idx] = [None] * self._chunks_per_shard + + if shard_idx not in self._shard_data_cache: + raise RuntimeError(f"Failed to load or initialize shard {shard_idx}") + return self._shard_data_cache[shard_idx] async def set_partial_values( @@ -356,73 +326,54 @@ def __eq__(self, other: object) -> bool: return NotImplemented # For equality, root CID is primary. Config like chunks_per_shard is part of that root's identity. return self._root_cid == other._root_cid - + async def flush(self) -> str: + # CHANGED: This method now encodes shards using DAG-CBOR. if self.read_only: - if ( - self._root_cid is None - ): # Read-only store should have been opened with a root_cid - raise ValueError( - "Read-only store has no root CID to return. Was it opened correctly?" - ) + if not self._root_cid: + raise ValueError("Read-only store has no root CID to return.") return self._root_cid - if self._root_obj is None: # Should be initialized for a writable store - raise RuntimeError("Store not initialized for writing: _root_obj is None.") - - # Update the array_shape and chunk_shape in the root object based on current state in the zarr.json - # Fetch the current array_shape and chunk_shape from root_ob - # TODO: - # zarr_json_cid = self._root_obj.get("metadata", {}).get("zarr.json", {}) - # if zarr_json_cid: - # zarr_json_bytes = await self.cas.load(zarr_json_cid) - # zarr_json = json.loads(zarr_json_bytes.decode("utf-8")) - # consolidated_metadata = zarr_json.get("consolidated_metadata", {}) - - # print("ZArr jSON bytes", zarr_json) - - # else: - # raise ValueError("Zarr JSON metadata CID not found.") - # print(self._array_shape, self._chunk_shape) - # self._root_obj["chunks"]["array_shape"] = list(self._array_shape) - # self._root_obj["chunks"]["chunk_shape"] = list(self._chunk_shape) + if self._root_obj is None: + raise RuntimeError("Store not initialized for writing.") - # Save all dirty shards first, as their CIDs might need to go into the root object if self._dirty_shards: for shard_idx in sorted(list(self._dirty_shards)): if shard_idx not in self._shard_data_cache: - # This implies an internal logic error if a shard is dirty but not in cache - # However, could happen if cache was cleared externally; robust code might reload/reinit - print( - f"Warning: Dirty shard {shard_idx} not found in cache. Skipping save for this shard." - ) + logging.warning(f"Dirty shard {shard_idx} not in cache. Skipping.") continue - shard_data_bytes = bytes(self._shard_data_cache[shard_idx]) + # Get the list of CIDs/Nones from the cache + shard_data_list = self._shard_data_cache[shard_idx] + + # Encode this list into a DAG-CBOR byte representation + shard_data_bytes = dag_cbor.encode(shard_data_list) - # The CAS save method here should return a string CID. - new_shard_cid = await self.cas.save( - shard_data_bytes, codec="raw" - ) # Shards are raw bytes of CIDs + # Save the DAG-CBOR block and get its CID + new_shard_cid_obj = await self.cas.save( + shard_data_bytes, codec="dag-cbor" # Use 'dag-cbor' codec + ) - if self._root_obj["chunks"]["shard_cids"][shard_idx] != new_shard_cid: - self._root_obj["chunks"]["shard_cids"][shard_idx] = new_shard_cid - self._dirty_root = True # Root object changed because a shard_cid in its list changed + if self._root_obj["chunks"]["shard_cids"][shard_idx] != new_shard_cid_obj: + # Store the CID object directly + self._root_obj["chunks"]["shard_cids"][shard_idx] = new_shard_cid_obj + self._dirty_root = True self._dirty_shards.clear() if self._dirty_root: + # Ensure all metadata CIDs are CID objects for correct encoding + self._root_obj["metadata"] = { + k: (CID.decode(v) if isinstance(v, str) else v) + for k, v in self._root_obj["metadata"].items() + } root_obj_bytes = dag_cbor.encode(self._root_obj) new_root_cid = await self.cas.save(root_obj_bytes, codec="dag-cbor") - self._root_cid = str(new_root_cid) # Ensure it's string + self._root_cid = str(new_root_cid) self._dirty_root = False - if ( - self._root_cid is None - ): # Should only happen if nothing was dirty AND it was a new store never flushed - raise RuntimeError( - "Failed to obtain a root CID after flushing. Store might be empty or unchanged." - ) + if self._root_cid is None: + raise RuntimeError("Failed to obtain a root CID after flushing.") return self._root_cid async def get( @@ -431,76 +382,41 @@ async def get( prototype: zarr.core.buffer.BufferPrototype, byte_range: Optional[zarr.abc.store.ByteRequest] = None, ) -> Optional[zarr.core.buffer.Buffer]: - if self._root_obj is None or self._cid_len is None: + # CHANGED: Logic is simplified to not use byte offsets. It relies on the full-shard cache. + if self._root_obj is None: raise RuntimeError("Load the root object first before accessing data.") print('Getting key', key) chunk_coords = self._parse_chunk_key(key) - # Metadata request (e.g., ".json") + # Metadata request if chunk_coords is None: - metadata_cid = self._root_obj["metadata"].get(key) - if metadata_cid is None: + metadata_cid_obj = self._root_obj["metadata"].get(key) + if metadata_cid_obj is None: return None - # byte_range is not typically applicable to metadata JSON objects themselves if byte_range is not None: - # Consider if this should be an error or ignored for metadata - print( - f"Warning: byte_range requested for metadata key '{key}'. Ignoring range." - ) - data = await self.cas.load(metadata_cid) + logging.warning(f"Byte range request for metadata key '{key}' ignored.") + data = await self.cas.load(str(metadata_cid_obj)) return prototype.buffer.from_bytes(data) + # Chunk data request linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) if not (0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): - # This case implies linear_chunk_index was out of _total_chunks bounds or bad sharding logic - return None - - target_shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] - if ( - target_shard_cid is None - ): # This shard has no data (all chunks within it are implicitly empty) return None - offset_in_shard_bytes = index_in_shard * self._cid_len - chunk_cid_bytes: Optional[bytes] = None - - if shard_idx in self._shard_data_cache: - cached_shard_data = self._shard_data_cache[shard_idx] - chunk_cid_bytes = bytes( - cached_shard_data[ - offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len - ] - ) - - if chunk_cid_bytes is None: # Not in cache or cache was invalid - chunk_cid_bytes = await self.cas.load( - target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len - ) - # After successfully fetching the specific CID bytes, - # check if we should initiate a background load of the full shard. - if ( - shard_idx not in self._shard_data_cache - and shard_idx not in self._pending_shard_loads - ): - self._pending_shard_loads[shard_idx] = asyncio.create_task( - self._fetch_and_cache_full_shard(shard_idx, target_shard_cid) - ) + # This will load the full shard into cache if it's not already there. + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) - if all( - b == 0 for b in chunk_cid_bytes - ): # Check for null CID placeholder (e.g. \x00 * cid_len) - return None # Chunk doesn't exist or is considered empty + # Get the CID object (or None) from the cached list. + chunk_cid_obj = target_shard_list[index_in_shard] + + if chunk_cid_obj is None: + return None # Chunk is empty/doesn't exist. - # Decode CID (assuming ASCII, remove potential null padding) - chunk_cid_str = chunk_cid_bytes.decode("ascii").rstrip("\x00") - if ( - not chunk_cid_str - ): # Empty string after rstrip if all were \x00 (already caught above) - return None + chunk_cid_str = str(chunk_cid_obj) - # Actual chunk data load using the retrieved chunk_cid_str + # Actual chunk data load using the retrieved chunk CID req_offset = byte_range.start if byte_range else None req_length = None if byte_range: @@ -512,7 +428,6 @@ async def get( f"Byte range start ({byte_range.start}) cannot be greater than end ({byte_range.end})" ) req_length = byte_range.end - byte_range.start - data = await self.cas.load(chunk_cid_str, offset=req_offset, length=req_length) return prototype.buffer.from_bytes(data) @@ -520,91 +435,51 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: if self.read_only: raise ValueError("Cannot write to a read-only store.") if self._root_obj is None: - raise RuntimeError( - "Store not initialized for writing (root_obj is None). Call open() first." - ) - # print('Setting key', key) - # Find the data variable and update its chunk data + raise RuntimeError("Store not initialized for writing. Call open() first.") + if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: - # extract the metadata from the value - # and store it in the root_obj["metadata"] dict - converted_value = value.to_bytes().decode("utf-8") - # Read the json - metadata_json = json.loads(converted_value) + metadata_json = json.loads(value.to_bytes().decode("utf-8")) new_array_shape = metadata_json.get("shape") if not new_array_shape: raise ValueError("Shape not found in metadata.") - if tuple(new_array_shape) != self._array_shape: - print(f"Detected shape change from {self._array_shape} to {tuple(new_array_shape)}. Resizing shard index...") - # Use your existing resize_store method to handle all the recalculations - # and extend the list of shard CIDs. await self.resize_store(new_shape=tuple(new_array_shape)) - - raw_chunk_data_bytes = value.to_bytes() - # Save the actual chunk data to CAS first, to get its CID - chunk_data_cid_obj = await self.cas.save( - raw_chunk_data_bytes, codec="raw" - ) # Chunks are typically raw bytes - chunk_data_cid_str = str(chunk_data_cid_obj) - await self.set_pointer(key, chunk_data_cid_str) # Store the CID in the index + raw_data_bytes = value.to_bytes() + # Save the data to CAS first to get its CID. + # Metadata is often saved as 'raw', chunks as well unless compressed. + data_cid_obj = await self.cas.save(raw_data_bytes, codec="raw") + await self.set_pointer(key, str(data_cid_obj)) async def set_pointer(self, key: str, pointer: str) -> None: - if self._root_obj is None or self._cid_len is None: + # CHANGED: Logic now updates a list in the cache, not a bytearray. + if self._root_obj is None: raise RuntimeError("Load the root object first before accessing data.") - # Ensure the CID (as ASCII bytes) fits in the allocated slot, padding with nulls - chunk_data_cid_ascii_bytes = pointer.encode("ascii") - if len(chunk_data_cid_ascii_bytes) > self._cid_len: - raise ValueError( - f"Encoded CID byte length ({len(chunk_data_cid_ascii_bytes)}) exceeds configured CID length ({self._cid_len}). CID: {pointer}" - ) - padded_chunk_data_cid_bytes = chunk_data_cid_ascii_bytes.ljust( - self._cid_len, b"\0" - ) - + chunk_coords = self._parse_chunk_key(key) - # print(f"Setting for key '{key}': {chunk_coords}") - - if chunk_coords is None: # Metadata key (e.g., ".json") - # For metadata, the 'value' is the metadata content itself, not a CID to it. - # So, we store the metadata content, get its CID, and put *that* CID in root_obj. - # This means the `value_cid_str` for metadata should be from `raw_chunk_data_bytes`. - # This seems to align with FlatZarrStore, where `value_cid` is used for both. - self._root_obj["metadata"][key] = ( - pointer # Store the string CID of the metadata content - ) + + pointer_cid_obj = CID.decode(pointer) # Convert string to CID object + + if chunk_coords is None: # Metadata key + self._root_obj["metadata"][key] = pointer_cid_obj self._dirty_root = True return - # Chunk Data: `chunk_data_cid_str` is the CID of the data we just saved. - # Now we need to store this CID string (padded) into the correct shard. + # Chunk Data: Store the CID object in the correct shard list. linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - # Ensure the target shard is loaded or initialized in cache - target_shard_data_cache = await self._load_or_initialize_shard_cache(shard_idx) - - offset_in_shard_bytes = index_in_shard * self._cid_len - - # Check if the content is actually changing to avoid unnecessary dirtying (optional optimization) - # current_bytes_in_shard = target_shard_data_cache[offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len] - # if current_bytes_in_shard == padded_chunk_data_cid_bytes: - # return # No change - - target_shard_data_cache[ - offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len - ] = padded_chunk_data_cid_bytes - self._dirty_shards.add(shard_idx) - # If this write implies the shard CID in root_obj["chunks"]["shard_cids"] might change - # (e.g., from None to an actual CID when the shard is first flushed), - # then _dirty_root will be set by flush(). + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + + if target_shard_list[index_in_shard] != pointer_cid_obj: + target_shard_list[index_in_shard] = pointer_cid_obj + self._dirty_shards.add(shard_idx) + # ... (Keep exists method, but simplify it) ... async def exists(self, key: str) -> bool: - if self._root_obj is None or self._cid_len is None: - raise RuntimeError( - "Root object not loaded. Call _load_root_from_cid() first." - ) + # CHANGED: Simplified to use the list-based cache. + if self._root_obj is None: + raise RuntimeError("Root object not loaded.") chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata @@ -614,40 +489,20 @@ async def exists(self, key: str) -> bool: linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - if not ( - self._root_obj - and "chunks" in self._root_obj - and 0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"]) - ): + if not (0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): return False - target_shard_cid = self._root_obj["chunks"]["shard_cids"][shard_idx] - if target_shard_cid is None: # Shard itself doesn't exist + shard_cid_obj = self._root_obj["chunks"]["shard_cids"][shard_idx] + if shard_cid_obj is None: return False - offset_in_shard_bytes = index_in_shard * self._cid_len - - # Optimization: Check local shard cache first - if shard_idx in self._shard_data_cache: - cached_shard_data = self._shard_data_cache[shard_idx] - # Ensure index_in_shard is valid for this cached data length - if offset_in_shard_bytes + self._cid_len <= len(cached_shard_data): - chunk_cid_bytes_from_cache = cached_shard_data[ - offset_in_shard_bytes : offset_in_shard_bytes + self._cid_len - ] - return not all(b == 0 for b in chunk_cid_bytes_from_cache) - # else: fall through to CAS load, cache might be out of sync or wrong size (should not happen with correct logic) - - # If not in cache or cache check was inconclusive, fetch from CAS - chunk_cid_bytes_from_cas = await self.cas.load( - target_shard_cid, offset=offset_in_shard_bytes, length=self._cid_len - ) - return not all(b == 0 for b in chunk_cid_bytes_from_cas) - except ( - Exception - ): # Broad catch for issues like invalid coords, CAS errors during load etc. + # Load shard if not cached and check the index + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + return target_shard_list[index_in_shard] is not None + except Exception: return False - + + # ... (Keep supports_writes, etc. properties) ... @property def supports_writes(self) -> bool: return not self.read_only @@ -661,61 +516,33 @@ def supports_deletes(self) -> bool: return not self.read_only async def delete(self, key: str) -> None: + # CHANGED: Simplified to set list element to None. if self.read_only: raise ValueError("Cannot delete from a read-only store.") if self._root_obj is None: - raise RuntimeError("Store not initialized for deletion (root_obj is None).") - if self._cid_len is None: - raise RuntimeError("Store not initialized properly; _cid_len is missing.") - # print(f"Deleting key: {key}") - + raise RuntimeError("Store not initialized for deletion.") + chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata - if key in self._root_obj.get("metadata", {}): - del self._root_obj["metadata"][key] + if self._root_obj["metadata"].pop(key, None): self._dirty_root = True - return else: - raise KeyError(f"Metadata key '{key}' not found for deletion.") + raise KeyError(f"Metadata key '{key}' not found.") + return - # Chunk deletion: zero out the CID entry in the shard linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - if not ( - 0 <= shard_idx < (self._num_shards if self._num_shards is not None else 0) - ): - raise KeyError( - f"Chunk key '{key}' maps to an invalid shard index {shard_idx}." - ) - - # Ensure shard data is available for modification (loads from CAS if not in cache, or initializes if new) - target_shard_data_cache = await self._load_or_initialize_shard_cache(shard_idx) - - offset_in_shard_bytes = index_in_shard * self._cid_len + if not (0 <= shard_idx < self._num_shards if self._num_shards is not None else 0): + raise KeyError(f"Chunk key '{key}' is out of bounds.") - # Check if the entry is already zeroed (meaning it doesn't exist or already deleted) - is_already_zero = True - for i in range(self._cid_len): - if ( - offset_in_shard_bytes + i >= len(target_shard_data_cache) - or target_shard_data_cache[offset_in_shard_bytes + i] != 0 - ): - is_already_zero = False - break - - if is_already_zero: + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + + if target_shard_list[index_in_shard] is not None: + target_shard_list[index_in_shard] = None self._dirty_shards.add(shard_idx) - # Zero out the CID entry in the shard cache - for i in range(self._cid_len): - target_shard_data_cache[offset_in_shard_bytes + i] = 0 - - self._dirty_shards.add(shard_idx) - # If this shard becomes non-None in root_obj due to other writes, flush will handle it. - # If this deletion makes a previously non-None shard all zeros, the shard itself might - # eventually be elided if we had shard GC, but its CID remains in root_obj for now. - + # ... (Keep listing methods as they are, they operate on metadata) ... @property def supports_listing(self) -> bool: return True @@ -732,17 +559,10 @@ async def list_prefix(self, prefix: str) -> AsyncIterator[str]: async for key in self.list(): if key.startswith(prefix): yield key - + # ... (Keep graft_store, but it needs significant changes) ... + async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, ...]): - """ - Performs a high-performance, metadata-only append by "grafting" the - chunk CIDs from another store into this store at a given offset. - - Args: - store_to_graft_cid: The root CID of the Zarr store whose chunks will be copied. - chunk_offset: A tuple defining the starting chunk coordinates in the target store. - e.g., (3, 0, 0) to start at the 4th time chunk. - """ + # CHANGED: This method is heavily modified to work with the new DAG-CBOR format. if self.read_only: raise ValueError("Cannot graft onto a read-only store.") if self._root_obj is None: @@ -750,56 +570,36 @@ async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, .. print(f"Grafting store {store_to_graft_cid[:10]}... at chunk offset {chunk_offset}") - # 1. Open the store we want to copy chunks from (read-only) store_to_graft = await ShardedZarrStore.open(cas=self.cas, read_only=True, root_cid=store_to_graft_cid) - if store_to_graft._root_obj is None: - raise ValueError("Store to graft could not be loaded.") - if store_to_graft._chunks_per_dim is None or self._cid_len is None: - raise ValueError("Store to graft is not properly configured.") - - source_shard_cache: Dict[int, bytes] = {} + if store_to_graft._root_obj is None or store_to_graft._chunks_per_dim is None: + raise ValueError("Store to graft could not be loaded or is not configured.") source_chunk_grid = store_to_graft._chunks_per_dim for local_coords in itertools.product(*[range(s) for s in source_chunk_grid]): - - # 3. Get the pointer (CID) for each chunk from the source store linear_local_index = store_to_graft._get_linear_chunk_index(local_coords) local_shard_idx, index_in_local_shard = store_to_graft._get_shard_info(linear_local_index) - if local_shard_idx not in source_shard_cache: - source_shard_cid = store_to_graft._root_obj["chunks"]["shard_cids"][local_shard_idx] - if not source_shard_cid: - source_shard_cache[local_shard_idx] = b'' # Mark as loaded but empty - continue - source_shard_cache[local_shard_idx] = await self.cas.load(source_shard_cid) + # Load the source shard into its cache + source_shard_list = await store_to_graft._load_or_initialize_shard_cache(local_shard_idx) - source_shard_data = source_shard_cache[local_shard_idx] - - if not source_shard_data: - continue # This chunk was empty (all fill value) - - # Extract the pointer bytes from the in-memory shard data - offset_in_source_shard = index_in_local_shard * store_to_graft._cid_len - pointer_bytes = source_shard_data[offset_in_source_shard : offset_in_source_shard + store_to_graft._cid_len] - - if all(b == 0 for b in pointer_bytes): - continue # Skip empty CID slots + pointer_cid_obj = source_shard_list[index_in_local_shard] + if pointer_cid_obj is None: + continue # Calculate global coordinates and write to the main store's index global_coords = tuple(c_local + c_offset for c_local, c_offset in zip(local_coords, chunk_offset)) linear_global_index = self._get_linear_chunk_index(global_coords) global_shard_idx, index_in_global_shard = self._get_shard_info(linear_global_index) - target_shard_cache = await self._load_or_initialize_shard_cache(global_shard_idx) - offset_in_global_shard = index_in_global_shard * self._cid_len + target_shard_list = await self._load_or_initialize_shard_cache(global_shard_idx) - target_shard_cache[offset_in_global_shard : offset_in_global_shard + self._cid_len] = pointer_bytes - self._dirty_shards.add(global_shard_idx) + if target_shard_list[index_in_global_shard] != pointer_cid_obj: + target_shard_list[index_in_global_shard] = pointer_cid_obj + self._dirty_shards.add(global_shard_idx) print(f"✓ Grafting complete for store {store_to_graft_cid[:10]}...") - - + # ... (Keep resizing methods as they mostly affect metadata) ... async def resize_store(self, new_shape: Tuple[int, ...]): """ Resizes the store's main shard index to accommodate a new overall array shape. @@ -890,136 +690,3 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: if child not in seen: seen.add(child) yield child - - async def _iterate_chunk_cids(self, shard_cid) -> AsyncIterator[str]: - """An async generator that yields all non-empty chunk CIDs in the store.""" - if self._root_obj is None or self._cid_len is None: - raise RuntimeError( - "Root object not loaded. Call _load_root_from_cid() first." - ) - if not shard_cid: - return - - try: - shard_data = await self.cas.load(shard_cid) - for i in range(0, len(shard_data), self._cid_len): - cid_bytes = shard_data[i : i + self._cid_len] - if not all(b == 0 for b in cid_bytes): - yield cid_bytes.decode("ascii").rstrip("\x00") - except Exception as e: - logging.warning(f"Could not process shard {shard_cid} for iteration: {e}") - - async def pin_entire_dataset( - self, - target_rpc: str = "http://127.0.0.1:5001", - concurrency_limit: int = 50, - show_progress: bool = True - ) -> None: - """ - Pins the entire dataset in parallel, with a limit on concurrent requests. - """ - if self._root_obj is None: - raise RuntimeError("Root object not loaded.") - if self._cid_len is None: - raise RuntimeError("Store is not initialized properly.") - - # --- 1. First, gather all unique CIDs to pin --- - print("Gathering all CIDs for pinning...") - cids_to_pin: Set[str] = set() - - if self._root_cid: - cids_to_pin.add(self._root_cid) - - for cid in self._root_obj.get("metadata", {}).values(): - if cid: - cids_to_pin.add(cid) - - for shard_cid in self._root_obj["chunks"]["shard_cids"]: - if not shard_cid: - continue - cids_to_pin.add(shard_cid) - async for chunk_cid in self._iterate_chunk_cids(shard_cid): - if chunk_cid: - cids_to_pin.add(chunk_cid) - - total_cids = len(cids_to_pin) - print(f"Found {total_cids} unique CIDs to pin.") - - # --- 2. Create and run pinning tasks in parallel with a semaphore --- - semaphore = asyncio.Semaphore(concurrency_limit) - tasks: list[Coroutine] = [] - - # Helper function to wrap the pin call with the semaphore - async def pin_with_semaphore(cid: str): - async with semaphore: - await self.cas.pin_cid(cid, target_rpc=target_rpc) - if show_progress: - # This progress reporting is approximate as tasks run out of order - # For precise progress, see asyncio.as_completed in the notes below - pass - - for cid in cids_to_pin: - tasks.append(pin_with_semaphore(cid)) - - print(f"Pinning {total_cids} CIDs with a concurrency of {concurrency_limit}...") - # The 'return_exceptions=False' (default) means this will raise the first exception it encounters. - await asyncio.gather(*tasks) - print("Successfully pinned the entire dataset.") - - - async def unpin_entire_dataset( - self, - target_rpc: str = "http://127.0.0.1:5001", - concurrency_limit: int = 50 - ) -> None: - """ - Unpins the entire dataset in parallel, with a limit on concurrent requests. - """ - if self._root_obj is None: - raise RuntimeError("Root object not loaded.") - if self._cid_len is None: - raise RuntimeError("Store is not initialized properly.") - - # --- 1. Gather all CIDs to unpin (chunks, shards, metadata) --- - print("Gathering all CIDs for unpinning...") - cids_to_unpin: Set[str] = set() - - for shard_cid in self._root_obj["chunks"]["shard_cids"]: - if not shard_cid: - continue - cids_to_unpin.add(shard_cid) - async for chunk_cid in self._iterate_chunk_cids(shard_cid): - if chunk_cid: - cids_to_unpin.add(chunk_cid) - - for cid in self._root_obj.get("metadata", {}).values(): - if cid: - cids_to_unpin.add(cid) - - # --- 2. Unpin all children in parallel --- - semaphore = asyncio.Semaphore(concurrency_limit) - tasks: list[Coroutine] = [] - - async def unpin_with_semaphore(cid: str): - async with semaphore: - try: - await self.cas.unpin_cid(cid, target_rpc=target_rpc) - except Exception: - # This is safe to ignore in a mass-unpin operation - print(f"Warning: Could not unpin chunk CID {cid}. Likely already unpinned.") - - for cid in cids_to_unpin: - tasks.append(unpin_with_semaphore(cid)) - - print(f"Unpinning {len(tasks)} child CIDs with a concurrency of {concurrency_limit}...") - await asyncio.gather(*tasks) - print("Successfully unpinned all child objects.") - - # --- 3. Finally, unpin the root CID itself after its children --- - if self._root_cid: - try: - print(f"Unpinning root CID {self._root_cid}...") - await self.cas.unpin_cid(self._root_cid, target_rpc=target_rpc) - print("Successfully unpinned the entire dataset.") - except Exception: - print(f"Warning: Could not unpin root CID {self._root_cid}. Likely already unpinned.") diff --git a/tests/test_sharded_zarr_pinning.py b/tests/test_sharded_zarr_pinning.py deleted file mode 100644 index bbd6a31..0000000 --- a/tests/test_sharded_zarr_pinning.py +++ /dev/null @@ -1,142 +0,0 @@ -import asyncio - -import httpx -import numpy as np -import pandas as pd -import pytest -import xarray as xr - -from py_hamt import KuboCAS, ShardedZarrStore - - -# Helper function to query the IPFS daemon for all pinned CIDs -async def get_pinned_cids(rpc_base_url: str) -> set[str]: - """Queries the Kubo RPC API and returns a set of all pinned CIDs.""" - url = f"{rpc_base_url}/api/v0/pin/ls" - try: - async with httpx.AsyncClient() as client: - resp = await client.post(url, params={"type": "all"}) - resp.raise_for_status() # Raises an exception for 4xx/5xx status codes - data = resp.json() - return set(data.get("Keys", {}).keys()) - except Exception as e: - pytest.fail(f"Failed to query pinned CIDs from Kubo RPC API: {e}") - return set() - - -# Helper function to gather all CIDs from a store instance -async def get_all_dataset_cids(store: ShardedZarrStore) -> set[str]: - """Helper to collect all CIDs associated with a ShardedZarrStore instance.""" - if store._root_obj is None or store._cid_len is None: - raise RuntimeError("Store is not properly initialized.") - - cids = set() - if store._root_cid: - cids.add(store._root_cid) - - # Gather metadata CIDs - for cid in store._root_obj.get("metadata", {}).values(): - if cid: - cids.add(cid) - - # Gather shard and all chunk CIDs within them - for shard_cid in store._root_obj["chunks"]["shard_cids"]: - if not shard_cid: - continue - cids.add(str(shard_cid)) - try: - # Load shard data to find the chunk CIDs within - shard_data = await store.cas.load(shard_cid) - for i in range(0, len(shard_data), store._cid_len): - cid_bytes = shard_data[i : i + store._cid_len] - if all(b == 0 for b in cid_bytes): # Skip null/empty CID slots - continue - - chunk_cid_str = cid_bytes.decode("ascii").rstrip("\x00") - if chunk_cid_str: - cids.add(chunk_cid_str) - except Exception as e: - print(f"Warning: Could not load shard {shard_cid} to gather its CIDs: {e}") - - return cids - - -@pytest.fixture(scope="module") -def random_zarr_dataset_for_pinning(): - """Creates a random xarray Dataset specifically for the pinning test.""" - times = pd.date_range("2025-01-01", periods=50) - lats = np.linspace(-90, 90, 10) - lons = np.linspace(-180, 180, 20) - - temp = np.random.randn(len(times), len(lats), len(lons)) - - ds = xr.Dataset( - {"temp": (["time", "lat", "lon"], temp)}, - coords={"time": times, "lat": lats, "lon": lons}, - ) - - # Define chunking for the store - ds = ds.chunk({"time": 10, "lat": 10, "lon": 20}) - yield ds - - -@pytest.mark.asyncio -async def test_sharded_zarr_store_pinning( - create_ipfs: tuple[str, str], random_zarr_dataset_for_pinning: xr.Dataset -): - """ - Tests the pin_entire_dataset and unpin_entire_dataset methods. - """ - rpc_base_url, gateway_base_url = create_ipfs - test_ds = random_zarr_dataset_for_pinning - - ordered_dims = list(test_ds.dims) - array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) - chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) - - async with KuboCAS( - rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url - ) as kubo_cas: - # --- 1. Write dataset to the store --- - store = await ShardedZarrStore.open( - cas=kubo_cas, - read_only=False, - array_shape=array_shape_tuple, - chunk_shape=chunk_shape_tuple, - chunks_per_shard=1, # Use a smaller number to ensure multiple shards - ) - test_ds.to_zarr(store=store, mode="w", consolidated=True) - root_cid = await store.flush() - assert root_cid is not None - - # --- 2. Gather all expected CIDs from the written store --- - expected_cids = await get_all_dataset_cids(store) - assert len(expected_cids) > 5 # Sanity check: ensure we have CIDs to test - - # --- 3. Pin the dataset and verify --- - await store.pin_entire_dataset(target_rpc=rpc_base_url) - - # Allow a moment for pins to register - await asyncio.sleep(1) - - currently_pinned = await get_pinned_cids(rpc_base_url) - - # Check if all our dataset's CIDs are in the main pin list - missing_pins = expected_cids - currently_pinned - assert not missing_pins, ( - f"The following CIDs were expected to be pinned but were not: {missing_pins}" - ) - - # --- 4. Unpin the dataset and verify --- - await store.unpin_entire_dataset(target_rpc=rpc_base_url) - - # Allow a moment for pins to be removed - await asyncio.sleep(1) - - pinned_after_unpin = await get_pinned_cids(rpc_base_url) - - # Check that none of our dataset's CIDs are in the pin list anymore - lingering_pins = expected_cids.intersection(pinned_after_unpin) - assert not lingering_pins, ( - f"The following CIDs were expected to be unpinned but still exist: {lingering_pins}" - ) From 999a85ac54bb6b2fa6f8e3ed2edd142d066c1710 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 2 Jul 2025 06:27:13 -0400 Subject: [PATCH 33/74] fix: remove print --- py_hamt/sharded_zarr_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 104d808..c0e6954 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -385,7 +385,7 @@ async def get( # CHANGED: Logic is simplified to not use byte offsets. It relies on the full-shard cache. if self._root_obj is None: raise RuntimeError("Load the root object first before accessing data.") - print('Getting key', key) + # print('Getting key', key) chunk_coords = self._parse_chunk_key(key) # Metadata request From 6f79dbb123c5cd2bce947a576a55204deffcef6e Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 3 Jul 2025 03:54:14 -0400 Subject: [PATCH 34/74] fix: sharding print and default --- py_hamt/hamt_to_sharded_converter.py | 2 +- py_hamt/sharded_zarr_store.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index 681bdce..46bbf3f 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -93,7 +93,7 @@ async def sharded_converter_cli(): parser.add_argument( "--chunks-per-shard", type=int, - default=1024, + default=6250, help="Number of chunk CIDs to store per shard in the new store.", ) parser.add_argument( diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index c0e6954..d617aa3 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -627,7 +627,6 @@ async def resize_store(self, new_shape: Tuple[int, ...]): self._root_obj["chunks"]["shard_cids"] = self._root_obj["chunks"]["shard_cids"][:self._num_shards] self._dirty_root = True - print(f"Store's internal shard index resized. New main array shape: {self._array_shape}") async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): From 78d7621cc7bce6427d75e1f95746975d272ee5ca Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 03:19:40 -0400 Subject: [PATCH 35/74] fix: fix race condition --- py_hamt/sharded_zarr_store.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index d617aa3..f6f6a1f 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -43,7 +43,13 @@ def __init__( self._root_cid = root_cid self._root_obj: Optional[dict] = None - # CHANGED: The cache now stores a list of CID objects or None, not a bytearray. + self._resize_lock = asyncio.Lock() + + # An event to signal when a resize is in-progress. + # It starts in the "set" state, allowing all operations to proceed. + self._resize_complete = asyncio.Event() + self._resize_complete.set() + self._shard_data_cache: Dict[ int, list[Optional[CID]] ] = {} @@ -52,8 +58,6 @@ def __init__( int, asyncio.Task ] = {} - # REMOVED: _cid_len is no longer needed with structured DAG-CBOR shards. - # self._cid_len: Optional[int] = None self._array_shape: Optional[Tuple[int, ...]] = None self._chunk_shape: Optional[Tuple[int, ...]] = None self._chunks_per_dim: Optional[Tuple[int, ...]] = None @@ -436,14 +440,32 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: raise ValueError("Cannot write to a read-only store.") if self._root_obj is None: raise RuntimeError("Store not initialized for writing. Call open() first.") - - if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: + + await self._resize_complete.wait() + + is_main_metadata = ( + key.endswith("zarr.json") and + not any(key.startswith(prefix) for prefix in ["time/", "lat/", "lon/", "latitude/", "longitude/"]) and + len(key.split('/')) == 1 # Ensures it's the root zarr.json + ) + + if is_main_metadata: metadata_json = json.loads(value.to_bytes().decode("utf-8")) new_array_shape = metadata_json.get("shape") if not new_array_shape: raise ValueError("Shape not found in metadata.") if tuple(new_array_shape) != self._array_shape: - await self.resize_store(new_shape=tuple(new_array_shape)) + async with self._resize_lock: + # Double-check after acquiring the lock, in case another task + # just finished this exact resize while we were waiting. + if tuple(new_array_shape) != self._array_shape: + # Block all other tasks until resize is complete. + self._resize_complete.clear() + try: + await self.resize_store(new_shape=tuple(new_array_shape)) + finally: + # All waiting tasks will now un-pause and proceed safely. + self._resize_complete.set() raw_data_bytes = value.to_bytes() # Save the data to CAS first to get its CID. @@ -573,7 +595,6 @@ async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, .. store_to_graft = await ShardedZarrStore.open(cas=self.cas, read_only=True, root_cid=store_to_graft_cid) if store_to_graft._root_obj is None or store_to_graft._chunks_per_dim is None: raise ValueError("Store to graft could not be loaded or is not configured.") - source_chunk_grid = store_to_graft._chunks_per_dim for local_coords in itertools.product(*[range(s) for s in source_chunk_grid]): linear_local_index = store_to_graft._get_linear_chunk_index(local_coords) From efa3429d3030974ce4c7605c86aaf49b75d5c5a4 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 03:21:52 -0400 Subject: [PATCH 36/74] fix: revert metadata logic --- py_hamt/sharded_zarr_store.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index f6f6a1f..e10ea1a 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -443,13 +443,7 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: await self._resize_complete.wait() - is_main_metadata = ( - key.endswith("zarr.json") and - not any(key.startswith(prefix) for prefix in ["time/", "lat/", "lon/", "latitude/", "longitude/"]) and - len(key.split('/')) == 1 # Ensures it's the root zarr.json - ) - - if is_main_metadata: + if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: metadata_json = json.loads(value.to_bytes().decode("utf-8")) new_array_shape = metadata_json.get("shape") if not new_array_shape: From 2f74b56df13c546d327197857a4818b7cbfec0f3 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 03:25:58 -0400 Subject: [PATCH 37/74] fix: debug key --- py_hamt/sharded_zarr_store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index e10ea1a..ff58b5d 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -443,6 +443,8 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: await self._resize_complete.wait() + print("Setting key:", key) + if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: metadata_json = json.loads(value.to_bytes().decode("utf-8")) new_array_shape = metadata_json.get("shape") From 6fc813b0777b16bc7674692e1ea3b6341e5f8501 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 03:28:08 -0400 Subject: [PATCH 38/74] fix: more debug --- py_hamt/sharded_zarr_store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index ff58b5d..d296055 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -447,6 +447,7 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: metadata_json = json.loads(value.to_bytes().decode("utf-8")) + print("setting metadata for key:", key, "with value:", metadata_json) new_array_shape = metadata_json.get("shape") if not new_array_shape: raise ValueError("Shape not found in metadata.") From 36e1ad02c0a5b5e659f7a16c755cf957d30cab2d Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 03:31:08 -0400 Subject: [PATCH 39/74] fix: more debug --- py_hamt/sharded_zarr_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index d296055..118bed2 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -443,7 +443,7 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: await self._resize_complete.wait() - print("Setting key:", key) + print("Setting key:", key, "with", self.array_shape) if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: metadata_json = json.loads(value.to_bytes().decode("utf-8")) From a57b4346cabbd9f834dc55827e42f6d3b0e53e50 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 03:32:07 -0400 Subject: [PATCH 40/74] fix: array shape --- py_hamt/sharded_zarr_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 118bed2..e36cb3e 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -443,7 +443,7 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: await self._resize_complete.wait() - print("Setting key:", key, "with", self.array_shape) + print("Setting key:", key, "with", self._array_shape) if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: metadata_json = json.loads(value.to_bytes().decode("utf-8")) From a67e66152ebf3654d63c52aa741497790ba01d17 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 04:09:17 -0400 Subject: [PATCH 41/74] fix: tests and race condition on caches --- py_hamt/sharded_zarr_store.py | 18 +++--- tests/test_sharded_zarr_store.py | 100 ++++++++++++++++++++++++++----- 2 files changed, 96 insertions(+), 22 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index e36cb3e..f2f578b 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -6,6 +6,7 @@ import itertools import logging +from collections import defaultdict import dag_cbor from multiformats.cid import CID import zarr.abc.store @@ -50,6 +51,8 @@ def __init__( self._resize_complete = asyncio.Event() self._resize_complete.set() + self._shard_locks = defaultdict(asyncio.Lock) + self._shard_data_cache: Dict[ int, list[Optional[CID]] ] = {} @@ -443,11 +446,8 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: await self._resize_complete.wait() - print("Setting key:", key, "with", self._array_shape) - if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: metadata_json = json.loads(value.to_bytes().decode("utf-8")) - print("setting metadata for key:", key, "with value:", metadata_json) new_array_shape = metadata_json.get("shape") if not new_array_shape: raise ValueError("Shape not found in metadata.") @@ -488,11 +488,13 @@ async def set_pointer(self, key: str, pointer: str) -> None: linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) - - if target_shard_list[index_in_shard] != pointer_cid_obj: - target_shard_list[index_in_shard] = pointer_cid_obj - self._dirty_shards.add(shard_idx) + shard_lock = self._shard_locks[shard_idx] + async with shard_lock: + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + + if target_shard_list[index_in_shard] != pointer_cid_obj: + target_shard_list[index_in_shard] = pointer_cid_obj + self._dirty_shards.add(shard_idx) # ... (Keep exists method, but simplify it) ... async def exists(self, key: str) -> bool: diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 2778f4d..9555353 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -41,8 +41,8 @@ async def test_sharded_zarr_store_write_read( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) - array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) async with KuboCAS( @@ -67,6 +67,78 @@ async def test_sharded_zarr_store_write_read( ds_read = xr.open_zarr(store=store_read) xr.testing.assert_identical(test_ds, ds_read) +@pytest.mark.asyncio +async def test_sharded_zarr_store_append( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests appending data to an existing Zarr dataset in the ShardedZarrStore, + which specifically exercises the store resizing logic. + """ + rpc_base_url, gateway_base_url = create_ipfs + initial_ds = random_zarr_dataset + + # The main data variable we are sharding + main_variable = "temp" + + ordered_dims = list(initial_ds.sizes) + array_shape_tuple = tuple(initial_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(initial_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # 1. --- Write Initial Dataset --- + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=10, + ) + initial_ds.to_zarr(store=store_write, mode="w") + initial_cid = await store_write.flush() + assert initial_cid is not None + + # 2. --- Prepare Data to Append --- + # Create a new dataset with 50 more time steps + append_times = pd.date_range(initial_ds.time[-1].values + pd.Timedelta(days=1), periods=50) + append_temp = np.random.randn(len(append_times), len(initial_ds.lat), len(initial_ds.lon)) + + append_ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], append_temp), + }, + coords={"time": append_times, "lat": initial_ds.lat, "lon": initial_ds.lon}, + ).chunk({"time": 20, "lat": 18, "lon": 36}) + + # 3. --- Perform Append Operation --- + store_append = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + root_cid=initial_cid, + ) + append_ds.to_zarr(store=store_append, mode="a", append_dim="time") + final_cid = await store_append.flush() + print(f"Data written to ShardedZarrStore with root CID: {final_cid}") + assert final_cid is not None + assert final_cid != initial_cid + + # 4. --- Verify the Final Dataset --- + store_read = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=True, + root_cid=final_cid, + ) + final_ds_read = xr.open_zarr(store=store_read) + + # The expected result is the concatenation of the two datasets + expected_final_ds = xr.concat([initial_ds, append_ds], dim="time") + + # Verify that the data read from the store is identical to the expected result + xr.testing.assert_identical(expected_final_ds, final_ds_read) + print("\n✅ Append test successful! Data verified.") + @pytest.mark.asyncio async def test_sharded_zarr_store_init(create_ipfs: tuple[str, str]): @@ -110,8 +182,8 @@ async def test_sharded_zarr_store_metadata( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) - array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) async with KuboCAS( @@ -161,8 +233,8 @@ async def test_sharded_zarr_store_chunks( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) - array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) async with KuboCAS( @@ -209,8 +281,8 @@ async def test_chunk_and_delete_logic( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) - array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) async with KuboCAS( @@ -261,8 +333,8 @@ async def test_sharded_zarr_store_partial_reads( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) - array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) async with KuboCAS( @@ -302,8 +374,8 @@ async def test_partial_reads_and_errors( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) - array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) async with KuboCAS( @@ -399,8 +471,8 @@ async def test_listing_and_metadata( rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset - ordered_dims = list(test_ds.dims) - array_shape_tuple = tuple(test_ds.dims[dim] for dim in ordered_dims) + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) async with KuboCAS( From a60039b4e2305847c8b9241557a35f96038d7ca0 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 04:20:51 -0400 Subject: [PATCH 42/74] fix: more locks --- py_hamt/sharded_zarr_store.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index f2f578b..2efcfe3 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -413,7 +413,9 @@ async def get( return None # This will load the full shard into cache if it's not already there. - target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + shard_lock = self._shard_locks[shard_idx] + async with shard_lock: + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) # Get the CID object (or None) from the cached list. chunk_cid_obj = target_shard_list[index_in_shard] @@ -557,11 +559,14 @@ async def delete(self, key: str) -> None: if not (0 <= shard_idx < self._num_shards if self._num_shards is not None else 0): raise KeyError(f"Chunk key '{key}' is out of bounds.") - target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) - if target_shard_list[index_in_shard] is not None: - target_shard_list[index_in_shard] = None - self._dirty_shards.add(shard_idx) + shard_lock = self._shard_locks[shard_idx] + async with shard_lock: + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + + if target_shard_list[index_in_shard] is not None: + target_shard_list[index_in_shard] = None + self._dirty_shards.add(shard_idx) # ... (Keep listing methods as they are, they operate on metadata) ... @property @@ -611,11 +616,13 @@ async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, .. linear_global_index = self._get_linear_chunk_index(global_coords) global_shard_idx, index_in_global_shard = self._get_shard_info(linear_global_index) - target_shard_list = await self._load_or_initialize_shard_cache(global_shard_idx) - - if target_shard_list[index_in_global_shard] != pointer_cid_obj: - target_shard_list[index_in_global_shard] = pointer_cid_obj - self._dirty_shards.add(global_shard_idx) + shard_lock = self._shard_locks[global_shard_idx] + async with shard_lock: + target_shard_list = await self._load_or_initialize_shard_cache(global_shard_idx) + + if target_shard_list[index_in_global_shard] != pointer_cid_obj: + target_shard_list[index_in_global_shard] = pointer_cid_obj + self._dirty_shards.add(global_shard_idx) print(f"✓ Grafting complete for store {store_to_graft_cid[:10]}...") From 2d18f4de4b3a00a5665a76b9a74abd8a13fb1fb1 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 08:14:41 -0400 Subject: [PATCH 43/74] fix: more tests --- fsgs.py | 2 +- public_gateway_example.py | 2 +- py_hamt/sharded_zarr_store.py | 200 ++++++++++++++++++------------- py_hamt/store_httpx.py | 2 +- tests/test_cpc_compare.py | 191 ++++++++++++++++------------- tests/test_public_gateway.py | 2 +- tests/test_sharded_zarr_store.py | 14 ++- 7 files changed, 233 insertions(+), 180 deletions(-) diff --git a/fsgs.py b/fsgs.py index 0ff4462..610863c 100644 --- a/fsgs.py +++ b/fsgs.py @@ -16,7 +16,7 @@ async def main(): - cid = "bafyr4iecw3faqyvj75psutabk2jxpddpjdokdy5b26jdnjjzpkzbgb5xoq" + cid = "bafyr4idgcwyxddd2mlskpo7vltcicf5mtozlzt4vzpivqmn343hk3c5nbu" # Use KuboCAS as an async context manager async with KuboCAS() as kubo_cas: # connects to a local kubo node diff --git a/public_gateway_example.py b/public_gateway_example.py index 186aa4d..a9c02e2 100644 --- a/public_gateway_example.py +++ b/public_gateway_example.py @@ -53,7 +53,7 @@ async def fetch_zarr_from_gateway(cid: str, gateway: str = "https://ipfs.io"): async def main(): # Example CID - this points to a weather dataset stored on IPFS - cid = "bafyr4iecw3faqyvj75psutabk2jxpddpjdokdy5b26jdnjjzpkzbgb5xoq" + cid = "bafyr4idgcwyxddd2mlskpo7vltcicf5mtozlzt4vzpivqmn343hk3c5nbu" # Try different public gateways gateways = [ diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 2efcfe3..8e436f1 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -1,16 +1,16 @@ import asyncio +import itertools +import json +import logging import math +from collections import defaultdict from collections.abc import AsyncIterator, Iterable -from typing import Coroutine, Dict, List, Optional, Set, Tuple -import json -import itertools -import logging +from typing import DefaultDict, Dict, List, Optional, Set, Tuple -from collections import defaultdict import dag_cbor -from multiformats.cid import CID import zarr.abc.store import zarr.core.buffer +from multiformats.cid import CID from zarr.core.common import BytesLike from .store_httpx import ContentAddressedStore @@ -20,7 +20,6 @@ class ShardedZarrStore(zarr.abc.store.Store): """ Implements the Zarr Store API using a sharded layout for chunk CIDs. - # CHANGED: Docstring updated to reflect DAG-CBOR format. This store divides the flat index of chunk CIDs into multiple "shards". Each shard is a DAG-CBOR array where each element is either a CID link to a chunk or a null value if the chunk is empty. This structure allows @@ -45,21 +44,15 @@ def __init__( self._root_obj: Optional[dict] = None self._resize_lock = asyncio.Lock() - # An event to signal when a resize is in-progress. # It starts in the "set" state, allowing all operations to proceed. self._resize_complete = asyncio.Event() self._resize_complete.set() + self._shard_locks: DefaultDict[int, asyncio.Lock] = defaultdict(asyncio.Lock) - self._shard_locks = defaultdict(asyncio.Lock) - - self._shard_data_cache: Dict[ - int, list[Optional[CID]] - ] = {} + self._shard_data_cache: Dict[int, list[Optional[CID]]] = {} self._dirty_shards: Set[int] = set() - self._pending_shard_loads: Dict[ - int, asyncio.Task - ] = {} + self._pending_shard_loads: Dict[int, asyncio.Task] = {} self._array_shape: Optional[Tuple[int, ...]] = None self._chunk_shape: Optional[Tuple[int, ...]] = None @@ -70,10 +63,13 @@ def __init__( self._dirty_root = False - def _update_geometry(self): """Calculates derived geometric properties from the base shapes.""" - if self._array_shape is None or self._chunk_shape is None or self._chunks_per_shard is None: + if ( + self._array_shape is None + or self._chunk_shape is None + or self._chunks_per_shard is None + ): raise RuntimeError("Base shape information is not set.") if not all(cs > 0 for cs in self._chunk_shape): @@ -82,7 +78,8 @@ def _update_geometry(self): raise ValueError("All array_shape dimensions must be non-negative.") self._chunks_per_dim = tuple( - math.ceil(a / c) if c > 0 else 0 for a, c in zip(self._array_shape, self._chunk_shape) + math.ceil(a / c) if c > 0 else 0 + for a, c in zip(self._array_shape, self._chunk_shape) ) self._total_chunks = math.prod(self._chunks_per_dim) @@ -118,9 +115,7 @@ async def open( if not isinstance(chunks_per_shard, int) or chunks_per_shard <= 0: raise ValueError("chunks_per_shard must be a positive integer.") - store._initialize_new_root( - array_shape, chunk_shape, chunks_per_shard - ) + store._initialize_new_root(array_shape, chunk_shape, chunks_per_shard) else: raise ValueError("root_cid must be provided for a read-only store.") return store @@ -137,13 +132,15 @@ def _initialize_new_root( self._update_geometry() + if self._num_shards is None: + raise RuntimeError("Number of shards not set after geometry update.") + self._root_obj = { - "manifest_version": "sharded_zarr_v1", # CHANGED: Version reflects new format + "manifest_version": "sharded_zarr_v1", "metadata": {}, "chunks": { "array_shape": list(self._array_shape), "chunk_shape": list(self._chunk_shape), - # REMOVED: cid_byte_length is no longer relevant "sharding_config": { "chunks_per_shard": self._chunks_per_shard, }, @@ -176,10 +173,8 @@ async def _load_root_from_cid(self): ) async def _fetch_and_cache_full_shard(self, shard_idx: int, shard_cid: str): - # CHANGED: Logic now decodes the shard from DAG-CBOR into a list. try: shard_data_bytes = await self.cas.load(shard_cid) - # Decode the CBOR object, which should be a list of CIDs/None decoded_shard = dag_cbor.decode(shard_data_bytes) if not isinstance(decoded_shard, list): raise TypeError(f"Shard {shard_idx} did not decode to a list.") @@ -191,8 +186,7 @@ async def _fetch_and_cache_full_shard(self, shard_idx: int, shard_cid: str): finally: if shard_idx in self._pending_shard_loads: del self._pending_shard_loads[shard_idx] - - # ... (Keep _parse_chunk_key, _get_linear_chunk_index, _get_shard_info as they are) ... + def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # 1. Exclude .json files immediately (metadata) if key.endswith(".json"): @@ -226,9 +220,6 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # Now, proceed with the original parsing logic using self._array_shape and # self._chunks_per_dim, which should be configured for this main data variable. - # print( - # f"Parsing chunk key: {key} for array: {actual_array_name} with shape: {self._array_shape} and chunks_per_dim: {self._chunks_per_dim}") - if not self._array_shape or not self._chunks_per_dim: # This ShardedZarrStore instance is not properly initialized # with the shape/chunking info for the array it's supposed to manage. @@ -306,10 +297,10 @@ async def _load_or_initialize_shard_cache(self, shard_idx: int) -> list: raise RuntimeError("Store not initialized: _chunks_per_shard is None.") # Initialize new shard as a list of Nones self._shard_data_cache[shard_idx] = [None] * self._chunks_per_shard - + if shard_idx not in self._shard_data_cache: - raise RuntimeError(f"Failed to load or initialize shard {shard_idx}") - + raise RuntimeError(f"Failed to load or initialize shard {shard_idx}") + return self._shard_data_cache[shard_idx] async def set_partial_values( @@ -333,7 +324,7 @@ def __eq__(self, other: object) -> bool: return NotImplemented # For equality, root CID is primary. Config like chunks_per_shard is part of that root's identity. return self._root_cid == other._root_cid - + async def flush(self) -> str: # CHANGED: This method now encodes shards using DAG-CBOR. if self.read_only: @@ -358,12 +349,18 @@ async def flush(self) -> str: # Save the DAG-CBOR block and get its CID new_shard_cid_obj = await self.cas.save( - shard_data_bytes, codec="dag-cbor" # Use 'dag-cbor' codec + shard_data_bytes, + codec="dag-cbor", # Use 'dag-cbor' codec ) - if self._root_obj["chunks"]["shard_cids"][shard_idx] != new_shard_cid_obj: + if ( + self._root_obj["chunks"]["shard_cids"][shard_idx] + != new_shard_cid_obj + ): # Store the CID object directly - self._root_obj["chunks"]["shard_cids"][shard_idx] = new_shard_cid_obj + self._root_obj["chunks"]["shard_cids"][shard_idx] = ( + new_shard_cid_obj + ) self._dirty_root = True self._dirty_shards.clear() @@ -419,9 +416,9 @@ async def get( # Get the CID object (or None) from the cached list. chunk_cid_obj = target_shard_list[index_in_shard] - + if chunk_cid_obj is None: - return None # Chunk is empty/doesn't exist. + return None # Chunk is empty/doesn't exist. chunk_cid_str = str(chunk_cid_obj) @@ -448,7 +445,13 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: await self._resize_complete.wait() - if key.endswith("zarr.json") and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) and not len(key) == 9: + if ( + key.endswith("zarr.json") + and not key.startswith("time/") + and not key.startswith(("lat/", "latitude/")) + and not key.startswith(("lon/", "longitude/")) + and not len(key) == 9 + ): metadata_json = json.loads(value.to_bytes().decode("utf-8")) new_array_shape = metadata_json.get("shape") if not new_array_shape: @@ -461,7 +464,7 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: # Block all other tasks until resize is complete. self._resize_complete.clear() try: - await self.resize_store(new_shape=tuple(new_array_shape)) + await self.resize_store(new_shape=tuple(new_array_shape)) finally: # All waiting tasks will now un-pause and proceed safely. self._resize_complete.set() @@ -476,10 +479,10 @@ async def set_pointer(self, key: str, pointer: str) -> None: # CHANGED: Logic now updates a list in the cache, not a bytearray. if self._root_obj is None: raise RuntimeError("Load the root object first before accessing data.") - + chunk_coords = self._parse_chunk_key(key) - - pointer_cid_obj = CID.decode(pointer) # Convert string to CID object + + pointer_cid_obj = CID.decode(pointer) # Convert string to CID object if chunk_coords is None: # Metadata key self._root_obj["metadata"][key] = pointer_cid_obj @@ -493,7 +496,7 @@ async def set_pointer(self, key: str, pointer: str) -> None: shard_lock = self._shard_locks[shard_idx] async with shard_lock: target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) - + if target_shard_list[index_in_shard] != pointer_cid_obj: target_shard_list[index_in_shard] = pointer_cid_obj self._dirty_shards.add(shard_idx) @@ -524,7 +527,7 @@ async def exists(self, key: str) -> bool: return target_shard_list[index_in_shard] is not None except Exception: return False - + # ... (Keep supports_writes, etc. properties) ... @property def supports_writes(self) -> bool: @@ -544,7 +547,7 @@ async def delete(self, key: str) -> None: raise ValueError("Cannot delete from a read-only store.") if self._root_obj is None: raise RuntimeError("Store not initialized for deletion.") - + chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata if self._root_obj["metadata"].pop(key, None): @@ -556,19 +559,18 @@ async def delete(self, key: str) -> None: linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - if not (0 <= shard_idx < self._num_shards if self._num_shards is not None else 0): + if not ( + 0 <= shard_idx < self._num_shards if self._num_shards is not None else 0 + ): raise KeyError(f"Chunk key '{key}' is out of bounds.") - shard_lock = self._shard_locks[shard_idx] async with shard_lock: target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) - if target_shard_list[index_in_shard] is not None: target_shard_list[index_in_shard] = None self._dirty_shards.add(shard_idx) - # ... (Keep listing methods as they are, they operate on metadata) ... @property def supports_listing(self) -> bool: return True @@ -585,59 +587,72 @@ async def list_prefix(self, prefix: str) -> AsyncIterator[str]: async for key in self.list(): if key.startswith(prefix): yield key - # ... (Keep graft_store, but it needs significant changes) ... - + async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, ...]): - # CHANGED: This method is heavily modified to work with the new DAG-CBOR format. if self.read_only: raise ValueError("Cannot graft onto a read-only store.") if self._root_obj is None: raise RuntimeError("Main store must be initialized before grafting.") - - print(f"Grafting store {store_to_graft_cid[:10]}... at chunk offset {chunk_offset}") - - store_to_graft = await ShardedZarrStore.open(cas=self.cas, read_only=True, root_cid=store_to_graft_cid) + store_to_graft = await ShardedZarrStore.open( + cas=self.cas, read_only=True, root_cid=store_to_graft_cid + ) if store_to_graft._root_obj is None or store_to_graft._chunks_per_dim is None: - raise ValueError("Store to graft could not be loaded or is not configured.") + raise ValueError("Store to graft could not be loaded or is not configured.") source_chunk_grid = store_to_graft._chunks_per_dim for local_coords in itertools.product(*[range(s) for s in source_chunk_grid]): linear_local_index = store_to_graft._get_linear_chunk_index(local_coords) - local_shard_idx, index_in_local_shard = store_to_graft._get_shard_info(linear_local_index) + local_shard_idx, index_in_local_shard = store_to_graft._get_shard_info( + linear_local_index + ) # Load the source shard into its cache - source_shard_list = await store_to_graft._load_or_initialize_shard_cache(local_shard_idx) - + source_shard_list = await store_to_graft._load_or_initialize_shard_cache( + local_shard_idx + ) + pointer_cid_obj = source_shard_list[index_in_local_shard] if pointer_cid_obj is None: continue # Calculate global coordinates and write to the main store's index - global_coords = tuple(c_local + c_offset for c_local, c_offset in zip(local_coords, chunk_offset)) + global_coords = tuple( + c_local + c_offset + for c_local, c_offset in zip(local_coords, chunk_offset) + ) linear_global_index = self._get_linear_chunk_index(global_coords) - global_shard_idx, index_in_global_shard = self._get_shard_info(linear_global_index) - + global_shard_idx, index_in_global_shard = self._get_shard_info( + linear_global_index + ) + shard_lock = self._shard_locks[global_shard_idx] async with shard_lock: - target_shard_list = await self._load_or_initialize_shard_cache(global_shard_idx) - + target_shard_list = await self._load_or_initialize_shard_cache( + global_shard_idx + ) + if target_shard_list[index_in_global_shard] != pointer_cid_obj: target_shard_list[index_in_global_shard] = pointer_cid_obj self._dirty_shards.add(global_shard_idx) - print(f"✓ Grafting complete for store {store_to_graft_cid[:10]}...") - - # ... (Keep resizing methods as they mostly affect metadata) ... async def resize_store(self, new_shape: Tuple[int, ...]): """ Resizes the store's main shard index to accommodate a new overall array shape. This is a metadata-only operation on the store's root object. + Used when doing skeleton writes or appends via xarray where the array shape changes. """ if self.read_only: raise ValueError("Cannot resize a read-only store.") - if self._root_obj is None or self._chunk_shape is None or self._chunks_per_shard is None: + if ( + self._root_obj is None + or self._chunk_shape is None + or self._chunks_per_shard is None + or self._array_shape is None + ): raise RuntimeError("Store is not properly initialized for resizing.") if len(new_shape) != len(self._array_shape): - raise ValueError("New shape must have the same number of dimensions as the old shape.") + raise ValueError( + "New shape must have the same number of dimensions as the old shape." + ) self._array_shape = tuple(new_shape) self._chunks_per_dim = tuple( @@ -646,16 +661,23 @@ async def resize_store(self, new_shape: Tuple[int, ...]): ) self._total_chunks = math.prod(self._chunks_per_dim) old_num_shards = self._num_shards if self._num_shards is not None else 0 - self._num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) if self._total_chunks > 0 else 0 + self._num_shards = ( + math.ceil(self._total_chunks / self._chunks_per_shard) + if self._total_chunks > 0 + else 0 + ) self._root_obj["chunks"]["array_shape"] = list(self._array_shape) if self._num_shards > old_num_shards: - self._root_obj["chunks"]["shard_cids"].extend([None] * (self._num_shards - old_num_shards)) + self._root_obj["chunks"]["shard_cids"].extend( + [None] * (self._num_shards - old_num_shards) + ) elif self._num_shards < old_num_shards: - self._root_obj["chunks"]["shard_cids"] = self._root_obj["chunks"]["shard_cids"][:self._num_shards] + self._root_obj["chunks"]["shard_cids"] = self._root_obj["chunks"][ + "shard_cids" + ][: self._num_shards] self._dirty_root = True - async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): """ Resizes the Zarr metadata for a specific variable (e.g., '.json' file). @@ -668,23 +690,31 @@ async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): # Zarr v2 uses .json, not zarr.json zarr_metadata_key = f"{variable_name}/zarr.json" - + old_zarr_metadata_cid = self._root_obj["metadata"].get(zarr_metadata_key) if not old_zarr_metadata_cid: - raise KeyError(f"Cannot find metadata for key '{zarr_metadata_key}' to resize.") + raise KeyError( + f"Cannot find metadata for key '{zarr_metadata_key}' to resize." + ) old_zarr_metadata_bytes = await self.cas.load(old_zarr_metadata_cid) zarr_metadata_json = json.loads(old_zarr_metadata_bytes) - + zarr_metadata_json["shape"] = list(new_shape) - - new_zarr_metadata_bytes = json.dumps(zarr_metadata_json, indent=2).encode('utf-8') + + new_zarr_metadata_bytes = json.dumps(zarr_metadata_json, indent=2).encode( + "utf-8" + ) # Metadata is a raw blob of bytes - new_zarr_metadata_cid = await self.cas.save(new_zarr_metadata_bytes, codec='raw') - + new_zarr_metadata_cid = await self.cas.save( + new_zarr_metadata_bytes, codec="raw" + ) + self._root_obj["metadata"][zarr_metadata_key] = str(new_zarr_metadata_cid) self._dirty_root = True - print(f"Resized metadata for variable '{variable_name}'. New shape: {new_shape}") + print( + f"Resized metadata for variable '{variable_name}'. New shape: {new_shape}" + ) async def list_dir(self, prefix: str) -> AsyncIterator[str]: # This simplified version only works for the root directory (prefix == "") of metadata. diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index 8d0a51c..4547c87 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -237,7 +237,7 @@ def __init__( rpc_url += f"&chunker={chunker}" self.rpc_url = rpc_url - + """@private""" self.gateway_base_url: str = gateway_base_url """@private""" diff --git a/tests/test_cpc_compare.py b/tests/test_cpc_compare.py index 7de9807..3680591 100644 --- a/tests/test_cpc_compare.py +++ b/tests/test_cpc_compare.py @@ -1,86 +1,106 @@ -import time - -import numpy as np -import pandas as pd -import pytest -import xarray as xr -from dag_cbor.ipld import IPLDKind -from multiformats import CID - -# Import both store implementations -from py_hamt import HAMT, KuboCAS, ShardedZarrStore -from py_hamt.zarr_hamt_store import ZarrHAMTStore - - -@pytest.mark.asyncio(loop_scope="session") -async def test_benchmark_sharded_store(): - """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring - print("\n\n" + "=" * 80) - print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print - print("=" * 80) - - - rpc_base_url = f"https://ipfs-gateway.dclimate.net" - gateway_base_url = f"https://ipfs-gateway.dclimate.net" - headers = { - "X-API-Key": "", - } - - async with KuboCAS( - rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers - ) as kubo_cas: - # --- Write --- - root_cid = "bafyr4ienfetuujjqeqhrjvtr6dpcfh2bdowxrofsgl6dz5oknqauhxicie" - print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print - # --- Read --- - start = time.perf_counter() - # When opening for read, chunks_per_shard is read from the store's metadata - store_read = await ShardedZarrStore.open( # Use ShardedZarrStore - cas=kubo_cas, read_only=True, root_cid=root_cid - ) - stop = time.perf_counter() - print(f"Total time to open ShardedZarrStore: {stop - start:.2f} seconds") - print(f"Opened ShardedZarrStore for reading with root CID: {root_cid}") - - start_read = time.perf_counter() - ipfs_ds = xr.open_zarr(store=store_read) - start_read = time.perf_counter() - print(ipfs_ds) - stop_read = time.perf_counter() - print(f"Total time to read dataset: {stop_read - start_read:.2f} seconds") - # start_read = time.perf_counter() - # print(ipfs_ds.variables, ipfs_ds.coords) # Print available variables and coordinates for debugging - # stop_read = time.perf_counter() - # print(f"Total time to read dataset variables and coordinates: {stop_read - start_read:.2f} seconds") - start_read = time.perf_counter() - # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) - if "2m_temperature" in ipfs_ds.variables and "time" in ipfs_ds.coords: - print("Fetching '2m_temperature' data...") - data_fetched = ipfs_ds["2m_temperature"].isel(time=0).values - # data_fetched = ipfs_ds["2m_temperature"].values - - # Calculate the size of the fetched data - data_size = data_fetched.nbytes if data_fetched is not None else 0 - print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") - elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable - first_var_name = list(ipfs_ds.data_vars.keys())[0] - # Construct a minimal selection based on available dimensions - selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} - if selection: - _ = ipfs_ds[first_var_name].isel(**selection).values - else: # If no dimensions, try loading the whole variable (e.g. scalar) - _ = ipfs_ds[first_var_name].values - end_read = time.perf_counter() - - print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print - print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") - print("=" * 80) - # Speed in MB/s - if data_size > 0: - speed = data_size / (end_read - start_read) / (1024 * 1024) - print(f"Read speed: {speed:.2f} MB/s") - else: - print("No data fetched, cannot calculate speed.") +# import time + +# import numpy as np +# import pandas as pd +# import pytest +# import xarray as xr +# from dag_cbor.ipld import IPLDKind +# from multiformats import CID + +# # Import both store implementations +# from py_hamt import HAMT, KuboCAS, ShardedZarrStore +# from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_benchmark_sharded_store(): +# """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring +# print("\n\n" + "=" * 80) +# print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print +# print("=" * 80) + + +# rpc_base_url = f"https://ipfs-gateway.dclimate.net" +# gateway_base_url = f"https://ipfs-gateway.dclimate.net" +# headers = { +# "X-API-Key": "", +# } + +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# ) as kubo_cas: +# # --- Write --- +# # root_cid = "bafyr4ifayhevbtfg2qzffuicic3rwli4fhnnkhrfduuxkwvetppfk4ogbe" +# root_cid = "bafyr4ifs4oejlvtwvb57udbbhba5yllss4cixjkxrevq54g3mo5kwknpwy" +# print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print +# # --- Read --- +# start = time.perf_counter() +# # When opening for read, chunks_per_shard is read from the store's metadata +# store_read = await ShardedZarrStore.open( # Use ShardedZarrStore +# cas=kubo_cas, read_only=True, root_cid=root_cid +# ) +# stop = time.perf_counter() +# print(f"Total time to open ShardedZarrStore: {stop - start:.2f} seconds") +# print(f"Opened ShardedZarrStore for reading with root CID: {root_cid}") + +# start_read = time.perf_counter() +# ipfs_ds = xr.open_zarr(store=store_read) +# start_read = time.perf_counter() +# print(ipfs_ds) +# stop_read = time.perf_counter() +# print(f"Total time to read dataset: {stop_read - start_read:.2f} seconds") +# # start_read = time.perf_counter() +# # print(ipfs_ds.variables, ipfs_ds.coords) # Print available variables and coordinates for debugging +# # stop_read = time.perf_counter() +# # print(f"Total time to read dataset variables and coordinates: {stop_read - start_read:.2f} seconds") +# start_read = time.perf_counter() +# # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) +# if "FPAR" in ipfs_ds.variables and "time" in ipfs_ds.coords: +# print("Fetching 'FPAR' data...") + +# # Define date range +# date_from = "2000-05-15" +# date_to = "2004-05-30" + +# # Define bounding box from polygon coordinates +# min_lon, max_lon = 4.916695, 5.258908 +# min_lat, max_lat = 51.921763, 52.160344 + +# print(ipfs_ds["FPAR"].sel( +# time=slice(date_from, date_to), +# latitude=slice(min_lat, max_lat), +# longitude=slice(min_lon, max_lon) +# )) + +# # Fetch data for the specified time and region +# data_fetched = ipfs_ds["FPAR"].sel( +# time=slice(date_from, date_to), +# latitude=slice(min_lat, max_lat), +# longitude=slice(min_lon, max_lon) +# ).values + +# # Calculate the size of the fetched data +# data_size = data_fetched.nbytes if data_fetched is not None else 0 +# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") +# elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable +# first_var_name = list(ipfs_ds.data_vars.keys())[0] +# # Construct a minimal selection based on available dimensions +# selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} +# if selection: +# _ = ipfs_ds[first_var_name].isel(**selection).values +# else: # If no dimensions, try loading the whole variable (e.g. scalar) +# _ = ipfs_ds[first_var_name].values +# end_read = time.perf_counter() + +# print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print +# print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") +# print("=" * 80) +# # Speed in MB/s +# if data_size > 0: +# speed = data_size / (end_read - start_read) / (1024 * 1024) +# print(f"Read speed: {speed:.2f} MB/s") +# else: +# print("No data fetched, cannot calculate speed.") # # ### # # BENCHMARK FOR THE ORIGINAL ZarrHAMTStore @@ -97,13 +117,13 @@ async def test_benchmark_sharded_store(): # # headers = { # # "X-API-Key": "", # # } -# # headers = {} +# headers = {} # async with KuboCAS( # rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers # ) as kubo_cas: -# root_cid = "bafyr4ialorauxcpw77mgmnyoeptn4g4zkqdqhtsobff4v76rllvd3m6cqi" +# root_cid = "bafyr4igl3pmswu5pfzb6dcgcxj3ipxlpxxxad7j7tf45obxe5pkp4xgpwe" # # root_node_id = CID.decode(root_cid) # hamt = await HAMT.build( @@ -113,6 +133,7 @@ async def test_benchmark_sharded_store(): # ipfs_ds: xr.Dataset # zhs = ZarrHAMTStore(hamt, read_only=True) # ipfs_ds = xr.open_zarr(store=zhs) +# print(ipfs_ds) # # --- Read --- # hamt = HAMT(cas=kubo_cas, values_are_bytes=True, root_node_id=root_cid, read_only=True) diff --git a/tests/test_public_gateway.py b/tests/test_public_gateway.py index c76c414..7bea68a 100644 --- a/tests/test_public_gateway.py +++ b/tests/test_public_gateway.py @@ -6,7 +6,7 @@ from py_hamt import KuboCAS -TEST_CID = "bafyr4iecw3faqyvj75psutabk2jxpddpjdokdy5b26jdnjjzpkzbgb5xoq" +TEST_CID = "bafyr4idgcwyxddd2mlskpo7vltcicf5mtozlzt4vzpivqmn343hk3c5nbu" async def verify_response_content(url: str, client=None): diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 9555353..d85beb1 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -67,6 +67,7 @@ async def test_sharded_zarr_store_write_read( ds_read = xr.open_zarr(store=store_read) xr.testing.assert_identical(test_ds, ds_read) + @pytest.mark.asyncio async def test_sharded_zarr_store_append( create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset @@ -78,9 +79,6 @@ async def test_sharded_zarr_store_append( rpc_base_url, gateway_base_url = create_ipfs initial_ds = random_zarr_dataset - # The main data variable we are sharding - main_variable = "temp" - ordered_dims = list(initial_ds.sizes) array_shape_tuple = tuple(initial_ds.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(initial_ds.chunks[dim][0] for dim in ordered_dims) @@ -102,9 +100,13 @@ async def test_sharded_zarr_store_append( # 2. --- Prepare Data to Append --- # Create a new dataset with 50 more time steps - append_times = pd.date_range(initial_ds.time[-1].values + pd.Timedelta(days=1), periods=50) - append_temp = np.random.randn(len(append_times), len(initial_ds.lat), len(initial_ds.lon)) - + append_times = pd.date_range( + initial_ds.time[-1].values + pd.Timedelta(days=1), periods=50 + ) + append_temp = np.random.randn( + len(append_times), len(initial_ds.lat), len(initial_ds.lon) + ) + append_ds = xr.Dataset( { "temp": (["time", "lat", "lon"], append_temp), From fb4fc37a8a1a7e742befe8f8365b0b0711798f2c Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 08:52:21 -0400 Subject: [PATCH 44/74] fix: remove comment --- py_hamt/sharded_zarr_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 8e436f1..7c26cad 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -98,7 +98,6 @@ async def open( array_shape: Optional[Tuple[int, ...]] = None, chunk_shape: Optional[Tuple[int, ...]] = None, chunks_per_shard: Optional[int] = None, - # REMOVED: cid_len is no longer needed. ) -> "ShardedZarrStore": """ Asynchronously opens an existing ShardedZarrStore or initializes a new one. From eded0431449b9e344079656b6c7182aa0b8e750c Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 09:17:54 -0400 Subject: [PATCH 45/74] fix: tests --- py_hamt/hamt_to_sharded_converter.py | 4 ++-- tests/test_kubo_cas.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index 46bbf3f..c50ba23 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -36,8 +36,8 @@ async def convert_hamt_to_sharded( print("Reading metadata from source store...") # Read the stores metadata to get array shape and chunk shape - ordered_dims = list(source_dataset.dims) - array_shape_tuple = tuple(source_dataset.dims[dim] for dim in ordered_dims) + ordered_dims = list(source_dataset.sizes) + array_shape_tuple = tuple(source_dataset.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(source_dataset.chunks[dim][0] for dim in ordered_dims) array_shape = array_shape_tuple chunk_shape = chunk_shape_tuple diff --git a/tests/test_kubo_cas.py b/tests/test_kubo_cas.py index db138b1..627944a 100644 --- a/tests/test_kubo_cas.py +++ b/tests/test_kubo_cas.py @@ -162,5 +162,5 @@ async def test_chunker_valid_patterns(): ) async def test_chunker_invalid_patterns(invalid): with pytest.raises(ValueError, match="Invalid chunker specification"): - async with KuboCAS(chunker=invalid): + async with KuboCAS(chunker=invalid) as cas: pass From cfda0c7b64d7cdd1820ffcbd164558226fda1f22 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 7 Jul 2025 09:24:53 -0400 Subject: [PATCH 46/74] fix: update test --- py_hamt/store_httpx.py | 17 +++++++++-------- tests/test_kubo_cas.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index 42af066..2ca88fe 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -217,6 +217,13 @@ def __init__( These are the first part of the url, defaults that refer to the default that kubo launches with on a local machine are provided. """ + self._owns_client: bool = False + self._closed: bool = True + self._client_per_loop: Dict[asyncio.AbstractEventLoop, httpx.AsyncClient] = {} + self._default_headers = headers + self._default_auth = auth + + # Now, perform validation that might raise an exception chunker_pattern = r"(?:size-[1-9]\d*|rabin(?:-[1-9]\d*-[1-9]\d*-[1-9]\d*)?)" if re.fullmatch(chunker_pattern, chunker) is None: raise ValueError("Invalid chunker specification") @@ -245,21 +252,15 @@ def __init__( self.gateway_base_url: str = gateway_base_url """@private""" - self._client_per_loop: Dict[asyncio.AbstractEventLoop, httpx.AsyncClient] = {} - if client is not None: # user supplied → bind it to *their* current loop self._client_per_loop[asyncio.get_running_loop()] = client - self._owns_client: bool = False + self._owns_client = False else: self._owns_client = True # we'll create clients lazily - # store for later use by _loop_client() - self._default_headers = headers - self._default_auth = auth - self._sem: asyncio.Semaphore = asyncio.Semaphore(concurrency) - self._closed: bool = False + self._closed = False # --------------------------------------------------------------------- # # helper: get or create the client bound to the current running loop # diff --git a/tests/test_kubo_cas.py b/tests/test_kubo_cas.py index 627944a..db138b1 100644 --- a/tests/test_kubo_cas.py +++ b/tests/test_kubo_cas.py @@ -162,5 +162,5 @@ async def test_chunker_valid_patterns(): ) async def test_chunker_invalid_patterns(invalid): with pytest.raises(ValueError, match="Invalid chunker specification"): - async with KuboCAS(chunker=invalid) as cas: + async with KuboCAS(chunker=invalid): pass From be0048bb68f43125a4502ceec66c0636288c0c3e Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Tue, 8 Jul 2025 07:17:55 -0400 Subject: [PATCH 47/74] fix: add concat test --- tests/test_sharded_zarr_store.py | 123 +++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index d85beb1..75577ea 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -790,3 +790,126 @@ async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, st await ShardedZarrStore.open( cas=kubo_cas, read_only=True, root_cid=invalid_root_cid ) + +@pytest.mark.asyncio +async def test_sharded_zarr_store_lazy_concat( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests lazy concatenation of two xarray datasets stored in separate ShardedZarrStores, + ensuring the combined dataset can be queried as a single dataset with data fetched + correctly from the respective stores. + """ + rpc_base_url, gateway_base_url = create_ipfs + base_ds = random_zarr_dataset + + # 1. --- Prepare Two Datasets with Distinct Time Ranges --- + # First dataset: August 1, 2024 to September 30, 2024 (61 days) + aug_sep_times = pd.date_range("2024-08-01", "2024-09-30", freq="D") + aug_sep_temp = np.random.randn(len(aug_sep_times), len(base_ds.lat), len(base_ds.lon)) + ds1 = xr.Dataset( + { + "temp": (["time", "lat", "lon"], aug_sep_temp), + }, + coords={"time": aug_sep_times, "lat": base_ds.lat, "lon": base_ds.lon}, + ).chunk({"time": 20, "lat": 18, "lon": 36}) + + # Second dataset: October 1, 2024 to November 30, 2024 (61 days) + oct_nov_times = pd.date_range("2024-10-01", "2024-11-30", freq="D") + oct_nov_temp = np.random.randn(len(oct_nov_times), len(base_ds.lat), len(base_ds.lon)) + ds2 = xr.Dataset( + { + "temp": (["time", "lat", "lon"], oct_nov_temp), + }, + coords={"time": oct_nov_times, "lat": base_ds.lat, "lon": base_ds.lon}, + ).chunk({"time": 20, "lat": 18, "lon": 36}) + + # Expected concatenated dataset for verification + expected_combined_ds = xr.concat([ds1, ds2], dim="time") + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # 2. --- Write First Dataset to ShardedZarrStore --- + ordered_dims = list(ds1.sizes) + array_shape_tuple = tuple(ds1.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(ds1.chunks[dim][0] for dim in ordered_dims) + + store1_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + ds1.to_zarr(store=store1_write, mode="w") + root_cid1 = await store1_write.flush() + assert root_cid1 is not None + + # 3. --- Write Second Dataset to ShardedZarrStore --- + array_shape_tuple = tuple(ds2.sizes[dim] for dim in ordered_dims) + store2_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + ds2.to_zarr(store=store2_write, mode="w") + root_cid2 = await store2_write.flush() + assert root_cid2 is not None + + # 4. --- Read and Lazily Concatenate Datasets --- + store1_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid1 + ) + store2_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid2 + ) + + ds1_read = xr.open_zarr(store=store1_read, chunks="auto") + ds2_read = xr.open_zarr(store=store2_read, chunks="auto") + + # Verify that datasets are lazy (Dask-backed) + assert ds1_read.temp.chunks is not None + assert ds2_read.temp.chunks is not None + + # Lazily concatenate along time dimension + combined_ds = xr.concat([ds1_read, ds2_read], dim="time") + + # Verify that the combined dataset is still lazy + assert combined_ds.temp.chunks is not None + + # 5. --- Query Across Both Datasets --- + # Select a time slice that spans both datasets (e.g., Sep 15 to Oct 15) + query_slice = combined_ds.sel(time=slice("2024-09-15", "2024-10-15")) + + # Verify that the query is still lazy + assert query_slice.temp.chunks is not None + + # Compute the result to trigger data loading + query_result = query_slice.compute() + + # 6. --- Verify Results --- + # Compare with the expected concatenated dataset + expected_query_result = expected_combined_ds.sel( + time=slice("2024-09-15", "2024-10-15") + ) + xr.testing.assert_identical(query_result, expected_query_result) + + # Verify specific values at a point to ensure data integrity + sample_time = pd.Timestamp("2024-09-30") # From ds1 + sample_result = query_result.sel(time=sample_time, method="nearest").temp.values + expected_sample = expected_combined_ds.sel( + time=sample_time, method="nearest" + ).temp.values + np.testing.assert_array_equal(sample_result, expected_sample) + + sample_time = pd.Timestamp("2024-10-01") # From ds2 + sample_result = query_result.sel(time=sample_time, method="nearest").temp.values + expected_sample = expected_combined_ds.sel( + time=sample_time, method="nearest" + ).temp.values + np.testing.assert_array_equal(sample_result, expected_sample) + + print("\n✅ Lazy concatenation test successful! Data verified.") \ No newline at end of file From d1ebf1fc050bc6e875ef03132aa7d70aef944a66 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Tue, 8 Jul 2025 08:40:59 -0400 Subject: [PATCH 48/74] fix: helpful test for local testing --- tests/test_sharded_zarr_store.py | 109 ++++++++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 75577ea..8ac1af4 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -912,4 +912,111 @@ async def test_sharded_zarr_store_lazy_concat( ).temp.values np.testing.assert_array_equal(sample_result, expected_sample) - print("\n✅ Lazy concatenation test successful! Data verified.") \ No newline at end of file + print("\n✅ Lazy concatenation test successful! Data verified.") + +# @pytest.mark.asyncio +# async def test_sharded_zarr_store_lazy_concat_with_cids( +# create_ipfs: tuple[str, str] +# ): +# """ +# Tests lazy concatenation of two xarray datasets stored in separate ShardedZarrStores +# using provided CIDs for finalized and non-finalized data, ensuring the non-finalized +# dataset is sliced after the finalization date (inclusive) and the combined dataset +# can be queried as a single dataset with data fetched correctly from the respective stores. +# """ +# rpc_base_url, gateway_base_url = create_ipfs + +# # Provided CIDs +# finalized_cid = "bafyr4icrox4pxashkfmbyztn7jhp6zjlpj3bufg5ggsjux74zr7ocnqdpu" +# non_finalized_cid = "bafyr4ibj3bfl5oo7bf6gagzr2g33jlnf23mq2xo632mbl6ytfry7jbuepy" +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url +# ) as kubo_cas: +# # 1. --- Open Finalized Dataset --- +# store_finalized = await ShardedZarrStore.open( +# cas=kubo_cas, read_only=True, root_cid=finalized_cid +# ) +# ds_finalized = xr.open_zarr(store=store_finalized, chunks="auto") + +# # Verify that the dataset is lazy (Dask-backed) +# assert ds_finalized['2m_temperature'].chunks is not None + +# # Determine the finalization date (last date in finalized dataset) +# finalization_date = np.datetime64(ds_finalized.time.max().values) +# # Convert to Python datetime for clarity +# finalization_date_dt = pd.Timestamp(finalization_date).to_pydatetime() + +# # 2. --- Open Non-Finalized Dataset and Slice After Finalization Date --- +# store_non_finalized = await ShardedZarrStore.open( +# cas=kubo_cas, read_only=True, root_cid=non_finalized_cid +# ) +# ds_non_finalized = xr.open_zarr(store=store_non_finalized, chunks="auto") + +# # Verify that the dataset is lazy +# assert ds_non_finalized['2m_temperature'].chunks is not None + +# # Slice non-finalized dataset to start *after* the finalization date +# # (finalization_date is inclusive for finalized data, so start at +1 hour) +# start_time = finalization_date + np.timedelta64(1, 'h') +# ds_non_finalized_sliced = ds_non_finalized.sel(time=slice(start_time, None)) + +# # Verify that the sliced dataset starts after the finalization date +# if ds_non_finalized_sliced.time.size > 0: +# assert ds_non_finalized_sliced.time.min() > finalization_date +# else: +# # Handle case where non-finalized dataset is empty after slicing +# print("Warning: Non-finalized dataset is empty after slicing.") + +# # 3. --- Lazily Concatenate Datasets --- +# combined_ds = xr.concat([ds_finalized, ds_non_finalized_sliced], dim="time") + +# # Verify that the combined dataset is still lazy +# assert combined_ds['2m_temperature'].chunks is not None + +# # 4. --- Query Across Both Datasets --- +# # Select a time slice that spans both datasets +# # Use a range that includes the boundary (e.g., finalization date and after) +# query_start = finalization_date - np.timedelta64(1, 'D') # 1 day before +# query_end = finalization_date + np.timedelta64(1, 'D') # 1 day after +# query_slice = combined_ds.sel(time=slice(query_start, query_end), latitude=0, longitude=0) +# # Make sure the query slice aligned with the query_start and query_end + +# assert query_slice.time.min() >= query_start +# assert query_slice.time.max() <= query_end + +# # Verify that the query is still lazy +# assert query_slice['2m_temperature'].chunks is not None + +# # Compute the result to trigger data loading +# query_result = query_slice.compute() + +# # 5. --- Verify Results --- +# # Verify data integrity at specific points +# # Check the last finalized time (from finalized dataset) +# sample_time_finalized = finalization_date +# if sample_time_finalized in query_result.time.values: +# sample_result = query_result.sel(time=sample_time_finalized, method="nearest")['2m_temperature'].values +# expected_sample = ds_finalized.sel(time=sample_time_finalized, latitude=0, longitude=0, method="nearest")['2m_temperature'].values +# np.testing.assert_array_equal(sample_result, expected_sample) + +# # Check the first non-finalized time (from non-finalized dataset, if available) +# if ds_non_finalized_sliced.time.size > 0: +# sample_time_non_finalized = ds_non_finalized_sliced.time.min().values +# if sample_time_non_finalized in query_result.time.values: +# sample_result = query_result.sel(time=sample_time_non_finalized, method="nearest")['2m_temperature'].values +# expected_sample = ds_non_finalized_sliced.sel( +# time=sample_time_non_finalized, latitude=0, longitude=0, method="nearest" +# )['2m_temperature'].values +# np.testing.assert_array_equal(sample_result, expected_sample) + +# # 6. --- Additional Validation --- +# # Verify that the concatenated dataset has no overlapping times +# time_values = combined_ds.time.values +# assert np.all(np.diff(time_values) > np.timedelta64(0, 'ns')), "Overlapping or unsorted time values detected" + +# # Verify that the query result covers the expected time range +# if query_result.time.size > 0: +# assert query_result.time.min() >= query_start +# assert query_result.time.max() <= query_end + +# print("\n✅ Lazy concatenation with CIDs test successful! Data verified.") \ No newline at end of file From a226a10c615ab3c83b6fd1f7d8b157408552fcf7 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 10 Jul 2025 12:43:04 -0400 Subject: [PATCH 49/74] fix: kubo store httpx coverage --- .github/workflows/run-checks.yaml | 2 +- py_hamt/store_httpx.py | 65 ++++++++-- tests/test_kubo_pin.py | 60 +++++++++ tests/test_sharded_zarr_store.py | 205 +++++++++++++++--------------- 4 files changed, 218 insertions(+), 114 deletions(-) create mode 100644 tests/test_kubo_pin.py diff --git a/.github/workflows/run-checks.yaml b/.github/workflows/run-checks.yaml index 439dd5f..c1fedeb 100644 --- a/.github/workflows/run-checks.yaml +++ b/.github/workflows/run-checks.yaml @@ -43,7 +43,7 @@ jobs: run_daemon: true - name: Run pytest with coverage - run: uv run pytest --ipfs --cov=py_hamt tests/ --cov-report=xml + run: uv run pytest --ipfs --cov=py_hamt tests/ --cov-report=xml -s - name: Upload coverage reports to Codecov uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24 # v5 diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index 2ca88fe..eb3f00f 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -40,14 +40,24 @@ async def load( ) -> bytes: """Retrieve data.""" - # Optional abstract methods for pinning and unpinning CIDs + async def pin_cid(self, id: IPLDKind, target_rpc: str) -> None: """Pin a CID in the storage.""" - pass + pass # pragma: no cover + async def unpin_cid(self, id: IPLDKind, target_rpc: str) -> None: """Unpin a CID in the storage.""" - pass + pass # pragma: no cover + + + async def pin_update(self, old_id: IPLDKind, new_id: IPLDKind, target_rpc: str) -> None: + """Update the pinned CID in the storage.""" + pass # pragma: no cover + + async def pin_ls(self, target_rpc: str) -> Dict[str, Any]: + """List all pinned CIDs in the storage.""" + pass # pragma: no cover class InMemoryCAS(ContentAddressedStore): @@ -407,12 +417,6 @@ async def pin_cid( response = await client.post(pin_add_url_base, params=params) response.raise_for_status() - # async with self._loop_session().post( - # pin_add_url_base, params=params - # ) as resp: - # resp.raise_for_status() - # # A 200 OK is sufficient to indicate success. - async def unpin_cid( self, cid: CID, target_rpc: str = "http://127.0.0.1:5001" ) -> None: @@ -428,6 +432,43 @@ async def unpin_cid( client = self._loop_client() response = await client.post(unpin_url_base, params=params) response.raise_for_status() - # async with self._loop_session().post(unpin_url_base, params=params) as resp: - # resp.raise_for_status() - # # A 200 OK is sufficient to indicate success. + + async def pin_update( + self, + old_id: IPLDKind, + new_id: IPLDKind, + target_rpc: str = "http://127.0.0.1:5001" + ) -> None: + """ + Updates the pinned CID in the storage. + + Args: + old_id (IPLDKind): The old Content ID to replace. + new_id (IPLDKind): The new Content ID to pin. + """ + params = {"arg": [str(old_id), str(new_id)]} + pin_update_url_base: str = f"{target_rpc}/api/v0/pin/update" + async with self._sem: # throttle RPC + client = self._loop_client() + response = await client.post(pin_update_url_base, params=params) + response.raise_for_status() + + async def pin_ls( + self, target_rpc: str = "http://127.0.0.1:5001" + ) -> list[CID]: + """ + Lists all pinned CIDs on the local Kubo node via the RPC API. + + Args: + target_rpc (str): The RPC URL of the Kubo node. + + Returns: + List[CID]: A list of pinned CIDs. + """ + pin_ls_url_base: str = f"{target_rpc}/api/v0/pin/ls" + async with self._sem: # throttle RPC + client = self._loop_client() + response = await client.post(pin_ls_url_base) + response.raise_for_status() + pins = response.json().get("Keys", []) + return pins diff --git a/tests/test_kubo_pin.py b/tests/test_kubo_pin.py new file mode 100644 index 0000000..9ae34b5 --- /dev/null +++ b/tests/test_kubo_pin.py @@ -0,0 +1,60 @@ +import pytest +import httpx +from dag_cbor import IPLDKind +import dag_cbor +from hypothesis import given, settings +from multiformats import CID, multihash +from py_hamt import KuboCAS +from testing_utils import ipld_strategy + + +@pytest.mark.asyncio(loop_scope="session") +async def test_pinning(create_ipfs, global_client_session): + """ + Tests pinning a CID using KuboCAS with explicit URLs. + Verifies that a CID can be pinned and is retrievable after pinning. + """ + rpc_url, gateway_url = create_ipfs + + async with KuboCAS( + rpc_base_url=rpc_url, + gateway_base_url=gateway_url, + client=global_client_session, + ) as kubo_cas: + # Save data to get a CID + data = b"test1" + encoded_data = dag_cbor.encode(data) + cid = await kubo_cas.save(encoded_data, codec="dag-cbor") + # Pin the CID + await kubo_cas.pin_cid(cid, target_rpc=rpc_url) + listed_pins = await kubo_cas.pin_ls(target_rpc=rpc_url) + # Verify the CID is pinned by querying the pin list + assert str(cid) in listed_pins, f"CID {cid} was not pinned" + + # Unpine the CID + await kubo_cas.unpin_cid(cid, target_rpc=rpc_url) + + # Verify the CID is no longer pinned + listed_pins_after_unpin = await kubo_cas.pin_ls(target_rpc=rpc_url) + assert str(cid) not in listed_pins_after_unpin, f"CID {cid} was not unpinned" + + # Pin again, then perform a pin update + await kubo_cas.pin_cid(cid, target_rpc=rpc_url) + + data = b"test2" + encoded_data = dag_cbor.encode(data) + new_cid = await kubo_cas.save(encoded_data, codec="dag-cbor") + + # Update the pinned CID + await kubo_cas.pin_update(cid, new_cid, target_rpc=rpc_url) + + # Verify the old CID is no longer pinned and the new CID is pinned + listed_pins_after_update = await kubo_cas.pin_ls(target_rpc=rpc_url) + assert str(cid) not in listed_pins_after_update, f"Old CID {cid} was not unpinned after update" + assert str(new_cid) in listed_pins_after_update, f"New CID {new_cid} was not pinned after update" + + # unpin the new CID + await kubo_cas.unpin_cid(new_cid, target_rpc=rpc_url) + # Verify the new CID is no longer pinned + listed_pins_after_unpin_update = await kubo_cas.pin_ls(target_rpc=rpc_url) + assert str(new_cid) not in listed_pins_after_unpin_update, f"New CID {new_cid} was not unpinned after update" diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 8ac1af4..dd3e480 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -914,109 +914,112 @@ async def test_sharded_zarr_store_lazy_concat( print("\n✅ Lazy concatenation test successful! Data verified.") -# @pytest.mark.asyncio -# async def test_sharded_zarr_store_lazy_concat_with_cids( -# create_ipfs: tuple[str, str] -# ): -# """ -# Tests lazy concatenation of two xarray datasets stored in separate ShardedZarrStores -# using provided CIDs for finalized and non-finalized data, ensuring the non-finalized -# dataset is sliced after the finalization date (inclusive) and the combined dataset -# can be queried as a single dataset with data fetched correctly from the respective stores. -# """ -# rpc_base_url, gateway_base_url = create_ipfs +@pytest.mark.asyncio +async def test_sharded_zarr_store_lazy_concat_with_cids( + create_ipfs: tuple[str, str] +): + """ + Tests lazy concatenation of two xarray datasets stored in separate ShardedZarrStores + using provided CIDs for finalized and non-finalized data, ensuring the non-finalized + dataset is sliced after the finalization date (inclusive) and the combined dataset + can be queried as a single dataset with data fetched correctly from the respective stores. + """ + rpc_base_url, gateway_base_url = create_ipfs -# # Provided CIDs -# finalized_cid = "bafyr4icrox4pxashkfmbyztn7jhp6zjlpj3bufg5ggsjux74zr7ocnqdpu" -# non_finalized_cid = "bafyr4ibj3bfl5oo7bf6gagzr2g33jlnf23mq2xo632mbl6ytfry7jbuepy" -# async with KuboCAS( -# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url -# ) as kubo_cas: -# # 1. --- Open Finalized Dataset --- -# store_finalized = await ShardedZarrStore.open( -# cas=kubo_cas, read_only=True, root_cid=finalized_cid -# ) -# ds_finalized = xr.open_zarr(store=store_finalized, chunks="auto") + # Provided CIDs + finalized_cid = "bafyr4icrox4pxashkfmbyztn7jhp6zjlpj3bufg5ggsjux74zr7ocnqdpu" + non_finalized_cid = "bafyr4ibj3bfl5oo7bf6gagzr2g33jlnf23mq2xo632mbl6ytfry7jbuepy" + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # 1. --- Open Finalized Dataset --- + store_finalized = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=finalized_cid + ) + ds_finalized = xr.open_zarr(store=store_finalized, chunks="auto") -# # Verify that the dataset is lazy (Dask-backed) -# assert ds_finalized['2m_temperature'].chunks is not None + # Verify that the dataset is lazy (Dask-backed) + assert ds_finalized['2m_temperature'].chunks is not None -# # Determine the finalization date (last date in finalized dataset) -# finalization_date = np.datetime64(ds_finalized.time.max().values) -# # Convert to Python datetime for clarity -# finalization_date_dt = pd.Timestamp(finalization_date).to_pydatetime() + # Determine the finalization date (last date in finalized dataset) + finalization_date = np.datetime64(ds_finalized.time.max().values) + # Convert to Python datetime for clarity + finalization_date_dt = pd.Timestamp(finalization_date).to_pydatetime() -# # 2. --- Open Non-Finalized Dataset and Slice After Finalization Date --- -# store_non_finalized = await ShardedZarrStore.open( -# cas=kubo_cas, read_only=True, root_cid=non_finalized_cid -# ) -# ds_non_finalized = xr.open_zarr(store=store_non_finalized, chunks="auto") - -# # Verify that the dataset is lazy -# assert ds_non_finalized['2m_temperature'].chunks is not None - -# # Slice non-finalized dataset to start *after* the finalization date -# # (finalization_date is inclusive for finalized data, so start at +1 hour) -# start_time = finalization_date + np.timedelta64(1, 'h') -# ds_non_finalized_sliced = ds_non_finalized.sel(time=slice(start_time, None)) - -# # Verify that the sliced dataset starts after the finalization date -# if ds_non_finalized_sliced.time.size > 0: -# assert ds_non_finalized_sliced.time.min() > finalization_date -# else: -# # Handle case where non-finalized dataset is empty after slicing -# print("Warning: Non-finalized dataset is empty after slicing.") - -# # 3. --- Lazily Concatenate Datasets --- -# combined_ds = xr.concat([ds_finalized, ds_non_finalized_sliced], dim="time") - -# # Verify that the combined dataset is still lazy -# assert combined_ds['2m_temperature'].chunks is not None - -# # 4. --- Query Across Both Datasets --- -# # Select a time slice that spans both datasets -# # Use a range that includes the boundary (e.g., finalization date and after) -# query_start = finalization_date - np.timedelta64(1, 'D') # 1 day before -# query_end = finalization_date + np.timedelta64(1, 'D') # 1 day after -# query_slice = combined_ds.sel(time=slice(query_start, query_end), latitude=0, longitude=0) -# # Make sure the query slice aligned with the query_start and query_end + # 2. --- Open Non-Finalized Dataset and Slice After Finalization Date --- + store_non_finalized = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=non_finalized_cid + ) + ds_non_finalized = xr.open_zarr(store=store_non_finalized, chunks="auto") + + # Verify that the dataset is lazy + assert ds_non_finalized['2m_temperature'].chunks is not None + + # Slice non-finalized dataset to start *after* the finalization date + # (finalization_date is inclusive for finalized data, so start at +1 hour) + start_time = finalization_date + np.timedelta64(1, 'h') + ds_non_finalized_sliced = ds_non_finalized.sel(time=slice(start_time, None)) + + # Verify that the sliced dataset starts after the finalization date + if ds_non_finalized_sliced.time.size > 0: + assert ds_non_finalized_sliced.time.min() > finalization_date + else: + # Handle case where non-finalized dataset is empty after slicing + print("Warning: Non-finalized dataset is empty after slicing.") + + # 3. --- Lazily Concatenate Datasets --- + combined_ds = xr.concat([ds_finalized, ds_non_finalized_sliced], dim="time") + print("\nCombined dataset time range:") + print(combined_ds.time.min().values, "to", combined_ds.time.max().values) + print("EHRUKHUKEHUK") + + # Verify that the combined dataset is still lazy + assert combined_ds['2m_temperature'].chunks is not None + + # 4. --- Query Across Both Datasets --- + # Select a time slice that spans both datasets + # Use a range that includes the boundary (e.g., finalization date and after) + query_start = finalization_date - np.timedelta64(1, 'D') # 1 day before + query_end = finalization_date + np.timedelta64(1, 'D') # 1 day after + query_slice = combined_ds.sel(time=slice(query_start, query_end), latitude=0, longitude=0) + # Make sure the query slice aligned with the query_start and query_end -# assert query_slice.time.min() >= query_start -# assert query_slice.time.max() <= query_end - -# # Verify that the query is still lazy -# assert query_slice['2m_temperature'].chunks is not None - -# # Compute the result to trigger data loading -# query_result = query_slice.compute() - -# # 5. --- Verify Results --- -# # Verify data integrity at specific points -# # Check the last finalized time (from finalized dataset) -# sample_time_finalized = finalization_date -# if sample_time_finalized in query_result.time.values: -# sample_result = query_result.sel(time=sample_time_finalized, method="nearest")['2m_temperature'].values -# expected_sample = ds_finalized.sel(time=sample_time_finalized, latitude=0, longitude=0, method="nearest")['2m_temperature'].values -# np.testing.assert_array_equal(sample_result, expected_sample) - -# # Check the first non-finalized time (from non-finalized dataset, if available) -# if ds_non_finalized_sliced.time.size > 0: -# sample_time_non_finalized = ds_non_finalized_sliced.time.min().values -# if sample_time_non_finalized in query_result.time.values: -# sample_result = query_result.sel(time=sample_time_non_finalized, method="nearest")['2m_temperature'].values -# expected_sample = ds_non_finalized_sliced.sel( -# time=sample_time_non_finalized, latitude=0, longitude=0, method="nearest" -# )['2m_temperature'].values -# np.testing.assert_array_equal(sample_result, expected_sample) - -# # 6. --- Additional Validation --- -# # Verify that the concatenated dataset has no overlapping times -# time_values = combined_ds.time.values -# assert np.all(np.diff(time_values) > np.timedelta64(0, 'ns')), "Overlapping or unsorted time values detected" - -# # Verify that the query result covers the expected time range -# if query_result.time.size > 0: -# assert query_result.time.min() >= query_start -# assert query_result.time.max() <= query_end - -# print("\n✅ Lazy concatenation with CIDs test successful! Data verified.") \ No newline at end of file + assert query_slice.time.min() >= query_start + assert query_slice.time.max() <= query_end + + # Verify that the query is still lazy + assert query_slice['2m_temperature'].chunks is not None + + # Compute the result to trigger data loading + query_result = query_slice.compute() + + # 5. --- Verify Results --- + # Verify data integrity at specific points + # Check the last finalized time (from finalized dataset) + sample_time_finalized = finalization_date + if sample_time_finalized in query_result.time.values: + sample_result = query_result.sel(time=sample_time_finalized, method="nearest")['2m_temperature'].values + expected_sample = ds_finalized.sel(time=sample_time_finalized, latitude=0, longitude=0, method="nearest")['2m_temperature'].values + np.testing.assert_array_equal(sample_result, expected_sample) + + # Check the first non-finalized time (from non-finalized dataset, if available) + if ds_non_finalized_sliced.time.size > 0: + sample_time_non_finalized = ds_non_finalized_sliced.time.min().values + if sample_time_non_finalized in query_result.time.values: + sample_result = query_result.sel(time=sample_time_non_finalized, method="nearest")['2m_temperature'].values + expected_sample = ds_non_finalized_sliced.sel( + time=sample_time_non_finalized, latitude=0, longitude=0, method="nearest" + )['2m_temperature'].values + np.testing.assert_array_equal(sample_result, expected_sample) + + # 6. --- Additional Validation --- + # Verify that the concatenated dataset has no overlapping times + time_values = combined_ds.time.values + assert np.all(np.diff(time_values) > np.timedelta64(0, 'ns')), "Overlapping or unsorted time values detected" + + # Verify that the query result covers the expected time range + if query_result.time.size > 0: + assert query_result.time.min() >= query_start + assert query_result.time.max() <= query_end + + print("\n✅ Lazy concatenation with CIDs test successful! Data verified.") \ No newline at end of file From 95982d217215f7b5bf37a3333f9632b91dd52025 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Fri, 11 Jul 2025 10:00:24 -0400 Subject: [PATCH 50/74] fix: full coverage --- py_hamt/hamt_to_sharded_converter.py | 3 +- py_hamt/sharded_zarr_store.py | 260 ++++--------- py_hamt/store_httpx.py | 25 +- run-checks.sh | 3 +- tests/test_converter.py | 6 +- tests/test_kubo_pin.py | 20 +- tests/test_sharded_store_deleting.py | 330 +++++++++++++++++ tests/test_sharded_store_grafting.py | 390 ++++++++++++++++++++ tests/test_sharded_store_resizing.py | 421 ++++++++++++++++++++++ tests/test_sharded_zarr_store.py | 290 +++++++++++++-- tests/test_sharded_zarr_store_coverage.py | 250 +++++++++++++ 11 files changed, 1735 insertions(+), 263 deletions(-) create mode 100644 tests/test_sharded_store_deleting.py create mode 100644 tests/test_sharded_store_grafting.py create mode 100644 tests/test_sharded_store_resizing.py create mode 100644 tests/test_sharded_zarr_store_coverage.py diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index c50ba23..befc4d6 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -36,7 +36,8 @@ async def convert_hamt_to_sharded( print("Reading metadata from source store...") # Read the stores metadata to get array shape and chunk shape - ordered_dims = list(source_dataset.sizes) + data_var_name = next(iter(source_dataset.data_vars)) + ordered_dims = list(source_dataset[data_var_name].dims) array_shape_tuple = tuple(source_dataset.sizes[dim] for dim in ordered_dims) chunk_shape_tuple = tuple(source_dataset.chunks[dim][0] for dim in ordered_dims) array_shape = array_shape_tuple diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 7c26cad..9bfea84 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -1,7 +1,6 @@ import asyncio import itertools import json -import logging import math from collections import defaultdict from collections.abc import AsyncIterator, Iterable @@ -11,6 +10,7 @@ import zarr.abc.store import zarr.core.buffer from multiformats.cid import CID +from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest from zarr.core.common import BytesLike from .store_httpx import ContentAddressedStore @@ -35,13 +35,13 @@ def __init__( self, cas: ContentAddressedStore, read_only: bool, - root_cid: Optional[str], + root_cid: Optional[str] = None, ): """Use the async `open()` classmethod to instantiate this class.""" super().__init__(read_only=read_only) self.cas = cas self._root_cid = root_cid - self._root_obj: Optional[dict] = None + self._root_obj: dict self._resize_lock = asyncio.Lock() # An event to signal when a resize is in-progress. @@ -52,25 +52,19 @@ def __init__( self._shard_data_cache: Dict[int, list[Optional[CID]]] = {} self._dirty_shards: Set[int] = set() - self._pending_shard_loads: Dict[int, asyncio.Task] = {} + self._pending_shard_loads: Dict[int, asyncio.Event] = {} - self._array_shape: Optional[Tuple[int, ...]] = None - self._chunk_shape: Optional[Tuple[int, ...]] = None - self._chunks_per_dim: Optional[Tuple[int, ...]] = None - self._chunks_per_shard: Optional[int] = None - self._num_shards: Optional[int] = None - self._total_chunks: Optional[int] = None + self._array_shape: Tuple[int, ...] + self._chunk_shape: Tuple[int, ...] + self._chunks_per_dim: Tuple[int, ...] + self._chunks_per_shard: int + self._num_shards: int = 0 + self._total_chunks: int = 0 self._dirty_root = False def _update_geometry(self): """Calculates derived geometric properties from the base shapes.""" - if ( - self._array_shape is None - or self._chunk_shape is None - or self._chunks_per_shard is None - ): - raise RuntimeError("Base shape information is not set.") if not all(cs > 0 for cs in self._chunk_shape): raise ValueError("All chunk_shape dimensions must be positive.") @@ -83,9 +77,7 @@ def _update_geometry(self): ) self._total_chunks = math.prod(self._chunks_per_dim) - if self._total_chunks == 0: - self._num_shards = 0 - else: + if not self._total_chunks == 0: self._num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) @classmethod @@ -131,9 +123,6 @@ def _initialize_new_root( self._update_geometry() - if self._num_shards is None: - raise RuntimeError("Number of shards not set after geometry update.") - self._root_obj = { "manifest_version": "sharded_zarr_v1", "metadata": {}, @@ -149,8 +138,6 @@ def _initialize_new_root( self._dirty_root = True async def _load_root_from_cid(self): - if not self._root_cid: - raise RuntimeError("Cannot load root without a root_cid.") root_bytes = await self.cas.load(self._root_cid) self._root_obj = dag_cbor.decode(root_bytes) @@ -178,12 +165,11 @@ async def _fetch_and_cache_full_shard(self, shard_idx: int, shard_cid: str): if not isinstance(decoded_shard, list): raise TypeError(f"Shard {shard_idx} did not decode to a list.") self._shard_data_cache[shard_idx] = decoded_shard - except Exception as e: - logging.warning( - f"Failed to fetch or decode shard {shard_idx} (CID: {shard_cid}): {e}" - ) + except Exception: + raise finally: if shard_idx in self._pending_shard_loads: + self._pending_shard_loads[shard_idx].set() # Signal completion del self._pending_shard_loads[shard_idx] def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: @@ -214,43 +200,20 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: if actual_array_name in excluded_array_prefixes: return None - # If we've reached here, the key is potentially for a "main" data variable - # that this store instance is expected to handle via sharding. - # Now, proceed with the original parsing logic using self._array_shape and - # self._chunks_per_dim, which should be configured for this main data variable. - - if not self._array_shape or not self._chunks_per_dim: - # This ShardedZarrStore instance is not properly initialized - # with the shape/chunking info for the array it's supposed to manage. - # This might also happen if a key like "some_other_main_array/c/0" is passed - # but this store instance was configured for "temp". - return None - # The part after "/c/" contains the chunk coordinates coord_part = key[marker_idx + len(chunk_marker) :] parts = coord_part.split("/") - # Validate dimensionality: - # The number of coordinate parts must match the dimensionality of the array - # this store instance is configured for (self._chunks_per_dim). - if len(parts) != len(self._chunks_per_dim): - # This key's dimensionality does not match the store's configured array. - # It's likely for a different array or a malformed key for the current array. - return None - - try: - coords = tuple(map(int, parts)) - # Validate coordinates against the chunk grid of the store's configured array - for i, c_coord in enumerate(coords): - if not (0 <= c_coord < self._chunks_per_dim[i]): - return None # Coordinate out of bounds for this array's chunk grid - return coords - except (ValueError, IndexError): # If int conversion fails or other issues - return None + coords = tuple(map(int, parts)) + # Validate coordinates against the chunk grid of the store's configured array + for i, c_coord in enumerate(coords): + if not (0 <= c_coord < self._chunks_per_dim[i]): + raise IndexError( + f"Chunk coordinate {c_coord} at dimension {i} is out of bounds for dimension size {self._chunks_per_dim[i]}." + ) + return coords def _get_linear_chunk_index(self, chunk_coords: Tuple[int, ...]) -> int: - if self._chunks_per_dim is None: - raise ValueError("Chunks per dimension not set") linear_index = 0 multiplier = 1 # Convert N-D chunk coordinates to a flat 1-D index (row-major order) @@ -260,46 +223,31 @@ def _get_linear_chunk_index(self, chunk_coords: Tuple[int, ...]) -> int: return linear_index def _get_shard_info(self, linear_chunk_index: int) -> Tuple[int, int]: - if self._chunks_per_shard is None or self._chunks_per_shard <= 0: - raise RuntimeError( - "Sharding not configured properly: _chunks_per_shard invalid." - ) - if linear_chunk_index < 0: - raise ValueError("Linear chunk index cannot be negative.") - shard_idx = linear_chunk_index // self._chunks_per_shard index_in_shard = linear_chunk_index % self._chunks_per_shard return shard_idx, index_in_shard async def _load_or_initialize_shard_cache(self, shard_idx: int) -> list: - # CHANGED: This method is updated to handle list-based cache and DAG-CBOR decoding. if shard_idx in self._shard_data_cache: return self._shard_data_cache[shard_idx] if shard_idx in self._pending_shard_loads: - await self._pending_shard_loads[shard_idx] + await self._pending_shard_loads[shard_idx].wait() if shard_idx in self._shard_data_cache: return self._shard_data_cache[shard_idx] - if self._root_obj is None or self._num_shards is None: - raise RuntimeError("Root object not loaded or initialized.") if not (0 <= shard_idx < self._num_shards): raise ValueError(f"Shard index {shard_idx} out of bounds.") shard_cid_obj = self._root_obj["chunks"]["shard_cids"][shard_idx] if shard_cid_obj: + self._pending_shard_loads[shard_idx] = asyncio.Event() # The CID in the root should already be a CID object if loaded correctly. shard_cid_str = str(shard_cid_obj) await self._fetch_and_cache_full_shard(shard_idx, shard_cid_str) else: - if self._chunks_per_shard is None: - raise RuntimeError("Store not initialized: _chunks_per_shard is None.") - # Initialize new shard as a list of Nones self._shard_data_cache[shard_idx] = [None] * self._chunks_per_shard - if shard_idx not in self._shard_data_cache: - raise RuntimeError(f"Failed to load or initialize shard {shard_idx}") - return self._shard_data_cache[shard_idx] async def set_partial_values( @@ -316,30 +264,18 @@ async def get_partial_values( ) -> List[Optional[zarr.core.buffer.Buffer]]: tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] results = await asyncio.gather(*tasks) - return results # type: ignore + return results def __eq__(self, other: object) -> bool: if not isinstance(other, ShardedZarrStore): - return NotImplemented + return False # For equality, root CID is primary. Config like chunks_per_shard is part of that root's identity. return self._root_cid == other._root_cid + # If nothing to flush, return the root CID. async def flush(self) -> str: - # CHANGED: This method now encodes shards using DAG-CBOR. - if self.read_only: - if not self._root_cid: - raise ValueError("Read-only store has no root CID to return.") - return self._root_cid - - if self._root_obj is None: - raise RuntimeError("Store not initialized for writing.") - if self._dirty_shards: for shard_idx in sorted(list(self._dirty_shards)): - if shard_idx not in self._shard_data_cache: - logging.warning(f"Dirty shard {shard_idx} not in cache. Skipping.") - continue - # Get the list of CIDs/Nones from the cache shard_data_list = self._shard_data_cache[shard_idx] @@ -375,9 +311,8 @@ async def flush(self) -> str: self._root_cid = str(new_root_cid) self._dirty_root = False - if self._root_cid is None: - raise RuntimeError("Failed to obtain a root CID after flushing.") - return self._root_cid + # Ignore because root_cid will always exist after initialization or flush. + return self._root_cid # type: ignore[return-value] async def get( self, @@ -385,11 +320,6 @@ async def get( prototype: zarr.core.buffer.BufferPrototype, byte_range: Optional[zarr.abc.store.ByteRequest] = None, ) -> Optional[zarr.core.buffer.Buffer]: - # CHANGED: Logic is simplified to not use byte offsets. It relies on the full-shard cache. - if self._root_obj is None: - raise RuntimeError("Load the root object first before accessing data.") - # print('Getting key', key) - chunk_coords = self._parse_chunk_key(key) # Metadata request if chunk_coords is None: @@ -397,17 +327,15 @@ async def get( if metadata_cid_obj is None: return None if byte_range is not None: - logging.warning(f"Byte range request for metadata key '{key}' ignored.") + raise ValueError( + "Byte range requests are not supported for metadata keys." + ) data = await self.cas.load(str(metadata_cid_obj)) return prototype.buffer.from_bytes(data) - # Chunk data request linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - if not (0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): - return None - # This will load the full shard into cache if it's not already there. shard_lock = self._shard_locks[shard_idx] async with shard_lock: @@ -421,27 +349,31 @@ async def get( chunk_cid_str = str(chunk_cid_obj) - # Actual chunk data load using the retrieved chunk CID - req_offset = byte_range.start if byte_range else None + req_offset = None req_length = None + req_suffix = None + if byte_range: - if byte_range.end is not None: - if ( - byte_range.start > byte_range.end - ): # Zarr allows start == stop for 0 length - raise ValueError( - f"Byte range start ({byte_range.start}) cannot be greater than end ({byte_range.end})" - ) - req_length = byte_range.end - byte_range.start - data = await self.cas.load(chunk_cid_str, offset=req_offset, length=req_length) + if isinstance(byte_range, RangeByteRequest): + req_offset = byte_range.start + if byte_range.end is not None: + if byte_range.start > byte_range.end: + raise ValueError( + f"Byte range start ({byte_range.start}) cannot be greater than end ({byte_range.end})" + ) + req_length = byte_range.end - byte_range.start + elif isinstance(byte_range, OffsetByteRequest): + req_offset = byte_range.offset + elif isinstance(byte_range, SuffixByteRequest): + req_suffix = byte_range.suffix + data = await self.cas.load( + chunk_cid_str, offset=req_offset, length=req_length, suffix=req_suffix + ) return prototype.buffer.from_bytes(data) async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: if self.read_only: - raise ValueError("Cannot write to a read-only store.") - if self._root_obj is None: - raise RuntimeError("Store not initialized for writing. Call open() first.") - + raise PermissionError("Cannot write to a read-only store.") await self._resize_complete.wait() if ( @@ -475,10 +407,6 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: await self.set_pointer(key, str(data_cid_obj)) async def set_pointer(self, key: str, pointer: str) -> None: - # CHANGED: Logic now updates a list in the cache, not a bytearray. - if self._root_obj is None: - raise RuntimeError("Load the root object first before accessing data.") - chunk_coords = self._parse_chunk_key(key) pointer_cid_obj = CID.decode(pointer) # Convert string to CID object @@ -488,7 +416,6 @@ async def set_pointer(self, key: str, pointer: str) -> None: self._dirty_root = True return - # Chunk Data: Store the CID object in the correct shard list. linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) @@ -500,34 +427,16 @@ async def set_pointer(self, key: str, pointer: str) -> None: target_shard_list[index_in_shard] = pointer_cid_obj self._dirty_shards.add(shard_idx) - # ... (Keep exists method, but simplify it) ... async def exists(self, key: str) -> bool: - # CHANGED: Simplified to use the list-based cache. - if self._root_obj is None: - raise RuntimeError("Root object not loaded.") - chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata return key in self._root_obj.get("metadata", {}) + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + # Load shard if not cached and check the index + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + return target_shard_list[index_in_shard] is not None - try: - linear_chunk_index = self._get_linear_chunk_index(chunk_coords) - shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - - if not (0 <= shard_idx < len(self._root_obj["chunks"]["shard_cids"])): - return False - - shard_cid_obj = self._root_obj["chunks"]["shard_cids"][shard_idx] - if shard_cid_obj is None: - return False - - # Load shard if not cached and check the index - target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) - return target_shard_list[index_in_shard] is not None - except Exception: - return False - - # ... (Keep supports_writes, etc. properties) ... @property def supports_writes(self) -> bool: return not self.read_only @@ -541,11 +450,8 @@ def supports_deletes(self) -> bool: return not self.read_only async def delete(self, key: str) -> None: - # CHANGED: Simplified to set list element to None. if self.read_only: - raise ValueError("Cannot delete from a read-only store.") - if self._root_obj is None: - raise RuntimeError("Store not initialized for deletion.") + raise PermissionError("Cannot delete from a read-only store.") chunk_coords = self._parse_chunk_key(key) if chunk_coords is None: # Metadata @@ -558,11 +464,6 @@ async def delete(self, key: str) -> None: linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - if not ( - 0 <= shard_idx < self._num_shards if self._num_shards is not None else 0 - ): - raise KeyError(f"Chunk key '{key}' is out of bounds.") - shard_lock = self._shard_locks[shard_idx] async with shard_lock: target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) @@ -575,10 +476,6 @@ def supports_listing(self) -> bool: return True async def list(self) -> AsyncIterator[str]: - if self._root_obj is None: - raise RuntimeError( - "Root object not loaded. Call _load_root_from_cid() first." - ) for key in list(self._root_obj.get("metadata", {})): yield key @@ -589,21 +486,17 @@ async def list_prefix(self, prefix: str) -> AsyncIterator[str]: async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, ...]): if self.read_only: - raise ValueError("Cannot graft onto a read-only store.") - if self._root_obj is None: - raise RuntimeError("Main store must be initialized before grafting.") + raise PermissionError("Cannot graft onto a read-only store.") + store_to_graft = await ShardedZarrStore.open( cas=self.cas, read_only=True, root_cid=store_to_graft_cid ) - if store_to_graft._root_obj is None or store_to_graft._chunks_per_dim is None: - raise ValueError("Store to graft could not be loaded or is not configured.") source_chunk_grid = store_to_graft._chunks_per_dim for local_coords in itertools.product(*[range(s) for s in source_chunk_grid]): linear_local_index = store_to_graft._get_linear_chunk_index(local_coords) local_shard_idx, index_in_local_shard = store_to_graft._get_shard_info( linear_local_index ) - # Load the source shard into its cache source_shard_list = await store_to_graft._load_or_initialize_shard_cache( local_shard_idx @@ -628,7 +521,6 @@ async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, .. target_shard_list = await self._load_or_initialize_shard_cache( global_shard_idx ) - if target_shard_list[index_in_global_shard] != pointer_cid_obj: target_shard_list[index_in_global_shard] = pointer_cid_obj self._dirty_shards.add(global_shard_idx) @@ -640,10 +532,10 @@ async def resize_store(self, new_shape: Tuple[int, ...]): Used when doing skeleton writes or appends via xarray where the array shape changes. """ if self.read_only: - raise ValueError("Cannot resize a read-only store.") + raise PermissionError("Cannot resize a read-only store.") if ( - self._root_obj is None - or self._chunk_shape is None + # self._root_obj is None + self._chunk_shape is None or self._chunks_per_shard is None or self._array_shape is None ): @@ -683,11 +575,8 @@ async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): This does NOT change the store's main shard index. """ if self.read_only: - raise ValueError("Cannot resize a read-only store.") - if self._root_obj is None: - raise RuntimeError("Store is not properly initialized for resizing.") + raise PermissionError("Cannot resize a read-only store.") - # Zarr v2 uses .json, not zarr.json zarr_metadata_key = f"{variable_name}/zarr.json" old_zarr_metadata_cid = self._root_obj["metadata"].get(zarr_metadata_key) @@ -711,18 +600,8 @@ async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): self._root_obj["metadata"][zarr_metadata_key] = str(new_zarr_metadata_cid) self._dirty_root = True - print( - f"Resized metadata for variable '{variable_name}'. New shape: {new_shape}" - ) async def list_dir(self, prefix: str) -> AsyncIterator[str]: - # This simplified version only works for the root directory (prefix == "") of metadata. - # It lists unique first components of metadata keys. - if self._root_obj is None: - raise RuntimeError( - "Root object not loaded. Call _load_root_from_cid() first." - ) - seen: Set[str] = set() if prefix == "": async for key in self.list(): # Iterates metadata keys @@ -733,15 +612,4 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: seen.add(first_component) yield first_component else: - # For listing subdirectories like "group1/", we'd need to match keys starting with "group1/" - # and then extract the next component. This is more involved. - # Zarr spec: list_dir(path) should yield children (both objects and "directories") - # For simplicity, and consistency with original FlatZarrStore, keeping this minimal. - # To make it more compliant for prefix="foo/": - normalized_prefix = prefix if prefix.endswith("/") else prefix + "/" - async for key in self.list_prefix(normalized_prefix): - remainder = key[len(normalized_prefix) :] - child = remainder.split("/", 1)[0] - if child not in seen: - seen.add(child) - yield child + raise NotImplementedError("Listing with a prefix is not implemented yet.") diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index eb3f00f..52fe75e 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -40,24 +40,23 @@ async def load( ) -> bytes: """Retrieve data.""" - async def pin_cid(self, id: IPLDKind, target_rpc: str) -> None: """Pin a CID in the storage.""" - pass # pragma: no cover - + pass # pragma: no cover async def unpin_cid(self, id: IPLDKind, target_rpc: str) -> None: """Unpin a CID in the storage.""" - pass # pragma: no cover - + pass # pragma: no cover - async def pin_update(self, old_id: IPLDKind, new_id: IPLDKind, target_rpc: str) -> None: + async def pin_update( + self, old_id: IPLDKind, new_id: IPLDKind, target_rpc: str + ) -> None: """Update the pinned CID in the storage.""" - pass # pragma: no cover - - async def pin_ls(self, target_rpc: str) -> Dict[str, Any]: + pass # pragma: no cover + + async def pin_ls(self, target_rpc: str) -> list[Dict[str, Any]]: """List all pinned CIDs in the storage.""" - pass # pragma: no cover + return [] # pragma: no cover class InMemoryCAS(ContentAddressedStore): @@ -432,12 +431,12 @@ async def unpin_cid( client = self._loop_client() response = await client.post(unpin_url_base, params=params) response.raise_for_status() - + async def pin_update( self, old_id: IPLDKind, new_id: IPLDKind, - target_rpc: str = "http://127.0.0.1:5001" + target_rpc: str = "http://127.0.0.1:5001", ) -> None: """ Updates the pinned CID in the storage. @@ -455,7 +454,7 @@ async def pin_update( async def pin_ls( self, target_rpc: str = "http://127.0.0.1:5001" - ) -> list[CID]: + ) -> list[Dict[str, Any]]: """ Lists all pinned CIDs on the local Kubo node via the RPC API. diff --git a/run-checks.sh b/run-checks.sh index bfbd222..b4995d1 100644 --- a/run-checks.sh +++ b/run-checks.sh @@ -7,4 +7,5 @@ uv run pytest --ipfs --cov=py_hamt tests/ uv run coverage report --fail-under=100 --show-missing # Check for linting, formatting, and type checking using the pre-commit hooks found in .pre-commit-config.yaml -uv run pre-commit run --all-files --show-diff-on-failure +# uv run pre-commit run --all-files --show-diff-on-failure +uv run pre-commit run --all-files diff --git a/tests/test_converter.py b/tests/test_converter.py index c5b865f..64e5799 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -38,13 +38,13 @@ def converter_test_dataset(): data = np.random.randn(len(times), len(lats), len(lons)) ds = xr.Dataset( - {unique_var_name: (["time", "lat", "lon"], data)}, - coords={"time": times, "lat": lats, "lon": lons}, + {unique_var_name: (["time", "latitude", "longitude"], data)}, + coords={"time": times, "latitude": lats, "longitude": lons}, attrs={"description": "Test dataset for converter verification."}, ) # Define chunking for the store - ds = ds.chunk({"time": 10, "lat": 10, "lon": 10}) + ds = ds.chunk({"time": 10, "latitude": 10, "longitude": 10}) yield ds diff --git a/tests/test_kubo_pin.py b/tests/test_kubo_pin.py index 9ae34b5..703dc92 100644 --- a/tests/test_kubo_pin.py +++ b/tests/test_kubo_pin.py @@ -1,11 +1,7 @@ -import pytest -import httpx -from dag_cbor import IPLDKind import dag_cbor -from hypothesis import given, settings -from multiformats import CID, multihash +import pytest + from py_hamt import KuboCAS -from testing_utils import ipld_strategy @pytest.mark.asyncio(loop_scope="session") @@ -50,11 +46,17 @@ async def test_pinning(create_ipfs, global_client_session): # Verify the old CID is no longer pinned and the new CID is pinned listed_pins_after_update = await kubo_cas.pin_ls(target_rpc=rpc_url) - assert str(cid) not in listed_pins_after_update, f"Old CID {cid} was not unpinned after update" - assert str(new_cid) in listed_pins_after_update, f"New CID {new_cid} was not pinned after update" + assert str(cid) not in listed_pins_after_update, ( + f"Old CID {cid} was not unpinned after update" + ) + assert str(new_cid) in listed_pins_after_update, ( + f"New CID {new_cid} was not pinned after update" + ) # unpin the new CID await kubo_cas.unpin_cid(new_cid, target_rpc=rpc_url) # Verify the new CID is no longer pinned listed_pins_after_unpin_update = await kubo_cas.pin_ls(target_rpc=rpc_url) - assert str(new_cid) not in listed_pins_after_unpin_update, f"New CID {new_cid} was not unpinned after update" + assert str(new_cid) not in listed_pins_after_unpin_update, ( + f"New CID {new_cid} was not unpinned after update" + ) diff --git a/tests/test_sharded_store_deleting.py b/tests/test_sharded_store_deleting.py new file mode 100644 index 0000000..86805e7 --- /dev/null +++ b/tests/test_sharded_store_deleting.py @@ -0,0 +1,330 @@ +import asyncio +import json + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +import zarr.core.buffer + +from py_hamt import KuboCAS, ShardedZarrStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset for benchmarking.""" + # Using a slightly larger dataset for a more meaningful benchmark + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + temp = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], temp), + }, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) + yield ds + +@pytest.mark.asyncio +async def test_delete_chunk_success(create_ipfs: tuple[str, str]): + """Tests successful deletion of a chunk from the store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Write a chunk + chunk_key = "temp/c/0/0" + chunk_data = b"test_chunk_data" + proto = zarr.core.buffer.default_buffer_prototype() + await store.set(chunk_key, proto.buffer.from_bytes(chunk_data)) + assert await store.exists(chunk_key) + + # Delete the chunk + await store.delete(chunk_key) + + # Verify chunk is deleted in cache and shard is marked dirty + linear_index = store._get_linear_chunk_index((0, 0)) + shard_idx, index_in_shard = store._get_shard_info(linear_index) + target_shard_list = await store._load_or_initialize_shard_cache(shard_idx) + assert target_shard_list[index_in_shard] is None + assert shard_idx in store._dirty_shards + + # Flush and verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert not await store_read.exists(chunk_key) + assert await store_read.get(chunk_key, proto) is None + +@pytest.mark.asyncio +async def test_delete_metadata_success(create_ipfs: tuple[str, str]): + """Tests successful deletion of a metadata key.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Write metadata + metadata_key = "temp/zarr.json" + metadata = json.dumps({"shape": [20, 20], "dtype": "float32"}).encode("utf-8") + proto = zarr.core.buffer.default_buffer_prototype() + await store.set(metadata_key, proto.buffer.from_bytes(metadata)) + assert await store.exists(metadata_key) + + # Delete metadata + await store.delete(metadata_key) + + # Verify metadata is deleted and root is marked dirty + assert metadata_key not in store._root_obj["metadata"] + assert store._dirty_root is True + + # Flush and verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert not await store_read.exists(metadata_key) + assert await store_read.get(metadata_key, proto) is None + +@pytest.mark.asyncio +async def test_delete_nonexistent_key(create_ipfs: tuple[str, str]): + """Tests deletion of a nonexistent metadata key.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Temp write to temp/c/0/0 to ensure it exists + proto = zarr.core.buffer.default_buffer_prototype() + await store.set("temp/c/0/0", proto.buffer.from_bytes(b"test_data")) + + # flush it + await store.flush() + assert not store._dirty_shards # No dirty shards after flush + + # Try to delete nonexistent metadata key + with pytest.raises(KeyError, match="Metadata key 'nonexistent.json' not found"): + await store.delete("nonexistent.json") + + # Try to delete nonexistent chunk key (out of bounds) + with pytest.raises(IndexError, match="Chunk coordinate"): + await store.delete("temp/c/3/0") # Out of bounds for 2x2 chunk grid + + # Try to delete nonexistent chunk key (within bounds but not set) + await store.delete("temp/c/0/0") # Should not raise, as it sets to None + assert not await store.exists("temp/c/0/0") + assert store._dirty_shards # Shard is marked dirty even if chunk was already None + +@pytest.mark.asyncio +async def test_delete_read_only_store(create_ipfs: tuple[str, str]): + """Tests deletion attempt on a read-only store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize writable store and add data + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + chunk_key = "temp/c/0/0" + proto = zarr.core.buffer.default_buffer_prototype() + await store_write.set(chunk_key, proto.buffer.from_bytes(b"test_data")) + metadata_key = "temp/zarr.json" + await store_write.set(metadata_key, proto.buffer.from_bytes(b'{"shape": [20, 20]}')) + root_cid = await store_write.flush() + + # Open as read-only + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + + # Try to delete chunk + with pytest.raises(PermissionError, match="Cannot delete from a read-only store"): + await store_read_only.delete(chunk_key) + + # Try to delete metadata + with pytest.raises(PermissionError, match="Cannot delete from a read-only store"): + await store_read_only.delete(metadata_key) + +@pytest.mark.asyncio +async def test_delete_concurrency(create_ipfs: tuple[str, str]): + """Tests concurrent delete operations to ensure shard locking works.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Write multiple chunks + proto = zarr.core.buffer.default_buffer_prototype() + chunk_keys = ["temp/c/0/0", "temp/c/1/0", "temp/c/0/1"] + for key in chunk_keys: + await store.set(key, proto.buffer.from_bytes(f"data_{key}".encode("utf-8"))) + assert await store.exists(key) + + # Define concurrent delete tasks + async def delete_task(key): + await store.delete(key) + + # Run concurrent deletes + tasks = [delete_task(key) for key in chunk_keys] + await asyncio.gather(*tasks) + + # Verify all chunks are deleted + for key in chunk_keys: + assert not await store.exists(key) + assert await store.get(key, proto) is None + + # Verify shards are marked dirty + assert store._dirty_shards # At least one shard should be dirty + + # Flush and verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + for key in chunk_keys: + assert not await store_read.exists(key) + assert await store_read.get(key, proto) is None + +@pytest.mark.asyncio +async def test_delete_with_dataset(create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset): + """Tests deletion of chunks and metadata in a store with a full dataset.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Write dataset + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store, mode="w") + root_cid = await store.flush() + + # Re-open store + store = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + + # Delete a chunk + chunk_key = "temp/c/0/0/0" + assert await store.exists(chunk_key) + await store.delete(chunk_key) + assert not await store.exists(chunk_key) + + # Delete metadata + metadata_key = "temp/zarr.json" + assert await store.exists(metadata_key) + await store.delete(metadata_key) + assert not await store.exists(metadata_key) + + # Flush and verify + new_root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + assert not await store_read.exists(chunk_key) + assert not await store_read.exists(metadata_key) + + # Verify other data remains intact + other_chunk_key = "temp/c/1/0/0" + assert await store_read.exists(other_chunk_key) + +@pytest.mark.asyncio +async def test_supports_writes_property(create_ipfs: tuple[str, str]): + """Tests the supports_writes property.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test writable store + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + assert store_write.supports_writes is True + + # Test read-only store + root_cid = await store_write.flush() + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read_only.supports_writes is False + +@pytest.mark.asyncio +async def test_supports_partial_writes_property(create_ipfs: tuple[str, str]): + """Tests the supports_partial_writes property.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test for both read-only and writable stores + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + assert store_write.supports_partial_writes is False + + root_cid = await store_write.flush() + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read_only.supports_partial_writes is False \ No newline at end of file diff --git a/tests/test_sharded_store_grafting.py b/tests/test_sharded_store_grafting.py new file mode 100644 index 0000000..5dba41d --- /dev/null +++ b/tests/test_sharded_store_grafting.py @@ -0,0 +1,390 @@ +import asyncio + +import dag_cbor +import numpy as np +import pandas as pd +import pytest +import xarray as xr +import zarr.core.buffer + +from py_hamt import KuboCAS, ShardedZarrStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset for benchmarking.""" + # Using a slightly larger dataset for a more meaningful benchmark + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + temp = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], temp), + }, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) + yield ds + +@pytest.mark.asyncio +async def test_graft_store_success(create_ipfs: tuple[str, str]): + """Tests successful grafting of a source store onto a target store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize source store + source_shape = (20, 20) + chunk_shape = (10, 10) + chunks_per_shard = 4 + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=source_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + + # Write some chunk data to source store + proto = zarr.core.buffer.default_buffer_prototype() + chunk_key = "temp/c/0/0" + chunk_data = b"test_source_data" + await source_store.set(chunk_key, proto.buffer.from_bytes(chunk_data)) + source_root_cid = await source_store.flush() + + # Initialize target store with larger shape to accommodate graft + target_shape = (40, 20) + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=target_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + + # Graft source store onto target store at offset (1, 0) + chunk_offset = (1, 0) + await target_store.graft_store(source_root_cid, chunk_offset) + + # Verify grafted data + grafted_chunk_key = "temp/c/1/0" # Offset (1,0) corresponds to chunk (1,0) + assert await target_store.exists(grafted_chunk_key) + grafted_data = await target_store.get(grafted_chunk_key, proto) + assert grafted_data is not None + assert grafted_data.to_bytes() == chunk_data + + # Verify original chunk position in source store is not present in target + assert not await target_store.exists("temp/c/0/0") + + # Verify target store's geometry unchanged + assert target_store._array_shape == target_shape + assert target_store._chunks_per_dim == (4, 2) # ceil(40/10) = 4 + assert target_store._total_chunks == 8 # 4 * 2 + assert target_store._num_shards == 2 # ceil(8/4) = 2 + assert target_store._dirty_shards # Grafting marks shards as dirty + + # Flush and verify persistence + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + assert await target_store_read.exists(grafted_chunk_key) + read_data = await target_store_read.get(grafted_chunk_key, proto) + assert read_data.to_bytes() == chunk_data + assert target_store_read._array_shape == target_shape + +@pytest.mark.asyncio +async def test_graft_store_with_dataset(create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset): + """Tests grafting a store containing a full dataset.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Write dataset to source store + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=source_store, mode="w") + source_root_cid = await source_store.flush() + + # Initialize target store with larger shape + target_shape = (array_shape_tuple[0] + 20, array_shape_tuple[1], array_shape_tuple[2]) + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=target_shape, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + + # Graft source store at offset (1, 0, 0) + chunk_offset = (1, 0, 0) + await target_store.graft_store(source_root_cid, chunk_offset) + + # Verify grafted chunk data + proto = zarr.core.buffer.default_buffer_prototype() + source_chunk_key = "temp/c/0/0/0" + target_chunk_key = "temp/c/1/0/0" # Offset by 1 in time dimension + assert await target_store.exists(target_chunk_key) + source_data = await source_store.get(source_chunk_key, proto) + target_data = await target_store.get(target_chunk_key, proto) + assert source_data is not None + assert target_data is not None + assert source_data.to_bytes() == target_data.to_bytes() + + # Verify metadata is not grafted + assert not await target_store.exists("temp/zarr.json") + + # Flush and verify persistence + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + assert await target_store_read.exists(target_chunk_key) + read_data = await target_store_read.get(target_chunk_key, proto) + assert read_data.to_bytes() == source_data.to_bytes() + +@pytest.mark.asyncio +async def test_graft_store_empty_source(create_ipfs: tuple[str, str]): + """Tests grafting an empty source store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize empty source store + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + source_root_cid = await source_store.flush() + + # Initialize target store + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(40, 40), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Graft empty source store + await target_store.graft_store(source_root_cid, chunk_offset=(1, 1)) + + # Verify no chunks were grafted + assert not await target_store.exists("temp/c/1/1") + assert not target_store._dirty_shards # No shards marked dirty since no changes + + # Flush and verify + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + assert not await target_store_read.exists("temp/c/1/1") + +@pytest.mark.asyncio +async def test_graft_store_invalid_cases(create_ipfs: tuple[str, str]): + """Tests error handling in graft_store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize target store + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(40, 40), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Test read-only target store + target_store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=True, + root_cid=await target_store.flush(), + ) + with pytest.raises(PermissionError, match="Cannot graft onto a read-only store"): + await target_store_read_only.graft_store("some_cid", chunk_offset=(0, 0)) + + # Test invalid source CID + # invalid_cid = "invalid_cid" + # with pytest.raises(ValueError, match="Store to graft could not be loaded"): + # await target_store.graft_store(invalid_cid, chunk_offset=(0, 0)) + + # Test source store with invalid configuration + invalid_root_obj = { + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": [10, 10], + "chunk_shape": [5, 5], + "sharding_config": {"chunks_per_shard": 4}, + "shard_cids": [None] * 4, + }, + } + invalid_root_cid = await kubo_cas.save( + dag_cbor.encode(invalid_root_obj), codec="dag-cbor" + ) + with pytest.raises(ValueError, match="Inconsistent number of shards"): + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_root_cid + ) + # source_store._chunks_per_dim = None # Simulate unconfigured store + # with pytest.raises(ValueError, match="Inconsistent number of shards"): + # await target_store.graft_store(invalid_root_cid, chunk_offset=(0, 0)) + + # Test grafting out-of-bounds offset + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + # Write some data to source store + proto = zarr.core.buffer.default_buffer_prototype() + await source_store.set("temp/c/0/0", proto.buffer.from_bytes(b"data")) + source_root_cid = await source_store.flush() + with pytest.raises(ValueError, match="Shard index 10 out of bounds."): + await target_store.graft_store(source_root_cid, chunk_offset=(10, 0)) # Out of bounds for target (4x4 chunks) + +@pytest.mark.asyncio +async def test_graft_store_concurrency(create_ipfs: tuple[str, str]): + """Tests concurrent graft_store operations to ensure shard locking works.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize source stores + source_shape = (20, 20) + chunk_shape = (10, 10) + chunks_per_shard = 4 + source_store1 = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=source_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + proto = zarr.core.buffer.default_buffer_prototype() + await source_store1.set("temp/c/0/0", proto.buffer.from_bytes(b"data1")) + source_cid1 = await source_store1.flush() + + source_store2 = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=source_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + await source_store2.set("temp/c/0/0", proto.buffer.from_bytes(b"data2")) + source_cid2 = await source_store2.flush() + + # Initialize target store + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(40, 40), + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + + # Define graft tasks + async def graft_task(cid, offset): + await target_store.graft_store(cid, chunk_offset=offset) + + # Run concurrent grafts + tasks = [ + graft_task(source_cid1, (1, 1)), + graft_task(source_cid2, (2, 2)), + ] + await asyncio.gather(*tasks) + + # Verify grafted data + assert await target_store.exists("temp/c/1/1") + assert await target_store.exists("temp/c/2/2") + data1 = await target_store.get("temp/c/1/1", proto) + data2 = await target_store.get("temp/c/2/2", proto) + assert data1.to_bytes() in [b"data1", b"data2"] + assert data2.to_bytes() in [b"data1", b"data2"] + assert data1.to_bytes() != data2.to_bytes() # Ensure distinct data + + # Flush and verify persistence + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + assert await target_store_read.exists("temp/c/1/1") + assert await target_store_read.exists("temp/c/2/2") + +@pytest.mark.asyncio +async def test_graft_store_overlapping_chunks(create_ipfs: tuple[str, str]): + """Tests grafting when target already has data at some chunk positions.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize source store + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + proto = zarr.core.buffer.default_buffer_prototype() + source_chunk_key = "temp/c/0/0" + source_data = b"source_data" + await source_store.set(source_chunk_key, proto.buffer.from_bytes(source_data)) + source_root_cid = await source_store.flush() + + # Initialize target store with some existing data + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(40, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + target_chunk_key = "temp/c/1/1" + existing_data = b"existing_data" + await target_store.set(target_chunk_key, proto.buffer.from_bytes(existing_data)) + + # Graft source store at offset (1, 0) + await target_store.graft_store(source_root_cid, chunk_offset=(1, 0)) + + # Verify that existing data was not overwritten + read_data = await target_store.get(target_chunk_key, proto) + assert read_data.to_bytes() == existing_data + assert target_store._dirty_shards # Shard is marked dirty due to attempted write + + # Verify other grafted chunks + grafted_chunk_key = "temp/c/1/0" # Corresponds to source (0,0) at offset (1,0) + assert await target_store.exists(grafted_chunk_key) + read_data = await target_store.get(grafted_chunk_key, proto) + assert read_data.to_bytes() == source_data + + # Flush and verify + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + assert (await target_store_read.get(target_chunk_key, proto)).to_bytes() == existing_data + assert (await target_store_read.get(grafted_chunk_key, proto)).to_bytes() == source_data \ No newline at end of file diff --git a/tests/test_sharded_store_resizing.py b/tests/test_sharded_store_resizing.py new file mode 100644 index 0000000..5a0fb4d --- /dev/null +++ b/tests/test_sharded_store_resizing.py @@ -0,0 +1,421 @@ +import asyncio +import json +import math + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +import zarr.core.buffer + +from py_hamt import KuboCAS, ShardedZarrStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset for benchmarking.""" + # Using a slightly larger dataset for a more meaningful benchmark + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + temp = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], temp), + }, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) + yield ds + + +@pytest.mark.asyncio +async def test_resize_store_success(create_ipfs: tuple[str, str]): + """Tests successful resizing of the store's main shard index.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + initial_shape = (20, 20) + chunk_shape = (10, 10) + chunks_per_shard = 4 + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=initial_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + + # Verify initial geometry + assert store._array_shape == initial_shape + assert store._chunk_shape == chunk_shape + assert store._chunks_per_shard == chunks_per_shard + initial_chunks_per_dim = (2, 2) # ceil(20/10) = 2 + assert store._chunks_per_dim == initial_chunks_per_dim + assert store._total_chunks == 4 # 2 * 2 + assert store._num_shards == 1 # ceil(4/4) = 1 + assert len(store._root_obj["chunks"]["shard_cids"]) == 1 + + # Resize to a larger shape + new_shape = (30, 30) + await store.resize_store(new_shape=new_shape) + + # Verify updated geometry + assert store._array_shape == new_shape + assert store._chunks_per_dim == (3, 3) # ceil(30/10) = 3 + assert store._total_chunks == 9 # 3 * 3 + assert store._num_shards == 3 # ceil(9/4) = 3 + assert len(store._root_obj["chunks"]["shard_cids"]) == 3 + assert store._root_obj["chunks"]["array_shape"] == list(new_shape) + assert store._dirty_root is True + + # Verify shard cids extended correctly + assert store._root_obj["chunks"]["shard_cids"][1] is None + assert store._root_obj["chunks"]["shard_cids"][2] is None + + # Flush and verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read._array_shape == new_shape + assert store_read._num_shards == 3 + assert len(store_read._root_obj["chunks"]["shard_cids"]) == 3 + + # Resize to a smaller shape + smaller_shape = (10, 10) + await store.resize_store(new_shape=smaller_shape) + assert store._array_shape == smaller_shape + assert store._chunks_per_dim == (1, 1) # ceil(10/10) = 1 + assert store._total_chunks == 1 # 1 * 1 + assert store._num_shards == 1 # ceil(1/4) = 1 + assert len(store._root_obj["chunks"]["shard_cids"]) == 1 + assert store._dirty_root is True + + # Flush and verify + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read._array_shape == smaller_shape + assert store_read._num_shards == 1 + +@pytest.mark.asyncio +async def test_resize_store_zero_sized_array(create_ipfs: tuple[str, str]): + """Tests resizing to/from a zero-sized array.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize with zero-sized dimension + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 0), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + assert store._total_chunks == 0 + assert store._num_shards == 0 + assert len(store._root_obj["chunks"]["shard_cids"]) == 0 + + # Resize to non-zero shape + new_shape = (20, 20) + await store.resize_store(new_shape=new_shape) + assert store._array_shape == new_shape + assert store._chunks_per_dim == (2, 2) + assert store._total_chunks == 4 + assert store._num_shards == 1 + assert len(store._root_obj["chunks"]["shard_cids"]) == 1 + assert store._dirty_root is True + + # Resize back to zero-sized + zero_shape = (0, 20) + await store.resize_store(new_shape=zero_shape) + assert store._array_shape == zero_shape + assert store._total_chunks == 0 + assert store._num_shards == 0 + assert len(store._root_obj["chunks"]["shard_cids"]) == 0 + + # Verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read._array_shape == zero_shape + assert store_read._num_shards == 0 + +@pytest.mark.asyncio +async def test_resize_store_invalid_cases(create_ipfs: tuple[str, str]): + """Tests error handling in resize_store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Test read-only store + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=True, + root_cid=await store.flush(), + ) + with pytest.raises(PermissionError, match="Cannot resize a read-only store"): + await store_read_only.resize_store(new_shape=(30, 30)) + + # Test wrong number of dimensions + with pytest.raises( + ValueError, + match="New shape must have the same number of dimensions as the old shape", + ): + await store.resize_store(new_shape=(30, 30, 30)) + + # Test uninitialized store (simulate by setting attributes to None) + store._chunk_shape = None + store._chunks_per_shard = None + with pytest.raises( + RuntimeError, match="Store is not properly initialized for resizing" + ): + await store.resize_store(new_shape=(30, 30)) + +@pytest.mark.asyncio +async def test_resize_variable_success(create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset): + """Tests successful resizing of a variable's metadata.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Write dataset + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store, mode="w") + root_cid = await store.flush() + + # Re-open store + store = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + variable_name = "temp" + zarr_metadata_key = f"{variable_name}/zarr.json" + assert zarr_metadata_key in store._root_obj["metadata"] + + # Resize variable + new_shape = (150, 18, 36) # Extend time dimension + await store.resize_variable(variable_name=variable_name, new_shape=new_shape) + + # Verify metadata updated + new_metadata_cid = store._root_obj["metadata"][zarr_metadata_key] + new_metadata_bytes = await kubo_cas.load(new_metadata_cid) + new_metadata = json.loads(new_metadata_bytes) + assert new_metadata["shape"] == list(new_shape) + assert store._dirty_root is True + + # Verify store's main array shape unchanged + assert store._array_shape == array_shape_tuple + + # Flush and verify persistence + new_root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + read_metadata_cid = store_read._root_obj["metadata"][zarr_metadata_key] + read_metadata_bytes = await kubo_cas.load(read_metadata_cid) + read_metadata = json.loads(read_metadata_bytes) + assert read_metadata["shape"] == list(new_shape) + +@pytest.mark.asyncio +async def test_resize_variable_invalid_cases(create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset): + """Tests error handling in resize_variable.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Write dataset + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store, mode="w") + root_cid = await store.flush() + + # Re-open store + store = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + + # Test read-only store + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + with pytest.raises(PermissionError, match="Cannot resize a read-only store"): + await store_read_only.resize_variable("temp", new_shape=(150, 18, 36)) + + # Test non-existent variable + with pytest.raises( + KeyError, + match="Cannot find metadata for key 'nonexistent/zarr.json' to resize", + ): + await store.resize_variable("nonexistent", new_shape=(150, 18, 36)) + + # Test invalid metadata (simulate by setting invalid metadata) + invalid_metadata = json.dumps({"not_shape": [1, 2, 3]}).encode("utf-8") + invalid_cid = await kubo_cas.save(invalid_metadata, codec="raw") + store._root_obj["metadata"]["invalid/zarr.json"] = invalid_cid + with pytest.raises(ValueError, match="Shape not found in metadata"): + await store.set("invalid/zarr.json", zarr.core.buffer.default_buffer_prototype().buffer.from_bytes(invalid_metadata)) + +@pytest.mark.asyncio +async def test_resize_store_with_data_preservation(create_ipfs: tuple[str, str]): + """Tests that resizing the store preserves existing data.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Write a chunk + chunk_key = "temp/c/0/0" + chunk_data = b"test_data" + proto = zarr.core.buffer.default_buffer_prototype() + await store.set(chunk_key, proto.buffer.from_bytes(chunk_data)) + root_cid = await store.flush() + + # Verify chunk exists + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert await store_read.exists(chunk_key) + read_chunk = await store_read.get(chunk_key, proto) + assert read_chunk.to_bytes() == chunk_data + + # Resize store + store_write = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + new_shape = (30, 30) + await store_write.resize_store(new_shape=new_shape) + new_root_cid = await store_write.flush() + + # Verify chunk still exists + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + assert await store_read.exists(chunk_key) + read_chunk = await store_read.get(chunk_key, proto) + assert read_chunk.to_bytes() == chunk_data + assert store_read._array_shape == new_shape + assert store_read._num_shards == 3 # ceil((3*3)/4) = 3 + +@pytest.mark.asyncio +async def test_resize_store_in_set_method(create_ipfs: tuple[str, str]): + """Tests that setting zarr.json triggers resize_store appropriately.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Set zarr.json with a new shape + new_shape = [30, 30] + metadata = json.dumps({"shape": new_shape, "dtype": "float32"}).encode("utf-8") + proto = zarr.core.buffer.default_buffer_prototype() + await store.set("temp/zarr.json", proto.buffer.from_bytes(metadata)) + + # Verify resize occurred + assert store._array_shape == tuple(new_shape) + assert store._chunks_per_dim == (3, 3) + assert store._total_chunks == 9 + assert store._num_shards == 3 + assert store._root_obj["chunks"]["array_shape"] == new_shape + assert len(store._root_obj["chunks"]["shard_cids"]) == 3 + + # Verify metadata stored + assert "temp/zarr.json" in store._root_obj["metadata"] + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + metadata_buffer = await store_read.get("temp/zarr.json", proto) + assert json.loads(metadata_buffer.to_bytes())["shape"] == new_shape + +@pytest.mark.asyncio +async def test_resize_concurrency(create_ipfs: tuple[str, str]): + """Tests concurrent resize_store operations to ensure locking works.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Define multiple resize tasks + async def resize_task(shape): + await store.resize_store(new_shape=shape) + + # Run concurrent resize operations + tasks = [ + resize_task((30, 30)), + resize_task((40, 40)), + resize_task((50, 50)), + ] + await asyncio.gather(*tasks) + + # Verify final state (last resize should win, but all are safe due to locking) + assert store._array_shape in [(30, 30), (40, 40), (50, 50)] + expected_chunks_per_dim = tuple(math.ceil(s / 10) for s in store._array_shape) + assert store._chunks_per_dim == expected_chunks_per_dim + assert store._total_chunks == math.prod(expected_chunks_per_dim) + assert store._num_shards == math.ceil(store._total_chunks / 4) + assert len(store._root_obj["chunks"]["shard_cids"]) == store._num_shards + assert store._dirty_root is True \ No newline at end of file diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index dd3e480..5105fd4 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -1,10 +1,13 @@ +import asyncio + import dag_cbor import numpy as np import pandas as pd import pytest import xarray as xr import zarr.core.buffer -from zarr.abc.store import RangeByteRequest +from multiformats import CID +from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest from py_hamt import KuboCAS, ShardedZarrStore @@ -67,6 +70,61 @@ async def test_sharded_zarr_store_write_read( ds_read = xr.open_zarr(store=store_read) xr.testing.assert_identical(test_ds, ds_read) + # Try to set a chunk directly in read-only mode + with pytest.raises(PermissionError): + proto = zarr.core.buffer.default_buffer_prototype() + await store_read.set("temp/c/0/0", proto.buffer.from_bytes(b"test_data")) + + +@pytest.mark.asyncio +async def test_load_or_initialize_shard_cache_concurrent_loads( + create_ipfs: tuple[str, str], +): + """Tests concurrent shard loading to trigger _pending_shard_loads wait.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Create a shard with data + shard_idx = 0 + shard_data = [ + CID.decode("bafyr4idgcwyxddd2mlskpo7vltcicf5mtozlzt4vzpivqmn343hk3c5nbu") + for _ in range(4) + ] + shard_data_bytes = dag_cbor.encode(shard_data) + shard_cid_obj = await kubo_cas.save(shard_data_bytes, codec="dag-cbor") + store._root_obj["chunks"]["shard_cids"][shard_idx] = shard_cid_obj + store._dirty_root = True + await store.flush() + + # Simulate concurrent shard loads + async def load_shard(): + return await store._load_or_initialize_shard_cache(shard_idx) + + # Run multiple tasks concurrently + tasks = [load_shard() for _ in range(3)] + results = await asyncio.gather(*tasks) + + # Verify all tasks return the same shard data + for result in results: + assert len(result) == 4 + assert all(isinstance(cid, CID) for cid in result) + assert result == shard_data + + # Verify shard is cached and no pending loads remain + assert shard_idx in store._shard_data_cache + assert store._shard_data_cache[shard_idx] == shard_data + assert shard_idx not in store._pending_shard_loads + @pytest.mark.asyncio async def test_sharded_zarr_store_append( @@ -514,6 +572,21 @@ async def test_listing_and_metadata( assert "lon" in dir_keys assert "zarr.json" in dir_keys + # Test listing with a prefix + prefix = "temp/" + with pytest.raises( + NotImplementedError, match="Listing with a prefix is not implemented yet." + ): + async for key in store_read.list_dir(prefix): + print(f"Key with prefix '{prefix}': {key}") + + with pytest.raises( + ValueError, match="Byte range requests are not supported for metadata keys." + ): + proto = zarr.core.buffer.default_buffer_prototype() + byte_range = zarr.abc.store.RangeByteRequest(start=10, end=50) + await store_read.get("lat/zarr.json", proto, byte_range=byte_range) + @pytest.mark.asyncio async def test_sharded_zarr_store_init_errors(create_ipfs: tuple[str, str]): @@ -567,6 +640,125 @@ async def test_sharded_zarr_store_init_errors(create_ipfs: tuple[str, str]): ) +@pytest.mark.asyncio +async def test_sharded_zarr_store_get_partial_values( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests the get_partial_values method of ShardedZarrStore, including RangeByteRequest, + OffsetByteRequest, SuffixByteRequest, and full reads, along with error handling for + invalid byte ranges. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # 1. --- Write Dataset to ShardedZarrStore --- + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + assert root_cid is not None + + # 2. --- Open Store for Reading --- + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + proto = zarr.core.buffer.default_buffer_prototype() + + # 3. --- Find a Chunk Key to Test --- + chunk_key = "temp/c/0/0/0" # Default chunk key + # async for key in store_read.list(): + # print(f"Found key: {key}") + # if key.startswith("temp/c/") and not key.endswith(".json"): + # chunk_key = key + # break + + assert chunk_key is not None, "Could not find a chunk key to test." + print(f"Testing with chunk key: {chunk_key}") + + # 4. --- Get Full Chunk Data for Comparison --- + full_chunk_buffer = await store_read.get(chunk_key, proto) + assert full_chunk_buffer is not None + full_chunk_data = full_chunk_buffer.to_bytes() + chunk_len = len(full_chunk_data) + print(f"Full chunk size: {chunk_len} bytes") + + # Ensure chunk is large enough for meaningful partial read tests + assert chunk_len > 100, "Chunk size too small for partial value tests" + + # 5. --- Define Byte Requests --- + range_req = RangeByteRequest(start=10, end=50) # Request 40 bytes + offset_req = OffsetByteRequest(offset=chunk_len - 30) # Last 30 bytes + suffix_req = SuffixByteRequest(suffix=20) # Last 20 bytes + + key_ranges_to_test = [ + (chunk_key, range_req), + (chunk_key, offset_req), + (chunk_key, suffix_req), + (chunk_key, None), # Full read + ] + + # 6. --- Call get_partial_values --- + results = await store_read.get_partial_values(proto, key_ranges_to_test) + + # 7. --- Assertions --- + assert len(results) == 4, "Expected 4 results from get_partial_values" + + assert results[0] is not None, "RangeByteRequest result should not be None" + assert results[1] is not None, "OffsetByteRequest result should not be None" + assert results[2] is not None, "SuffixByteRequest result should not be None" + assert results[3] is not None, "Full read result should not be None" + + # Check RangeByteRequest result + expected_range = full_chunk_data[10:50] + assert results[0].to_bytes() == expected_range, ( + "RangeByteRequest result does not match" + ) + print(f"RangeByteRequest: OK (Got {len(results[0].to_bytes())} bytes)") + + # Check OffsetByteRequest result + expected_offset = full_chunk_data[chunk_len - 30 :] + assert results[1].to_bytes() == expected_offset, ( + "OffsetByteRequest result does not match" + ) + print(f"OffsetByteRequest: OK (Got {len(results[1].to_bytes())} bytes)") + + # Check SuffixByteRequest result + # expected_suffix = full_chunk_data[-20:] + # assert results[2].to_bytes() == expected_suffix, ( + # "SuffixByteRequest result does not match" + # ) + # print(f"SuffixByteRequest: OK (Got {len(results[2].to_bytes())} bytes)") + + # Check full read result + assert results[3].to_bytes() == full_chunk_data, ( + "Full read via get_partial_values does not match" + ) + print(f"Full Read: OK (Got {len(results[3].to_bytes())} bytes)") + + # 8. --- Test Invalid Byte Range --- + invalid_range_req = RangeByteRequest(start=50, end=10) + with pytest.raises( + ValueError, + match="Byte range start.*cannot be greater than end", + ): + await store_read.get_partial_values(proto, [(chunk_key, invalid_range_req)]) + + print("\n✅ get_partial_values test successful! All partial reads verified.") + + # @pytest.mark.asyncio # async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, str]): # """Tests initialization with invalid shapes and manifest errors.""" @@ -680,28 +872,31 @@ async def test_sharded_zarr_store_parse_chunk_key(create_ipfs: tuple[str, str]): assert store._parse_chunk_key("lon/c/0/0") is None # Test uninitialized store - uninitialized_store = ShardedZarrStore(kubo_cas, read_only=False, root_cid=None) - assert uninitialized_store._parse_chunk_key("temp/c/0/0") is None + # uninitialized_store = ShardedZarrStore(kubo_cas, read_only=False, root_cid=None) + # assert uninitialized_store._parse_chunk_key("temp/c/0/0") is None - # Test get on uninitialized store - with pytest.raises( - RuntimeError, match="Load the root object first before accessing data." - ): - proto = zarr.core.buffer.default_buffer_prototype() - await uninitialized_store.get("temp/c/0/0", proto) + # # Test get on uninitialized store + # with pytest.raises( + # RuntimeError, match="Load the root object first before accessing data." + # ): + # proto = zarr.core.buffer.default_buffer_prototype() + # await uninitialized_store.get("temp/c/0/0", proto) - with pytest.raises(RuntimeError, match="Cannot load root without a root_cid."): - await uninitialized_store._load_root_from_cid() + # with pytest.raises(RuntimeError, match="Cannot load root without a root_cid."): + # await uninitialized_store._load_root_from_cid() # Test dimensionality mismatch - assert store._parse_chunk_key("temp/c/0/0/0") is None # 3D key for 2D array + with pytest.raises(IndexError, match="tuple index out of range"): + store._parse_chunk_key("temp/c/0/0/0/0") # Test invalid coordinates - assert ( - store._parse_chunk_key("temp/c/3/0") is None - ) # Out of bounds (3 >= 2 chunks) - assert store._parse_chunk_key("temp/c/0/invalid") is None # Non-integer - assert store._parse_chunk_key("temp/c/0/-1") is None # Negative coordinate + with pytest.raises(ValueError, match="invalid literal"): + store._parse_chunk_key("temp/c/0/invalid") + with pytest.raises(IndexError, match="Chunk coordinate"): + store._parse_chunk_key("temp/c/0/-1") + + with pytest.raises(IndexError, match="Chunk coordinate"): + store._parse_chunk_key("temp/c/3/0") @pytest.mark.asyncio @@ -791,6 +986,7 @@ async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, st cas=kubo_cas, read_only=True, root_cid=invalid_root_cid ) + @pytest.mark.asyncio async def test_sharded_zarr_store_lazy_concat( create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset @@ -806,7 +1002,9 @@ async def test_sharded_zarr_store_lazy_concat( # 1. --- Prepare Two Datasets with Distinct Time Ranges --- # First dataset: August 1, 2024 to September 30, 2024 (61 days) aug_sep_times = pd.date_range("2024-08-01", "2024-09-30", freq="D") - aug_sep_temp = np.random.randn(len(aug_sep_times), len(base_ds.lat), len(base_ds.lon)) + aug_sep_temp = np.random.randn( + len(aug_sep_times), len(base_ds.lat), len(base_ds.lon) + ) ds1 = xr.Dataset( { "temp": (["time", "lat", "lon"], aug_sep_temp), @@ -816,7 +1014,9 @@ async def test_sharded_zarr_store_lazy_concat( # Second dataset: October 1, 2024 to November 30, 2024 (61 days) oct_nov_times = pd.date_range("2024-10-01", "2024-11-30", freq="D") - oct_nov_temp = np.random.randn(len(oct_nov_times), len(base_ds.lat), len(base_ds.lon)) + oct_nov_temp = np.random.randn( + len(oct_nov_times), len(base_ds.lat), len(base_ds.lon) + ) ds2 = xr.Dataset( { "temp": (["time", "lat", "lon"], oct_nov_temp), @@ -914,10 +1114,9 @@ async def test_sharded_zarr_store_lazy_concat( print("\n✅ Lazy concatenation test successful! Data verified.") + @pytest.mark.asyncio -async def test_sharded_zarr_store_lazy_concat_with_cids( - create_ipfs: tuple[str, str] -): +async def test_sharded_zarr_store_lazy_concat_with_cids(create_ipfs: tuple[str, str]): """ Tests lazy concatenation of two xarray datasets stored in separate ShardedZarrStores using provided CIDs for finalized and non-finalized data, ensuring the non-finalized @@ -939,12 +1138,10 @@ async def test_sharded_zarr_store_lazy_concat_with_cids( ds_finalized = xr.open_zarr(store=store_finalized, chunks="auto") # Verify that the dataset is lazy (Dask-backed) - assert ds_finalized['2m_temperature'].chunks is not None + assert ds_finalized["2m_temperature"].chunks is not None # Determine the finalization date (last date in finalized dataset) finalization_date = np.datetime64(ds_finalized.time.max().values) - # Convert to Python datetime for clarity - finalization_date_dt = pd.Timestamp(finalization_date).to_pydatetime() # 2. --- Open Non-Finalized Dataset and Slice After Finalization Date --- store_non_finalized = await ShardedZarrStore.open( @@ -953,11 +1150,11 @@ async def test_sharded_zarr_store_lazy_concat_with_cids( ds_non_finalized = xr.open_zarr(store=store_non_finalized, chunks="auto") # Verify that the dataset is lazy - assert ds_non_finalized['2m_temperature'].chunks is not None + assert ds_non_finalized["2m_temperature"].chunks is not None # Slice non-finalized dataset to start *after* the finalization date # (finalization_date is inclusive for finalized data, so start at +1 hour) - start_time = finalization_date + np.timedelta64(1, 'h') + start_time = finalization_date + np.timedelta64(1, "h") ds_non_finalized_sliced = ds_non_finalized.sel(time=slice(start_time, None)) # Verify that the sliced dataset starts after the finalization date @@ -974,21 +1171,23 @@ async def test_sharded_zarr_store_lazy_concat_with_cids( print("EHRUKHUKEHUK") # Verify that the combined dataset is still lazy - assert combined_ds['2m_temperature'].chunks is not None + assert combined_ds["2m_temperature"].chunks is not None # 4. --- Query Across Both Datasets --- # Select a time slice that spans both datasets # Use a range that includes the boundary (e.g., finalization date and after) - query_start = finalization_date - np.timedelta64(1, 'D') # 1 day before - query_end = finalization_date + np.timedelta64(1, 'D') # 1 day after - query_slice = combined_ds.sel(time=slice(query_start, query_end), latitude=0, longitude=0) + query_start = finalization_date - np.timedelta64(1, "D") # 1 day before + query_end = finalization_date + np.timedelta64(1, "D") # 1 day after + query_slice = combined_ds.sel( + time=slice(query_start, query_end), latitude=0, longitude=0 + ) # Make sure the query slice aligned with the query_start and query_end - + assert query_slice.time.min() >= query_start assert query_slice.time.max() <= query_end # Verify that the query is still lazy - assert query_slice['2m_temperature'].chunks is not None + assert query_slice["2m_temperature"].chunks is not None # Compute the result to trigger data loading query_result = query_slice.compute() @@ -998,28 +1197,39 @@ async def test_sharded_zarr_store_lazy_concat_with_cids( # Check the last finalized time (from finalized dataset) sample_time_finalized = finalization_date if sample_time_finalized in query_result.time.values: - sample_result = query_result.sel(time=sample_time_finalized, method="nearest")['2m_temperature'].values - expected_sample = ds_finalized.sel(time=sample_time_finalized, latitude=0, longitude=0, method="nearest")['2m_temperature'].values + sample_result = query_result.sel( + time=sample_time_finalized, method="nearest" + )["2m_temperature"].values + expected_sample = ds_finalized.sel( + time=sample_time_finalized, latitude=0, longitude=0, method="nearest" + )["2m_temperature"].values np.testing.assert_array_equal(sample_result, expected_sample) # Check the first non-finalized time (from non-finalized dataset, if available) if ds_non_finalized_sliced.time.size > 0: sample_time_non_finalized = ds_non_finalized_sliced.time.min().values if sample_time_non_finalized in query_result.time.values: - sample_result = query_result.sel(time=sample_time_non_finalized, method="nearest")['2m_temperature'].values + sample_result = query_result.sel( + time=sample_time_non_finalized, method="nearest" + )["2m_temperature"].values expected_sample = ds_non_finalized_sliced.sel( - time=sample_time_non_finalized, latitude=0, longitude=0, method="nearest" - )['2m_temperature'].values + time=sample_time_non_finalized, + latitude=0, + longitude=0, + method="nearest", + )["2m_temperature"].values np.testing.assert_array_equal(sample_result, expected_sample) # 6. --- Additional Validation --- # Verify that the concatenated dataset has no overlapping times time_values = combined_ds.time.values - assert np.all(np.diff(time_values) > np.timedelta64(0, 'ns')), "Overlapping or unsorted time values detected" + assert np.all(np.diff(time_values) > np.timedelta64(0, "ns")), ( + "Overlapping or unsorted time values detected" + ) # Verify that the query result covers the expected time range if query_result.time.size > 0: assert query_result.time.min() >= query_start assert query_result.time.max() <= query_end - print("\n✅ Lazy concatenation with CIDs test successful! Data verified.") \ No newline at end of file + print("\n✅ Lazy concatenation with CIDs test successful! Data verified.") diff --git a/tests/test_sharded_zarr_store_coverage.py b/tests/test_sharded_zarr_store_coverage.py new file mode 100644 index 0000000..5e2c6f6 --- /dev/null +++ b/tests/test_sharded_zarr_store_coverage.py @@ -0,0 +1,250 @@ + +import dag_cbor +import pytest +import zarr.abc.store +import zarr.core.buffer + +from py_hamt import KuboCAS +from py_hamt.sharded_zarr_store import ShardedZarrStore + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_init_exceptions(create_ipfs: tuple[str, str]): + """ + Tests various initialization exceptions in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test RuntimeError when base shape information is not set + # with pytest.raises(RuntimeError, match="Base shape information is not set."): + # store = ShardedZarrStore(kubo_cas, False, None) + # store._update_geometry() + + # Test ValueError for non-positive chunk_shape dimensions + with pytest.raises( + ValueError, match="All chunk_shape dimensions must be positive." + ): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(10, 0), + chunks_per_shard=10, + ) + + # Test ValueError for non-negative array_shape dimensions + with pytest.raises( + ValueError, match="All array_shape dimensions must be non-negative." + ): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, -10), + chunk_shape=(10, 10), + chunks_per_shard=10, + ) + + # Test ValueError when array_shape is not provided for a new store + with pytest.raises( + ValueError, + match="array_shape and chunk_shape must be provided for a new store.", + ): + await ShardedZarrStore.open(cas=kubo_cas, read_only=False, chunk_shape=(10, 10)) + + # Test ValueError for non-positive chunks_per_shard + with pytest.raises(ValueError, match="chunks_per_shard must be a positive integer."): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(10, 10), + chunks_per_shard=0, + ) + + # Test ValueError when root_cid is not provided for a read-only store + with pytest.raises(ValueError, match="root_cid must be provided for a read-only store."): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_load_root_exceptions(create_ipfs: tuple[str, str]): + """ + Tests exceptions raised during the loading of the root object. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test RuntimeError when _root_cid is not set + # with pytest.raises(RuntimeError, match="Cannot load root without a root_cid."): + # store = ShardedZarrStore(kubo_cas, True, None) + # await store._load_root_from_cid() + + # Test ValueError for an incompatible manifest version + invalid_manifest_root = { + "manifest_version": "invalid_version", + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [], + }, + } + invalid_manifest_cid = await kubo_cas.save( + dag_cbor.encode(invalid_manifest_root), codec="dag-cbor" + ) + with pytest.raises(ValueError, match="Incompatible manifest version"): + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_manifest_cid + ) + + # Test ValueError for an inconsistent number of shards + inconsistent_shards_root = { + "manifest_version": "sharded_zarr_v1", + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [None, None, None], # Should be 2 shards, but array shape dictates 2 total chunks + }, + } + inconsistent_shards_cid = await kubo_cas.save( + dag_cbor.encode(inconsistent_shards_root), codec="dag-cbor" + ) + with pytest.raises(ValueError, match="Inconsistent number of shards"): + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=inconsistent_shards_cid + ) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_shard_handling_exceptions( + create_ipfs: tuple[str, str], caplog +): + """ + Tests exceptions and logging during shard handling. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=1, + ) + + # Test TypeError when a shard does not decode to a list + invalid_shard_cid = await kubo_cas.save( + dag_cbor.encode({"not": "a list"}), "dag-cbor" + ) + store._root_obj["chunks"]["shard_cids"][0] = invalid_shard_cid + with pytest.raises(TypeError, match="Shard 0 did not decode to a list."): + await store._load_or_initialize_shard_cache(0) + + # bad __eq__ method + assert store != { "not a ShardedZarrStore": "test" } + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_get_set_exceptions(create_ipfs: tuple[str, str]): + """ + Tests exceptions raised during get and set operations. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=1, + ) + proto = zarr.core.buffer.default_buffer_prototype() + + # Test RuntimeError when root object is not loaded in get + # store_no_root = ShardedZarrStore(kubo_cas, True, "some_cid") + # with pytest.raises( + # RuntimeError, match="Load the root object first before accessing data." + # ): + # await store_no_root.get("key", proto) + + # Set some bytes to /c/0 to ensure it exists + await store.set( + "/c/0", + proto.buffer.from_bytes(b'{"shape": [10], "dtype": "float32"}'), + ) + + # Test ValueError for invalid byte range in get + with pytest.raises( + ValueError, + match="Byte range start .* cannot be greater than end .*", + ): + await store.get( + "/c/0", proto, byte_range=zarr.abc.store.RangeByteRequest(start=10, end=5) + ) + + # Test NotImplementedError for set_partial_values + with pytest.raises(NotImplementedError): + await store.set_partial_values([]) + + # Test ValueError when shape is not found in metadata during set + with pytest.raises(ValueError, match="Shape not found in metadata."): + await store.set("test/zarr.json", proto.buffer.from_bytes(b'{"not": "a shape"}')) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_other_exceptions(create_ipfs: tuple[str, str]): + """ + Tests other miscellaneous exceptions in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=1, + ) + + # Test RuntimeError for uninitialized store in flush + # store_no_root = ShardedZarrStore(kubo_cas, False, None) + # with pytest.raises(RuntimeError, match="Store not initialized for writing."): + # await store_no_root.flush() + + # Test ValueError when resizing a store with a different number of dimensions + with pytest.raises( + ValueError, + match="New shape must have the same number of dimensions as the old shape.", + ): + await store.resize_store(new_shape=(10, 10)) + + # Test KeyError when resizing a variable that doesn't exist + with pytest.raises( + KeyError, + match="Cannot find metadata for key 'nonexistent/zarr.json' to resize.", + ): + await store.resize_variable("nonexistent", new_shape=(20,)) + + + # Test RuntimeError when listing a store with no root object + # with pytest.raises(RuntimeError, match="Root object not loaded."): + # async for _ in store_no_root.list(): + # pass + + # # Test RuntimeError when listing directories of a store with no root object + # with pytest.raises(RuntimeError, match="Root object not loaded."): + # async for _ in store_no_root.list_dir(""): + # pass + + # with pytest.raises(ValueError, match="Linear chunk index cannot be negative."): + # await store_no_root._get_shard_info(-1) \ No newline at end of file From 5c2580a7078daff7af051ecd4e25a445969348f4 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Fri, 11 Jul 2025 10:13:28 -0400 Subject: [PATCH 51/74] fix: fix mypy --- tests/test_sharded_store_deleting.py | 30 ++++++++++++---- tests/test_sharded_store_grafting.py | 43 +++++++++++++++++++---- tests/test_sharded_store_resizing.py | 31 ++++++++++++---- tests/test_sharded_zarr_store_coverage.py | 32 +++++++++++------ 4 files changed, 107 insertions(+), 29 deletions(-) diff --git a/tests/test_sharded_store_deleting.py b/tests/test_sharded_store_deleting.py index 86805e7..b67f27c 100644 --- a/tests/test_sharded_store_deleting.py +++ b/tests/test_sharded_store_deleting.py @@ -31,6 +31,7 @@ def random_zarr_dataset(): ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) yield ds + @pytest.mark.asyncio async def test_delete_chunk_success(create_ipfs: tuple[str, str]): """Tests successful deletion of a chunk from the store.""" @@ -72,6 +73,7 @@ async def test_delete_chunk_success(create_ipfs: tuple[str, str]): assert not await store_read.exists(chunk_key) assert await store_read.get(chunk_key, proto) is None + @pytest.mark.asyncio async def test_delete_metadata_success(create_ipfs: tuple[str, str]): """Tests successful deletion of a metadata key.""" @@ -110,6 +112,7 @@ async def test_delete_metadata_success(create_ipfs: tuple[str, str]): assert not await store_read.exists(metadata_key) assert await store_read.get(metadata_key, proto) is None + @pytest.mark.asyncio async def test_delete_nonexistent_key(create_ipfs: tuple[str, str]): """Tests deletion of a nonexistent metadata key.""" @@ -145,7 +148,10 @@ async def test_delete_nonexistent_key(create_ipfs: tuple[str, str]): # Try to delete nonexistent chunk key (within bounds but not set) await store.delete("temp/c/0/0") # Should not raise, as it sets to None assert not await store.exists("temp/c/0/0") - assert store._dirty_shards # Shard is marked dirty even if chunk was already None + assert ( + store._dirty_shards + ) # Shard is marked dirty even if chunk was already None + @pytest.mark.asyncio async def test_delete_read_only_store(create_ipfs: tuple[str, str]): @@ -166,7 +172,9 @@ async def test_delete_read_only_store(create_ipfs: tuple[str, str]): proto = zarr.core.buffer.default_buffer_prototype() await store_write.set(chunk_key, proto.buffer.from_bytes(b"test_data")) metadata_key = "temp/zarr.json" - await store_write.set(metadata_key, proto.buffer.from_bytes(b'{"shape": [20, 20]}')) + await store_write.set( + metadata_key, proto.buffer.from_bytes(b'{"shape": [20, 20]}') + ) root_cid = await store_write.flush() # Open as read-only @@ -175,13 +183,18 @@ async def test_delete_read_only_store(create_ipfs: tuple[str, str]): ) # Try to delete chunk - with pytest.raises(PermissionError, match="Cannot delete from a read-only store"): + with pytest.raises( + PermissionError, match="Cannot delete from a read-only store" + ): await store_read_only.delete(chunk_key) # Try to delete metadata - with pytest.raises(PermissionError, match="Cannot delete from a read-only store"): + with pytest.raises( + PermissionError, match="Cannot delete from a read-only store" + ): await store_read_only.delete(metadata_key) + @pytest.mark.asyncio async def test_delete_concurrency(create_ipfs: tuple[str, str]): """Tests concurrent delete operations to ensure shard locking works.""" @@ -230,8 +243,11 @@ async def delete_task(key): assert not await store_read.exists(key) assert await store_read.get(key, proto) is None + @pytest.mark.asyncio -async def test_delete_with_dataset(create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset): +async def test_delete_with_dataset( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): """Tests deletion of chunks and metadata in a store with a full dataset.""" rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset @@ -282,6 +298,7 @@ async def test_delete_with_dataset(create_ipfs: tuple[str, str], random_zarr_dat other_chunk_key = "temp/c/1/0/0" assert await store_read.exists(other_chunk_key) + @pytest.mark.asyncio async def test_supports_writes_property(create_ipfs: tuple[str, str]): """Tests the supports_writes property.""" @@ -306,6 +323,7 @@ async def test_supports_writes_property(create_ipfs: tuple[str, str]): ) assert store_read_only.supports_writes is False + @pytest.mark.asyncio async def test_supports_partial_writes_property(create_ipfs: tuple[str, str]): """Tests the supports_partial_writes property.""" @@ -327,4 +345,4 @@ async def test_supports_partial_writes_property(create_ipfs: tuple[str, str]): store_read_only = await ShardedZarrStore.open( cas=kubo_cas, read_only=True, root_cid=root_cid ) - assert store_read_only.supports_partial_writes is False \ No newline at end of file + assert store_read_only.supports_partial_writes is False diff --git a/tests/test_sharded_store_grafting.py b/tests/test_sharded_store_grafting.py index 5dba41d..6437d51 100644 --- a/tests/test_sharded_store_grafting.py +++ b/tests/test_sharded_store_grafting.py @@ -31,6 +31,7 @@ def random_zarr_dataset(): ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) yield ds + @pytest.mark.asyncio async def test_graft_store_success(create_ipfs: tuple[str, str]): """Tests successful grafting of a source store onto a target store.""" @@ -95,11 +96,15 @@ async def test_graft_store_success(create_ipfs: tuple[str, str]): ) assert await target_store_read.exists(grafted_chunk_key) read_data = await target_store_read.get(grafted_chunk_key, proto) + assert read_data is not None assert read_data.to_bytes() == chunk_data assert target_store_read._array_shape == target_shape + @pytest.mark.asyncio -async def test_graft_store_with_dataset(create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset): +async def test_graft_store_with_dataset( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): """Tests grafting a store containing a full dataset.""" rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset @@ -122,7 +127,11 @@ async def test_graft_store_with_dataset(create_ipfs: tuple[str, str], random_zar source_root_cid = await source_store.flush() # Initialize target store with larger shape - target_shape = (array_shape_tuple[0] + 20, array_shape_tuple[1], array_shape_tuple[2]) + target_shape = ( + array_shape_tuple[0] + 20, + array_shape_tuple[1], + array_shape_tuple[2], + ) target_store = await ShardedZarrStore.open( cas=kubo_cas, read_only=False, @@ -156,8 +165,10 @@ async def test_graft_store_with_dataset(create_ipfs: tuple[str, str], random_zar ) assert await target_store_read.exists(target_chunk_key) read_data = await target_store_read.get(target_chunk_key, proto) + assert read_data is not None assert read_data.to_bytes() == source_data.to_bytes() + @pytest.mark.asyncio async def test_graft_store_empty_source(create_ipfs: tuple[str, str]): """Tests grafting an empty source store.""" @@ -198,6 +209,7 @@ async def test_graft_store_empty_source(create_ipfs: tuple[str, str]): ) assert not await target_store_read.exists("temp/c/1/1") + @pytest.mark.asyncio async def test_graft_store_invalid_cases(create_ipfs: tuple[str, str]): """Tests error handling in graft_store.""" @@ -220,7 +232,9 @@ async def test_graft_store_invalid_cases(create_ipfs: tuple[str, str]): read_only=True, root_cid=await target_store.flush(), ) - with pytest.raises(PermissionError, match="Cannot graft onto a read-only store"): + with pytest.raises( + PermissionError, match="Cannot graft onto a read-only store" + ): await target_store_read_only.graft_store("some_cid", chunk_offset=(0, 0)) # Test invalid source CID @@ -263,7 +277,10 @@ async def test_graft_store_invalid_cases(create_ipfs: tuple[str, str]): await source_store.set("temp/c/0/0", proto.buffer.from_bytes(b"data")) source_root_cid = await source_store.flush() with pytest.raises(ValueError, match="Shard index 10 out of bounds."): - await target_store.graft_store(source_root_cid, chunk_offset=(10, 0)) # Out of bounds for target (4x4 chunks) + await target_store.graft_store( + source_root_cid, chunk_offset=(10, 0) + ) # Out of bounds for target (4x4 chunks) + @pytest.mark.asyncio async def test_graft_store_concurrency(create_ipfs: tuple[str, str]): @@ -322,6 +339,8 @@ async def graft_task(cid, offset): assert await target_store.exists("temp/c/2/2") data1 = await target_store.get("temp/c/1/1", proto) data2 = await target_store.get("temp/c/2/2", proto) + assert data1 is not None + assert data2 is not None assert data1.to_bytes() in [b"data1", b"data2"] assert data2.to_bytes() in [b"data1", b"data2"] assert data1.to_bytes() != data2.to_bytes() # Ensure distinct data @@ -334,6 +353,7 @@ async def graft_task(cid, offset): assert await target_store_read.exists("temp/c/1/1") assert await target_store_read.exists("temp/c/2/2") + @pytest.mark.asyncio async def test_graft_store_overlapping_chunks(create_ipfs: tuple[str, str]): """Tests grafting when target already has data at some chunk positions.""" @@ -372,13 +392,17 @@ async def test_graft_store_overlapping_chunks(create_ipfs: tuple[str, str]): # Verify that existing data was not overwritten read_data = await target_store.get(target_chunk_key, proto) + assert read_data is not None assert read_data.to_bytes() == existing_data - assert target_store._dirty_shards # Shard is marked dirty due to attempted write + assert ( + target_store._dirty_shards + ) # Shard is marked dirty due to attempted write # Verify other grafted chunks grafted_chunk_key = "temp/c/1/0" # Corresponds to source (0,0) at offset (1,0) assert await target_store.exists(grafted_chunk_key) read_data = await target_store.get(grafted_chunk_key, proto) + assert read_data is not None assert read_data.to_bytes() == source_data # Flush and verify @@ -386,5 +410,10 @@ async def test_graft_store_overlapping_chunks(create_ipfs: tuple[str, str]): target_store_read = await ShardedZarrStore.open( cas=kubo_cas, read_only=True, root_cid=target_root_cid ) - assert (await target_store_read.get(target_chunk_key, proto)).to_bytes() == existing_data - assert (await target_store_read.get(grafted_chunk_key, proto)).to_bytes() == source_data \ No newline at end of file + targeted_target_chunk = await target_store_read.get(target_chunk_key, proto) + assert targeted_target_chunk is not None + assert targeted_target_chunk.to_bytes() == existing_data + + grafted_chunk_data = await target_store_read.get(grafted_chunk_key, proto) + assert grafted_chunk_data is not None + assert grafted_chunk_data.to_bytes() == source_data diff --git a/tests/test_sharded_store_resizing.py b/tests/test_sharded_store_resizing.py index 5a0fb4d..ce43465 100644 --- a/tests/test_sharded_store_resizing.py +++ b/tests/test_sharded_store_resizing.py @@ -106,6 +106,7 @@ async def test_resize_store_success(create_ipfs: tuple[str, str]): assert store_read._array_shape == smaller_shape assert store_read._num_shards == 1 + @pytest.mark.asyncio async def test_resize_store_zero_sized_array(create_ipfs: tuple[str, str]): """Tests resizing to/from a zero-sized array.""" @@ -151,6 +152,7 @@ async def test_resize_store_zero_sized_array(create_ipfs: tuple[str, str]): assert store_read._array_shape == zero_shape assert store_read._num_shards == 0 + @pytest.mark.asyncio async def test_resize_store_invalid_cases(create_ipfs: tuple[str, str]): """Tests error handling in resize_store.""" @@ -184,15 +186,18 @@ async def test_resize_store_invalid_cases(create_ipfs: tuple[str, str]): await store.resize_store(new_shape=(30, 30, 30)) # Test uninitialized store (simulate by setting attributes to None) - store._chunk_shape = None - store._chunks_per_shard = None + store._chunk_shape = None # type: ignore + store._chunks_per_shard = None # type: ignore with pytest.raises( RuntimeError, match="Store is not properly initialized for resizing" ): await store.resize_store(new_shape=(30, 30)) + @pytest.mark.asyncio -async def test_resize_variable_success(create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset): +async def test_resize_variable_success( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): """Tests successful resizing of a variable's metadata.""" rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset @@ -246,8 +251,11 @@ async def test_resize_variable_success(create_ipfs: tuple[str, str], random_zarr read_metadata = json.loads(read_metadata_bytes) assert read_metadata["shape"] == list(new_shape) + @pytest.mark.asyncio -async def test_resize_variable_invalid_cases(create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset): +async def test_resize_variable_invalid_cases( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): """Tests error handling in resize_variable.""" rpc_base_url, gateway_base_url = create_ipfs test_ds = random_zarr_dataset @@ -293,7 +301,13 @@ async def test_resize_variable_invalid_cases(create_ipfs: tuple[str, str], rando invalid_cid = await kubo_cas.save(invalid_metadata, codec="raw") store._root_obj["metadata"]["invalid/zarr.json"] = invalid_cid with pytest.raises(ValueError, match="Shape not found in metadata"): - await store.set("invalid/zarr.json", zarr.core.buffer.default_buffer_prototype().buffer.from_bytes(invalid_metadata)) + await store.set( + "invalid/zarr.json", + zarr.core.buffer.default_buffer_prototype().buffer.from_bytes( + invalid_metadata + ), + ) + @pytest.mark.asyncio async def test_resize_store_with_data_preservation(create_ipfs: tuple[str, str]): @@ -324,6 +338,7 @@ async def test_resize_store_with_data_preservation(create_ipfs: tuple[str, str]) ) assert await store_read.exists(chunk_key) read_chunk = await store_read.get(chunk_key, proto) + assert read_chunk is not None assert read_chunk.to_bytes() == chunk_data # Resize store @@ -340,10 +355,12 @@ async def test_resize_store_with_data_preservation(create_ipfs: tuple[str, str]) ) assert await store_read.exists(chunk_key) read_chunk = await store_read.get(chunk_key, proto) + assert read_chunk is not None assert read_chunk.to_bytes() == chunk_data assert store_read._array_shape == new_shape assert store_read._num_shards == 3 # ceil((3*3)/4) = 3 + @pytest.mark.asyncio async def test_resize_store_in_set_method(create_ipfs: tuple[str, str]): """Tests that setting zarr.json triggers resize_store appropriately.""" @@ -381,8 +398,10 @@ async def test_resize_store_in_set_method(create_ipfs: tuple[str, str]): cas=kubo_cas, read_only=True, root_cid=root_cid ) metadata_buffer = await store_read.get("temp/zarr.json", proto) + assert metadata_buffer is not None assert json.loads(metadata_buffer.to_bytes())["shape"] == new_shape + @pytest.mark.asyncio async def test_resize_concurrency(create_ipfs: tuple[str, str]): """Tests concurrent resize_store operations to ensure locking works.""" @@ -418,4 +437,4 @@ async def resize_task(shape): assert store._total_chunks == math.prod(expected_chunks_per_dim) assert store._num_shards == math.ceil(store._total_chunks / 4) assert len(store._root_obj["chunks"]["shard_cids"]) == store._num_shards - assert store._dirty_root is True \ No newline at end of file + assert store._dirty_root is True diff --git a/tests/test_sharded_zarr_store_coverage.py b/tests/test_sharded_zarr_store_coverage.py index 5e2c6f6..b53139d 100644 --- a/tests/test_sharded_zarr_store_coverage.py +++ b/tests/test_sharded_zarr_store_coverage.py @@ -1,4 +1,3 @@ - import dag_cbor import pytest import zarr.abc.store @@ -51,10 +50,14 @@ async def test_sharded_zarr_store_init_exceptions(create_ipfs: tuple[str, str]): ValueError, match="array_shape and chunk_shape must be provided for a new store.", ): - await ShardedZarrStore.open(cas=kubo_cas, read_only=False, chunk_shape=(10, 10)) + await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, chunk_shape=(10, 10) + ) # Test ValueError for non-positive chunks_per_shard - with pytest.raises(ValueError, match="chunks_per_shard must be a positive integer."): + with pytest.raises( + ValueError, match="chunks_per_shard must be a positive integer." + ): await ShardedZarrStore.open( cas=kubo_cas, read_only=False, @@ -64,7 +67,9 @@ async def test_sharded_zarr_store_init_exceptions(create_ipfs: tuple[str, str]): ) # Test ValueError when root_cid is not provided for a read-only store - with pytest.raises(ValueError, match="root_cid must be provided for a read-only store."): + with pytest.raises( + ValueError, match="root_cid must be provided for a read-only store." + ): await ShardedZarrStore.open(cas=kubo_cas, read_only=True) @@ -107,7 +112,11 @@ async def test_sharded_zarr_store_load_root_exceptions(create_ipfs: tuple[str, s "array_shape": [10], "chunk_shape": [5], "sharding_config": {"chunks_per_shard": 1}, - "shard_cids": [None, None, None], # Should be 2 shards, but array shape dictates 2 total chunks + "shard_cids": [ + None, + None, + None, + ], # Should be 2 shards, but array shape dictates 2 total chunks }, } inconsistent_shards_cid = await kubo_cas.save( @@ -147,7 +156,7 @@ async def test_sharded_zarr_store_shard_handling_exceptions( await store._load_or_initialize_shard_cache(0) # bad __eq__ method - assert store != { "not a ShardedZarrStore": "test" } + assert store != {"not a ShardedZarrStore": "test"} @pytest.mark.asyncio @@ -187,7 +196,9 @@ async def test_sharded_zarr_store_get_set_exceptions(create_ipfs: tuple[str, str match="Byte range start .* cannot be greater than end .*", ): await store.get( - "/c/0", proto, byte_range=zarr.abc.store.RangeByteRequest(start=10, end=5) + "/c/0", + proto, + byte_range=zarr.abc.store.RangeByteRequest(start=10, end=5), ) # Test NotImplementedError for set_partial_values @@ -196,7 +207,9 @@ async def test_sharded_zarr_store_get_set_exceptions(create_ipfs: tuple[str, str # Test ValueError when shape is not found in metadata during set with pytest.raises(ValueError, match="Shape not found in metadata."): - await store.set("test/zarr.json", proto.buffer.from_bytes(b'{"not": "a shape"}')) + await store.set( + "test/zarr.json", proto.buffer.from_bytes(b'{"not": "a shape"}') + ) @pytest.mark.asyncio @@ -234,7 +247,6 @@ async def test_sharded_zarr_store_other_exceptions(create_ipfs: tuple[str, str]) match="Cannot find metadata for key 'nonexistent/zarr.json' to resize.", ): await store.resize_variable("nonexistent", new_shape=(20,)) - # Test RuntimeError when listing a store with no root object # with pytest.raises(RuntimeError, match="Root object not loaded."): @@ -247,4 +259,4 @@ async def test_sharded_zarr_store_other_exceptions(create_ipfs: tuple[str, str]) # pass # with pytest.raises(ValueError, match="Linear chunk index cannot be negative."): - # await store_no_root._get_shard_info(-1) \ No newline at end of file + # await store_no_root._get_shard_info(-1) From 2fcff2bdd07a8b40963e8a8b73202b6cd30c63da Mon Sep 17 00:00:00 2001 From: Faolain Date: Tue, 15 Jul 2025 12:19:29 -0400 Subject: [PATCH 52/74] ci: update ipfs from 0.35 to 0.36 --- .github/workflows/run-checks.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run-checks.yaml b/.github/workflows/run-checks.yaml index c1fedeb..809c050 100644 --- a/.github/workflows/run-checks.yaml +++ b/.github/workflows/run-checks.yaml @@ -39,7 +39,7 @@ jobs: - name: Install IPFS uses: oduwsdl/setup-ipfs@e92fedca9f61ab9184cb74940254859f4d7af4d9 # v0.6.3 with: - ipfs_version: "0.35.0" + ipfs_version: "0.36.0" run_daemon: true - name: Run pytest with coverage From 47c04c520c88a237d73a83b4b2fca7a04f221904 Mon Sep 17 00:00:00 2001 From: Faolain Date: Tue, 15 Jul 2025 12:27:09 -0400 Subject: [PATCH 53/74] ci: revert back to 0.35 from 0.36 to find out why Error: IPFS API service unreachable --- .github/workflows/run-checks.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run-checks.yaml b/.github/workflows/run-checks.yaml index 809c050..c1fedeb 100644 --- a/.github/workflows/run-checks.yaml +++ b/.github/workflows/run-checks.yaml @@ -39,7 +39,7 @@ jobs: - name: Install IPFS uses: oduwsdl/setup-ipfs@e92fedca9f61ab9184cb74940254859f4d7af4d9 # v0.6.3 with: - ipfs_version: "0.36.0" + ipfs_version: "0.35.0" run_daemon: true - name: Run pytest with coverage From a9a36fb78873a7eff8a47e8d315ffb10e9f6ccb1 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 16 Jul 2025 02:04:15 -0400 Subject: [PATCH 54/74] fix: change integer math --- py_hamt/sharded_zarr_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 9bfea84..3450e52 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -78,7 +78,7 @@ def _update_geometry(self): self._total_chunks = math.prod(self._chunks_per_dim) if not self._total_chunks == 0: - self._num_shards = math.ceil(self._total_chunks / self._chunks_per_shard) + self._num_shards = (self._total_chunks + self._chunks_per_shard - 1) // self._chunks_per_shard @classmethod async def open( @@ -553,7 +553,7 @@ async def resize_store(self, new_shape: Tuple[int, ...]): self._total_chunks = math.prod(self._chunks_per_dim) old_num_shards = self._num_shards if self._num_shards is not None else 0 self._num_shards = ( - math.ceil(self._total_chunks / self._chunks_per_shard) + (self._total_chunks + self._chunks_per_shard - 1) // self._chunks_per_shard if self._total_chunks > 0 else 0 ) From 07bae20bc2d710010dc080db81cc737018f68daf Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 16 Jul 2025 02:35:02 -0400 Subject: [PATCH 55/74] fix: print debug --- py_hamt/sharded_zarr_store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 3450e52..94876a5 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -511,6 +511,7 @@ async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, .. c_local + c_offset for c_local, c_offset in zip(local_coords, chunk_offset) ) + print(local_coords, "->", global_coords, chunk_offset) linear_global_index = self._get_linear_chunk_index(global_coords) global_shard_idx, index_in_global_shard = self._get_shard_info( linear_global_index From fcf011918bdf8711de759e05f2554012cec63f36 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 16 Jul 2025 02:40:49 -0400 Subject: [PATCH 56/74] fix: remove debug --- py_hamt/sharded_zarr_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 94876a5..3450e52 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -511,7 +511,6 @@ async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, .. c_local + c_offset for c_local, c_offset in zip(local_coords, chunk_offset) ) - print(local_coords, "->", global_coords, chunk_offset) linear_global_index = self._get_linear_chunk_index(global_coords) global_shard_idx, index_in_global_shard = self._get_shard_info( linear_global_index From 1ef78a8252c7209989fc522b231333a688ec2471 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 16 Jul 2025 07:54:57 -0400 Subject: [PATCH 57/74] fix: reformat --- py_hamt/sharded_zarr_store.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 3450e52..e5f571d 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -78,7 +78,9 @@ def _update_geometry(self): self._total_chunks = math.prod(self._chunks_per_dim) if not self._total_chunks == 0: - self._num_shards = (self._total_chunks + self._chunks_per_shard - 1) // self._chunks_per_shard + self._num_shards = ( + self._total_chunks + self._chunks_per_shard - 1 + ) // self._chunks_per_shard @classmethod async def open( From 63e59a8fd9e9cbfd655fad78900d880d79564020 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 24 Jul 2025 06:12:19 -0400 Subject: [PATCH 58/74] fix: print key --- py_hamt/sharded_zarr_store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index e5f571d..2c99f70 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -178,6 +178,7 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # 1. Exclude .json files immediately (metadata) if key.endswith(".json"): return None + print(key) excluded_array_prefixes = {"time", "lat", "lon", "latitude", "longitude"} chunk_marker = "/c/" From d2fca8155ea5be1518aaab5cd339892b922dd714 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Fri, 25 Jul 2025 07:16:22 -0400 Subject: [PATCH 59/74] fix: update tests --- py_hamt/sharded_zarr_store.py | 1 - tests/test_sharded_zarr_store.py | 94 +------------------------------- 2 files changed, 2 insertions(+), 93 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 2c99f70..e5f571d 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -178,7 +178,6 @@ def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # 1. Exclude .json files immediately (metadata) if key.endswith(".json"): return None - print(key) excluded_array_prefixes = {"time", "lat", "lon", "latitude", "longitude"} chunk_marker = "/c/" diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 5105fd4..72ea322 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -758,95 +758,6 @@ async def test_sharded_zarr_store_get_partial_values( print("\n✅ get_partial_values test successful! All partial reads verified.") - -# @pytest.mark.asyncio -# async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, str]): -# """Tests initialization with invalid shapes and manifest errors.""" -# rpc_base_url, gateway_base_url = create_ipfs -# async with KuboCAS( -# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url -# ) as kubo_cas: -# # Test negative chunk_shape dimension (line 136) -# with pytest.raises( -# ValueError, match="All chunk_shape dimensions must be positive" -# ): -# await ShardedZarrStore.open( -# cas=kubo_cas, -# read_only=False, -# array_shape=(10, 10), -# chunk_shape=(-5, 5), -# chunks_per_shard=10, -# ) - -# # Test negative array_shape dimension (line 141) -# with pytest.raises( -# ValueError, match="All array_shape dimensions must be non-negative" -# ): -# await ShardedZarrStore.open( -# cas=kubo_cas, -# read_only=False, -# array_shape=(10, -10), -# chunk_shape=(5, 5), -# chunks_per_shard=10, -# ) - -# # Test zero-sized array (lines 150, 163) - reinforce existing test -# store = await ShardedZarrStore.open( -# cas=kubo_cas, -# read_only=False, -# array_shape=(0, 10), -# chunk_shape=(5, 5), -# chunks_per_shard=10, -# ) -# assert store._total_chunks == 0 -# assert store._num_shards == 0 -# assert len(store._root_obj["chunks"]["shard_cids"]) == 0 # Line 163 -# root_cid = await store.flush() - -# # Test invalid manifest version (line 224) -# invalid_root_obj = { -# "manifest_version": "invalid_version", -# "metadata": {}, -# "chunks": { -# "array_shape": [10, 10], -# "chunk_shape": [5, 5], -# "cid_byte_length": 59, -# "sharding_config": {"chunks8048": 10}, -# "shard_cids": [None] * 4, -# }, -# } -# invalid_root_cid = await kubo_cas.save( -# dag_cbor.encode(invalid_root_obj), codec="dag-cbor" -# ) -# with pytest.raises(ValueError, match="Incompatible manifest version"): -# await ShardedZarrStore.open( -# cas=kubo_cas, read_only=True, root_cid=invalid_root_cid -# ) - -# # Test inconsistent shard count (line 236) -# invalid_root_obj = { -# "manifest_version": "sharded_zarr_v1", -# "metadata": {}, -# "chunks": { -# "array_shape": [ -# 10, -# 10, -# ], # 100 chunks, with 10 chunks per shard -> 10 shards -# "chunk_shape": [5, 5], -# "cid_byte_length": 59, -# "sharding_config": {"chunks_per_shard": 10}, -# "shard_cids": [None] * 5, # Wrong number of shards -# }, -# } -# invalid_root_cid = await kubo_cas.save( -# dag_cbor.encode(invalid_root_obj), codec="dag-cbor" -# ) -# with pytest.raises(ValueError, match="Inconsistent number of shards"): -# await ShardedZarrStore.open( -# cas=kubo_cas, read_only=True, root_cid=invalid_root_cid -# ) - - @pytest.mark.asyncio async def test_sharded_zarr_store_parse_chunk_key(create_ipfs: tuple[str, str]): """Tests chunk key parsing edge cases.""" @@ -1126,8 +1037,8 @@ async def test_sharded_zarr_store_lazy_concat_with_cids(create_ipfs: tuple[str, rpc_base_url, gateway_base_url = create_ipfs # Provided CIDs - finalized_cid = "bafyr4icrox4pxashkfmbyztn7jhp6zjlpj3bufg5ggsjux74zr7ocnqdpu" - non_finalized_cid = "bafyr4ibj3bfl5oo7bf6gagzr2g33jlnf23mq2xo632mbl6ytfry7jbuepy" + finalized_cid = "bafyr4iacuutc5bgmirkfyzn4igi2wys7e42kkn674hx3c4dv4wrgjp2k2u" + non_finalized_cid = "bafyr4iayq3aaifmyv4o7ezoi4xyysstit3ohvnq4cnjlbjwueqehlbvkla" async with KuboCAS( rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url ) as kubo_cas: @@ -1168,7 +1079,6 @@ async def test_sharded_zarr_store_lazy_concat_with_cids(create_ipfs: tuple[str, combined_ds = xr.concat([ds_finalized, ds_non_finalized_sliced], dim="time") print("\nCombined dataset time range:") print(combined_ds.time.min().values, "to", combined_ds.time.max().values) - print("EHRUKHUKEHUK") # Verify that the combined dataset is still lazy assert combined_ds["2m_temperature"].chunks is not None From 7b20b2a3a188336748610d95b0f39825b67c6373 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 28 Jul 2025 07:20:20 -0400 Subject: [PATCH 60/74] fix: update formatting --- tests/test_sharded_zarr_store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 72ea322..4d070a9 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -758,6 +758,7 @@ async def test_sharded_zarr_store_get_partial_values( print("\n✅ get_partial_values test successful! All partial reads verified.") + @pytest.mark.asyncio async def test_sharded_zarr_store_parse_chunk_key(create_ipfs: tuple[str, str]): """Tests chunk key parsing edge cases.""" From 4acd1958854f6e60f1c71662dbd3008bd4fc06d0 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 28 Jul 2025 08:14:44 -0400 Subject: [PATCH 61/74] fix: update cids --- fsgs.py | 2 +- public_gateway_example.py | 2 +- tests/test_public_gateway.py | 2 +- tests/test_sharded_zarr_store.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fsgs.py b/fsgs.py index 610863c..f51e2e5 100644 --- a/fsgs.py +++ b/fsgs.py @@ -16,7 +16,7 @@ async def main(): - cid = "bafyr4idgcwyxddd2mlskpo7vltcicf5mtozlzt4vzpivqmn343hk3c5nbu" + cid = "bafyr4ibiduv7ml3jeyl3gn6cjcrcizfqss7j64rywpbj3whr7tc6xipt3y" # Use KuboCAS as an async context manager async with KuboCAS() as kubo_cas: # connects to a local kubo node diff --git a/public_gateway_example.py b/public_gateway_example.py index a9c02e2..d331ec9 100644 --- a/public_gateway_example.py +++ b/public_gateway_example.py @@ -53,7 +53,7 @@ async def fetch_zarr_from_gateway(cid: str, gateway: str = "https://ipfs.io"): async def main(): # Example CID - this points to a weather dataset stored on IPFS - cid = "bafyr4idgcwyxddd2mlskpo7vltcicf5mtozlzt4vzpivqmn343hk3c5nbu" + cid = "bafyr4ibiduv7ml3jeyl3gn6cjcrcizfqss7j64rywpbj3whr7tc6xipt3y" # Try different public gateways gateways = [ diff --git a/tests/test_public_gateway.py b/tests/test_public_gateway.py index 7bea68a..d727216 100644 --- a/tests/test_public_gateway.py +++ b/tests/test_public_gateway.py @@ -6,7 +6,7 @@ from py_hamt import KuboCAS -TEST_CID = "bafyr4idgcwyxddd2mlskpo7vltcicf5mtozlzt4vzpivqmn343hk3c5nbu" +TEST_CID = "bafyr4ibiduv7ml3jeyl3gn6cjcrcizfqss7j64rywpbj3whr7tc6xipt3y" async def verify_response_content(url: str, client=None): diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 4d070a9..1322a90 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -97,7 +97,7 @@ async def test_load_or_initialize_shard_cache_concurrent_loads( # Create a shard with data shard_idx = 0 shard_data = [ - CID.decode("bafyr4idgcwyxddd2mlskpo7vltcicf5mtozlzt4vzpivqmn343hk3c5nbu") + CID.decode("bafyr4iacuutc5bgmirkfyzn4igi2wys7e42kkn674hx3c4dv4wrgjp2k2u") for _ in range(4) ] shard_data_bytes = dag_cbor.encode(shard_data) @@ -1039,7 +1039,7 @@ async def test_sharded_zarr_store_lazy_concat_with_cids(create_ipfs: tuple[str, # Provided CIDs finalized_cid = "bafyr4iacuutc5bgmirkfyzn4igi2wys7e42kkn674hx3c4dv4wrgjp2k2u" - non_finalized_cid = "bafyr4iayq3aaifmyv4o7ezoi4xyysstit3ohvnq4cnjlbjwueqehlbvkla" + non_finalized_cid = "bafyr4ihicmzx4uw4pefk7idba3mz5r5g27au3l7d62yj4gguxx6neaa5ti" async with KuboCAS( rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url ) as kubo_cas: From 5b1bad889c6475ee76237f80a4f0bbbae31a9baa Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 28 Jul 2025 09:00:02 -0400 Subject: [PATCH 62/74] fix: more changes --- py_hamt/{ => reference}/flat_zarr_store.py | 0 py_hamt/sharded_zarr_store.py | 27 +++++++++------- tests/test_sharded_zarr_store.py | 36 +++++----------------- 3 files changed, 24 insertions(+), 39 deletions(-) rename py_hamt/{ => reference}/flat_zarr_store.py (100%) diff --git a/py_hamt/flat_zarr_store.py b/py_hamt/reference/flat_zarr_store.py similarity index 100% rename from py_hamt/flat_zarr_store.py rename to py_hamt/reference/flat_zarr_store.py diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index e5f571d..91d4ae6 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -407,6 +407,7 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: # Metadata is often saved as 'raw', chunks as well unless compressed. data_cid_obj = await self.cas.save(raw_data_bytes, codec="raw") await self.set_pointer(key, str(data_cid_obj)) + return None async def set_pointer(self, key: str, pointer: str) -> None: chunk_coords = self._parse_chunk_key(key) @@ -416,7 +417,7 @@ async def set_pointer(self, key: str, pointer: str) -> None: if chunk_coords is None: # Metadata key self._root_obj["metadata"][key] = pointer_cid_obj self._dirty_root = True - return + return None linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) @@ -428,16 +429,20 @@ async def set_pointer(self, key: str, pointer: str) -> None: if target_shard_list[index_in_shard] != pointer_cid_obj: target_shard_list[index_in_shard] = pointer_cid_obj self._dirty_shards.add(shard_idx) + return None async def exists(self, key: str) -> bool: - chunk_coords = self._parse_chunk_key(key) - if chunk_coords is None: # Metadata - return key in self._root_obj.get("metadata", {}) - linear_chunk_index = self._get_linear_chunk_index(chunk_coords) - shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) - # Load shard if not cached and check the index - target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) - return target_shard_list[index_in_shard] is not None + try: + chunk_coords = self._parse_chunk_key(key) + if chunk_coords is None: # Metadata + return key in self._root_obj.get("metadata", {}) + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + # Load shard if not cached and check the index + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + return target_shard_list[index_in_shard] is not None + except (ValueError, IndexError, KeyError): + return False @property def supports_writes(self) -> bool: @@ -461,7 +466,7 @@ async def delete(self, key: str) -> None: self._dirty_root = True else: raise KeyError(f"Metadata key '{key}' not found.") - return + return None linear_chunk_index = self._get_linear_chunk_index(chunk_coords) shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) @@ -600,7 +605,7 @@ async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): new_zarr_metadata_bytes, codec="raw" ) - self._root_obj["metadata"][zarr_metadata_key] = str(new_zarr_metadata_cid) + self._root_obj["metadata"][zarr_metadata_key] = new_zarr_metadata_cid self._dirty_root = True async def list_dir(self, prefix: str) -> AsyncIterator[str]: diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 1322a90..f4d5264 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -272,6 +272,9 @@ async def test_sharded_zarr_store_metadata( assert await store_read.exists("time/c/0") # assert not await store_read.exists("nonexistent") + # Test does not exist + assert not await store_read.exists("temp/c/20/0/0") # Non-existent chunk + # Test list keys = [key async for key in store_read.list()] assert len(keys) > 0 @@ -736,11 +739,11 @@ async def test_sharded_zarr_store_get_partial_values( print(f"OffsetByteRequest: OK (Got {len(results[1].to_bytes())} bytes)") # Check SuffixByteRequest result - # expected_suffix = full_chunk_data[-20:] - # assert results[2].to_bytes() == expected_suffix, ( - # "SuffixByteRequest result does not match" - # ) - # print(f"SuffixByteRequest: OK (Got {len(results[2].to_bytes())} bytes)") + expected_suffix = full_chunk_data[-20:] + assert results[2].to_bytes() == expected_suffix, ( + "SuffixByteRequest result does not match" + ) + print(f"SuffixByteRequest: OK (Got {len(results[2].to_bytes())} bytes)") # Check full read result assert results[3].to_bytes() == full_chunk_data, ( @@ -783,20 +786,6 @@ async def test_sharded_zarr_store_parse_chunk_key(create_ipfs: tuple[str, str]): assert store._parse_chunk_key("lat/c/0/0") is None assert store._parse_chunk_key("lon/c/0/0") is None - # Test uninitialized store - # uninitialized_store = ShardedZarrStore(kubo_cas, read_only=False, root_cid=None) - # assert uninitialized_store._parse_chunk_key("temp/c/0/0") is None - - # # Test get on uninitialized store - # with pytest.raises( - # RuntimeError, match="Load the root object first before accessing data." - # ): - # proto = zarr.core.buffer.default_buffer_prototype() - # await uninitialized_store.get("temp/c/0/0", proto) - - # with pytest.raises(RuntimeError, match="Cannot load root without a root_cid."): - # await uninitialized_store._load_root_from_cid() - # Test dimensionality mismatch with pytest.raises(IndexError, match="tuple index out of range"): store._parse_chunk_key("temp/c/0/0/0/0") @@ -1024,8 +1013,6 @@ async def test_sharded_zarr_store_lazy_concat( ).temp.values np.testing.assert_array_equal(sample_result, expected_sample) - print("\n✅ Lazy concatenation test successful! Data verified.") - @pytest.mark.asyncio async def test_sharded_zarr_store_lazy_concat_with_cids(create_ipfs: tuple[str, str]): @@ -1072,14 +1059,9 @@ async def test_sharded_zarr_store_lazy_concat_with_cids(create_ipfs: tuple[str, # Verify that the sliced dataset starts after the finalization date if ds_non_finalized_sliced.time.size > 0: assert ds_non_finalized_sliced.time.min() > finalization_date - else: - # Handle case where non-finalized dataset is empty after slicing - print("Warning: Non-finalized dataset is empty after slicing.") # 3. --- Lazily Concatenate Datasets --- combined_ds = xr.concat([ds_finalized, ds_non_finalized_sliced], dim="time") - print("\nCombined dataset time range:") - print(combined_ds.time.min().values, "to", combined_ds.time.max().values) # Verify that the combined dataset is still lazy assert combined_ds["2m_temperature"].chunks is not None @@ -1142,5 +1124,3 @@ async def test_sharded_zarr_store_lazy_concat_with_cids(create_ipfs: tuple[str, if query_result.time.size > 0: assert query_result.time.min() >= query_start assert query_result.time.max() <= query_end - - print("\n✅ Lazy concatenation with CIDs test successful! Data verified.") From bfb4a4161e441445750e6841b5cb51cbc13adb2d Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 21 Aug 2025 07:14:41 -0400 Subject: [PATCH 63/74] fix: remove duplicate --- tests/test_public_gateway.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_public_gateway.py b/tests/test_public_gateway.py index 20a64f7..0cbc4c8 100644 --- a/tests/test_public_gateway.py +++ b/tests/test_public_gateway.py @@ -6,7 +6,6 @@ from py_hamt import KuboCAS -TEST_CID = "bafyr4ibiduv7ml3jeyl3gn6cjcrcizfqss7j64rywpbj3whr7tc6xipt3y" """ Tests for IPFS gateway functionality. From 7d826fad62f702b6c94b5fbff0521bd3d645682e Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Thu, 21 Aug 2025 08:30:55 -0400 Subject: [PATCH 64/74] fix: with read only --- py_hamt/sharded_zarr_store.py | 42 ++++++++++++++++++++++++ tests/test_sharded_zarr_store.py | 56 ++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 91d4ae6..02b0857 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -268,6 +268,48 @@ async def get_partial_values( results = await asyncio.gather(*tasks) return results + def with_read_only(self, read_only: bool = False) -> "ShardedZarrStore": + """ + Return this store (if the flag already matches) or a *shallow* + clone that presents the requested read‑only status. + + The clone **shares** the same CAS instance and internal state; + no flushing, network traffic or async work is done. + """ + # Fast path + if read_only == self.read_only: + return self # Same mode, return same instance + + # Create new instance with different read_only flag + # Creates a *bare* instance without running its __init__ + clone = type(self).__new__(type(self)) + + # Copy all attributes from the current instance + clone.cas = self.cas + clone._root_cid = self._root_cid + clone._root_obj = self._root_obj + + clone._resize_lock = self._resize_lock + clone._resize_complete = self._resize_complete + clone._shard_locks = self._shard_locks + + clone._shard_data_cache = self._shard_data_cache + clone._dirty_shards = self._dirty_shards + clone._pending_shard_loads = self._pending_shard_loads + + clone._array_shape = self._array_shape + clone._chunk_shape = self._chunk_shape + clone._chunks_per_dim = self._chunks_per_dim + clone._chunks_per_shard = self._chunks_per_shard + clone._num_shards = self._num_shards + clone._total_chunks = self._total_chunks + + clone._dirty_root = self._dirty_root + + # Re‑initialise the zarr base class so that Zarr sees the flag + zarr.abc.store.Store.__init__(clone, read_only=read_only) + return clone + def __eq__(self, other: object) -> bool: if not isinstance(other, ShardedZarrStore): return False diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index f4d5264..33c88af 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -524,6 +524,62 @@ async def test_store_eq_method(create_ipfs: tuple[str, str]): assert store1 == store2 +@pytest.mark.asyncio +async def test_with_read_only(create_ipfs: tuple[str, str]): + """Tests the with_read_only method.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Create a writable store + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(5, 5), + chunks_per_shard=4, + ) + + # Test same mode returns same instance + same_store = store_write.with_read_only(False) + assert same_store is store_write + + # Test switching to read-only + store_read_only = store_write.with_read_only(True) + assert store_read_only is not store_write + assert store_read_only.read_only is True + assert store_write.read_only is False + + # Test that clone shares the same state + assert store_read_only.cas is store_write.cas + assert store_read_only._root_cid == store_write._root_cid + assert store_read_only._root_obj is store_write._root_obj + assert store_read_only._shard_data_cache is store_write._shard_data_cache + assert store_read_only._dirty_shards is store_write._dirty_shards + assert store_read_only._array_shape == store_write._array_shape + assert store_read_only._chunk_shape == store_write._chunk_shape + assert store_read_only._chunks_per_shard == store_write._chunks_per_shard + + # Test switching back to writable + store_write_again = store_read_only.with_read_only(False) + assert store_write_again is not store_read_only + assert store_write_again.read_only is False + + # Test write operations are blocked on read-only store + proto = zarr.core.buffer.default_buffer_prototype() + test_data = proto.buffer.from_bytes(b"test_data") + + with pytest.raises(PermissionError): + await store_read_only.set("test_key", test_data) + + with pytest.raises(PermissionError): + await store_read_only.delete("test_key") + + # Test write operations work on writable store + await store_write.set("test_key", test_data) + await store_write.delete("test_key") + + @pytest.mark.asyncio async def test_listing_and_metadata( create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset From 559a7ff3db922614e2c7568dd21caac0b665f1c2 Mon Sep 17 00:00:00 2001 From: 0xSwego <0xSwego@gmail.com> Date: Mon, 25 Aug 2025 11:37:20 +0100 Subject: [PATCH 65/74] Update py_hamt/store_httpx.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- py_hamt/store_httpx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index 3de8efb..9bcb960 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -447,7 +447,7 @@ async def pin_cid( Args: cid (CID): The Content ID to pin. - name (Optional[str]): An optional name for the pin. + target_rpc (str): The RPC URL of the Kubo node. """ params = {"arg": str(cid), "recursive": "true"} pin_add_url_base: str = f"{target_rpc}/api/v0/pin/add" From e537389cfb3d44a747edd6135243c447ee7d60ea Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 25 Aug 2025 09:18:10 -0400 Subject: [PATCH 66/74] fix: fix casing --- py_hamt/store_httpx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index 9bcb960..ab82f6a 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -187,7 +187,7 @@ def __init__( *, headers: dict[str, str] | None = None, auth: Tuple[str, str] | None = None, - pinOnAdd: bool = False, + pin_on_add: bool = False, chunker: str = "size-1048576", ): """ @@ -255,8 +255,8 @@ def __init__( else: gateway_base_url = f"{gateway_base_url}/ipfs/" - pinString: str = "true" if pinOnAdd else "false" - self.rpc_url: str = f"{rpc_base_url}/api/v0/add?hash={self.hasher}&chunker={self.chunker}&pin={pinString}" + pin_string: str = "true" if pin_on_add else "false" + self.rpc_url: str = f"{rpc_base_url}/api/v0/add?hash={self.hasher}&chunker={self.chunker}&pin={pin_string}" """@private""" self.gateway_base_url: str = gateway_base_url """@private""" From efb3c640f804da3fd9a498aab98651635b04f65e Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Tue, 2 Sep 2025 09:46:15 -0400 Subject: [PATCH 67/74] fix: linting --- py_hamt/store_httpx.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index 6fe74da..5f6b5e2 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -279,7 +279,7 @@ def __init__( self._default_auth = auth self._sem: asyncio.Semaphore = asyncio.Semaphore(concurrency) - self._closed: bool = False + self._closed = False # Validate retry parameters if max_retries < 0: @@ -475,7 +475,9 @@ async def load( while retry_count <= self.max_retries: try: - response = await client.get(url, headers=headers or None, timeout=60.0) + response = await client.get( + url, headers=headers or None, timeout=60.0 + ) response.raise_for_status() return response.content @@ -484,11 +486,15 @@ async def load( if retry_count > self.max_retries: raise httpx.TimeoutException( f"Failed to load data after {self.max_retries} retries: {str(e)}", - request=e.request if isinstance(e, httpx.RequestError) else None, + request=e.request + if isinstance(e, httpx.RequestError) + else None, ) # Calculate backoff delay with jitter - delay = self.initial_delay * (self.backoff_factor ** (retry_count - 1)) + delay = self.initial_delay * ( + self.backoff_factor ** (retry_count - 1) + ) jitter = delay * 0.1 * (random.random() - 0.5) await asyncio.sleep(delay + jitter) From 41143f9e671de95fa3f96e183448e8d6813379c0 Mon Sep 17 00:00:00 2001 From: Faolain Date: Wed, 10 Sep 2025 02:23:09 -0400 Subject: [PATCH 68/74] lint(ruff): add stricter rule to __init__ --- .pre-commit-config.yaml | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72c594e..66a0794 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.11.11 + rev: v0.12.12 hooks: - id: ruff-check - id: ruff-format diff --git a/pyproject.toml b/pyproject.toml index 5e72353..f712c93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,3 +39,4 @@ dev = [ [tool.ruff] lint.extend-select = ["I"] +preview = true From e55a3e34ac489a7c485a5f666c7e626bbfd9840d Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 10 Sep 2025 13:53:40 -0400 Subject: [PATCH 69/74] fix: lru cache --- py_hamt/__init__.py | 2 +- py_hamt/hamt_to_sharded_converter.py | 6 +- py_hamt/sharded_zarr_store.py | 288 +++++++++++++++++++++++---- tests/test_sharded_store_deleting.py | 10 +- tests/test_sharded_store_grafting.py | 10 +- tests/test_sharded_zarr_store.py | 256 +++++++++++++++++++++++- tests/testing_utils.py | 20 +- 7 files changed, 527 insertions(+), 65 deletions(-) diff --git a/py_hamt/__init__.py b/py_hamt/__init__.py index aba37df..2fff260 100644 --- a/py_hamt/__init__.py +++ b/py_hamt/__init__.py @@ -1,5 +1,6 @@ from .encryption_hamt_store import SimpleEncryptedZarrHAMTStore from .hamt import HAMT, blake3_hashfn +from .hamt_to_sharded_converter import convert_hamt_to_sharded, sharded_converter_cli from .sharded_zarr_store import ShardedZarrStore from .store_httpx import ContentAddressedStore, InMemoryCAS, KuboCAS from .zarr_hamt_store import ZarrHAMTStore @@ -11,7 +12,6 @@ "InMemoryCAS", "KuboCAS", "ZarrHAMTStore", - "InMemoryCAS", "SimpleEncryptedZarrHAMTStore", "ShardedZarrStore", "convert_hamt_to_sharded", diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py index befc4d6..b0e8921 100644 --- a/py_hamt/hamt_to_sharded_converter.py +++ b/py_hamt/hamt_to_sharded_converter.py @@ -5,8 +5,10 @@ import xarray as xr from multiformats import CID -from py_hamt import HAMT, KuboCAS, ShardedZarrStore -from py_hamt.zarr_hamt_store import ZarrHAMTStore +from .hamt import HAMT +from .sharded_zarr_store import ShardedZarrStore +from .store_httpx import KuboCAS +from .zarr_hamt_store import ZarrHAMTStore async def convert_hamt_to_sharded( diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 02b0857..f0b86b5 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -2,7 +2,8 @@ import itertools import json import math -from collections import defaultdict +import sys +from collections import OrderedDict, defaultdict from collections.abc import AsyncIterator, Iterable from typing import DefaultDict, Dict, List, Optional, Set, Tuple @@ -16,6 +17,123 @@ from .store_httpx import ContentAddressedStore +class MemoryBoundedLRUCache: + """ + An LRU cache that evicts items when memory usage exceeds a threshold. + + Memory usage is calculated using sys.getsizeof for accurate sizing. + Dirty shards (those marked for writing) are never evicted until marked clean. + All operations are thread-safe for async access using an asyncio.Lock. + """ + + def __init__(self, max_memory_bytes: int = 100 * 1024 * 1024): # 100MB default + self.max_memory_bytes = max_memory_bytes + self._cache: OrderedDict[int, List[Optional[CID]]] = OrderedDict() + self._dirty_shards: Set[int] = set() + self._shard_sizes: Dict[int, int] = {} # Cached sizes for each shard + self._actual_memory_usage = 0 + self._cache_lock = asyncio.Lock() + + def _get_shard_size(self, shard_data: List[Optional[CID]]) -> int: + """Compute actual size: list overhead + sum of item sizes.""" + if not shard_data: + return sys.getsizeof(shard_data) + total = sys.getsizeof(shard_data) + for item in shard_data: + total += sys.getsizeof(item) + return total + + async def get(self, shard_idx: int) -> Optional[List[Optional[CID]]]: + """Get a shard from cache, moving it to end (most recently used).""" + async with self._cache_lock: + if shard_idx not in self._cache: + return None + shard_data = self._cache.pop(shard_idx) + self._cache[shard_idx] = shard_data + return shard_data + + async def put( + self, shard_idx: int, shard_data: List[Optional[CID]], is_dirty: bool = False + ) -> None: + """Add or update a shard in cache, evicting old items if needed.""" + async with self._cache_lock: + shard_size = self._get_shard_size(shard_data) + + # If shard exists, remove its old size + if shard_idx in self._cache: + self._cache.pop(shard_idx) + self._actual_memory_usage -= self._shard_sizes.pop(shard_idx, 0) + + # Track dirty status + if is_dirty: + self._dirty_shards.add(shard_idx) + + # Add new shard + self._cache[shard_idx] = shard_data + self._shard_sizes[shard_idx] = shard_size + self._actual_memory_usage += shard_size + + # Evict old items if over memory limit, never evict dirty shards + while ( + self._actual_memory_usage > self.max_memory_bytes + and len(self._cache) > 1 + ): + evicted = False + while self._cache: + candidate_idx, candidate_data = self._cache.popitem(last=False) + if candidate_idx not in self._dirty_shards: + # Evict this clean LRU + self._actual_memory_usage -= self._shard_sizes.pop( + candidate_idx, 0 + ) + evicted = True + break + else: + # Dirty: move to MRU + self._cache[candidate_idx] = candidate_data + if not evicted: + # No clean shards to evict + break + + async def mark_dirty(self, shard_idx: int) -> None: + """Mark a shard as dirty (should not be evicted).""" + async with self._cache_lock: + if shard_idx in self._cache: + self._dirty_shards.add(shard_idx) + + async def mark_clean(self, shard_idx: int) -> None: + """Mark a shard as clean (can be evicted).""" + async with self._cache_lock: + self._dirty_shards.discard(shard_idx) + + async def clear(self) -> None: + """Clear all cached data.""" + async with self._cache_lock: + self._cache.clear() + self._dirty_shards.clear() + self._shard_sizes.clear() + self._actual_memory_usage = 0 + + async def __contains__(self, shard_idx: int) -> bool: + async with self._cache_lock: + return shard_idx in self._cache + + @property + def estimated_memory_usage(self) -> int: + """Current memory usage in bytes, based on actual sizes.""" + return self._actual_memory_usage + + @property + def cache_size(self) -> int: + """Number of items currently cached.""" + return len(self._cache) + + @property + def dirty_cache_size(self) -> int: + """Number of dirty items currently cached.""" + return len(self._dirty_shards) + + class ShardedZarrStore(zarr.abc.store.Store): """ Implements the Zarr Store API using a sharded layout for chunk CIDs. @@ -36,6 +154,8 @@ def __init__( cas: ContentAddressedStore, read_only: bool, root_cid: Optional[str] = None, + *, + max_cache_memory_bytes: int = 100 * 1024 * 1024, # 100MB default ): """Use the async `open()` classmethod to instantiate this class.""" super().__init__(read_only=read_only) @@ -50,8 +170,7 @@ def __init__( self._resize_complete.set() self._shard_locks: DefaultDict[int, asyncio.Lock] = defaultdict(asyncio.Lock) - self._shard_data_cache: Dict[int, list[Optional[CID]]] = {} - self._dirty_shards: Set[int] = set() + self._shard_data_cache = MemoryBoundedLRUCache(max_cache_memory_bytes) self._pending_shard_loads: Dict[int, asyncio.Event] = {} self._array_shape: Tuple[int, ...] @@ -63,7 +182,7 @@ def __init__( self._dirty_root = False - def _update_geometry(self): + def __update_geometry(self): """Calculates derived geometric properties from the base shapes.""" if not all(cs > 0 for cs in self._chunk_shape): @@ -92,11 +211,14 @@ async def open( array_shape: Optional[Tuple[int, ...]] = None, chunk_shape: Optional[Tuple[int, ...]] = None, chunks_per_shard: Optional[int] = None, + max_cache_memory_bytes: int = 100 * 1024 * 1024, # 100MB default ) -> "ShardedZarrStore": """ Asynchronously opens an existing ShardedZarrStore or initializes a new one. """ - store = cls(cas, read_only, root_cid) + store = cls( + cas, read_only, root_cid, max_cache_memory_bytes=max_cache_memory_bytes + ) if root_cid: await store._load_root_from_cid() elif not read_only: @@ -123,7 +245,7 @@ def _initialize_new_root( self._chunk_shape = chunk_shape self._chunks_per_shard = chunks_per_shard - self._update_geometry() + self.__update_geometry() self._root_obj = { "manifest_version": "sharded_zarr_v1", @@ -141,6 +263,16 @@ def _initialize_new_root( async def _load_root_from_cid(self): root_bytes = await self.cas.load(self._root_cid) + try: + self._root_obj = dag_cbor.decode(root_bytes) + if not isinstance(self._root_obj, dict) or "chunks" not in self._root_obj: + raise ValueError( + "Root object is not a valid dictionary with 'chunks' key." + ) + if not isinstance(self._root_obj["chunks"]["shard_cids"], list): + raise ValueError("shard_cids is not a list.") + except Exception as e: + raise ValueError(f"Failed to decode root object: {e}") self._root_obj = dag_cbor.decode(root_bytes) if self._root_obj.get("manifest_version") != "sharded_zarr_v1": @@ -153,26 +285,60 @@ async def _load_root_from_cid(self): self._chunk_shape = tuple(chunk_info["chunk_shape"]) self._chunks_per_shard = chunk_info["sharding_config"]["chunks_per_shard"] - self._update_geometry() + self.__update_geometry() if len(chunk_info["shard_cids"]) != self._num_shards: raise ValueError( f"Inconsistent number of shards. Expected {self._num_shards}, found {len(chunk_info['shard_cids'])}." ) - async def _fetch_and_cache_full_shard(self, shard_idx: int, shard_cid: str): - try: - shard_data_bytes = await self.cas.load(shard_cid) - decoded_shard = dag_cbor.decode(shard_data_bytes) - if not isinstance(decoded_shard, list): - raise TypeError(f"Shard {shard_idx} did not decode to a list.") - self._shard_data_cache[shard_idx] = decoded_shard - except Exception: - raise - finally: - if shard_idx in self._pending_shard_loads: - self._pending_shard_loads[shard_idx].set() # Signal completion - del self._pending_shard_loads[shard_idx] + async def _fetch_and_cache_full_shard( + self, + shard_idx: int, + shard_cid: str, + max_retries: int = 3, + retry_delay: float = 1.0, + ) -> None: + """ + Fetch a shard from CAS and cache it, with retry logic for transient errors. + + Args: + shard_idx: The index of the shard to fetch. + shard_cid: The CID of the shard. + max_retries: Maximum number of retry attempts for transient errors. + retry_delay: Delay between retry attempts in seconds. + """ + for attempt in range(max_retries): + try: + shard_data_bytes = await self.cas.load(shard_cid) + decoded_shard = dag_cbor.decode(shard_data_bytes) + if not isinstance(decoded_shard, list): + raise TypeError(f"Shard {shard_idx} did not decode to a list.") + await self._shard_data_cache.put(shard_idx, decoded_shard) + # Always set the Event to unblock waiting coroutines + if shard_idx in self._pending_shard_loads: + self._pending_shard_loads[shard_idx].set() + del self._pending_shard_loads[shard_idx] + return # Success + except (ConnectionError, TimeoutError) as e: + # Handle transient errors (e.g., network issues) + if attempt < max_retries - 1: + await asyncio.sleep( + retry_delay * (2**attempt) + ) # Exponential backoff + continue + else: + # Log the failure and raise a specific error + print( + f"Failed to fetch shard {shard_idx} after {max_retries} attempts: {e}" + ) + raise RuntimeError( + f"Failed to fetch shard {shard_idx} after {max_retries} attempts: {e}" + ) + except Exception as e: + # Handle non-transient errors immediately + print(f"Error fetching shard {shard_idx}: {e}") + raise def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # 1. Exclude .json files immediately (metadata) @@ -229,14 +395,45 @@ def _get_shard_info(self, linear_chunk_index: int) -> Tuple[int, int]: index_in_shard = linear_chunk_index % self._chunks_per_shard return shard_idx, index_in_shard - async def _load_or_initialize_shard_cache(self, shard_idx: int) -> list: - if shard_idx in self._shard_data_cache: - return self._shard_data_cache[shard_idx] + async def _load_or_initialize_shard_cache( + self, shard_idx: int + ) -> List[Optional[CID]]: + """ + Load a shard into the cache or initialize an empty shard if it doesn't exist. + + Args: + shard_idx: The index of the shard to load or initialize. + + Returns: + List[Optional[CID]]: The shard data (list of CIDs or None). + + Raises: + ValueError: If the shard index is out of bounds. + RuntimeError: If the shard cannot be loaded or initialized. + """ + cached_shard = await self._shard_data_cache.get(shard_idx) + if cached_shard is not None: + return cached_shard if shard_idx in self._pending_shard_loads: - await self._pending_shard_loads[shard_idx].wait() - if shard_idx in self._shard_data_cache: - return self._shard_data_cache[shard_idx] + try: + # Wait for the pending load with a timeout (e.g., 60 seconds) + await asyncio.wait_for( + self._pending_shard_loads[shard_idx].wait(), timeout=60.0 + ) + cached_shard = await self._shard_data_cache.get(shard_idx) + if cached_shard is not None: + return cached_shard + else: + raise RuntimeError( + f"Shard {shard_idx} not found in cache after pending load completed." + ) + except asyncio.TimeoutError: + # Clean up the pending load to allow retry + if shard_idx in self._pending_shard_loads: + self._pending_shard_loads[shard_idx].set() + del self._pending_shard_loads[shard_idx] + raise RuntimeError(f"Timeout waiting for shard {shard_idx} to load.") if not (0 <= shard_idx < self._num_shards): raise ValueError(f"Shard index {shard_idx} out of bounds.") @@ -244,13 +441,16 @@ async def _load_or_initialize_shard_cache(self, shard_idx: int) -> list: shard_cid_obj = self._root_obj["chunks"]["shard_cids"][shard_idx] if shard_cid_obj: self._pending_shard_loads[shard_idx] = asyncio.Event() - # The CID in the root should already be a CID object if loaded correctly. shard_cid_str = str(shard_cid_obj) await self._fetch_and_cache_full_shard(shard_idx, shard_cid_str) else: - self._shard_data_cache[shard_idx] = [None] * self._chunks_per_shard + empty_shard = [None] * self._chunks_per_shard + await self._shard_data_cache.put(shard_idx, empty_shard) - return self._shard_data_cache[shard_idx] + result = await self._shard_data_cache.get(shard_idx) + if result is None: + raise RuntimeError(f"Failed to load or initialize shard {shard_idx}") + return result # type: ignore[return-value] async def set_partial_values( self, key_start_values: Iterable[Tuple[str, int, BytesLike]] @@ -294,7 +494,6 @@ def with_read_only(self, read_only: bool = False) -> "ShardedZarrStore": clone._shard_locks = self._shard_locks clone._shard_data_cache = self._shard_data_cache - clone._dirty_shards = self._dirty_shards clone._pending_shard_loads = self._pending_shard_loads clone._array_shape = self._array_shape @@ -318,10 +517,14 @@ def __eq__(self, other: object) -> bool: # If nothing to flush, return the root CID. async def flush(self) -> str: - if self._dirty_shards: - for shard_idx in sorted(list(self._dirty_shards)): + async with self._shard_data_cache._cache_lock: + dirty_shards = list(self._shard_data_cache._dirty_shards) + if dirty_shards: + for shard_idx in sorted(dirty_shards): # Get the list of CIDs/Nones from the cache - shard_data_list = self._shard_data_cache[shard_idx] + shard_data_list = await self._shard_data_cache.get(shard_idx) + if shard_data_list is None: + raise RuntimeError(f"Dirty shard {shard_idx} not found in cache") # Encode this list into a DAG-CBOR byte representation shard_data_bytes = dag_cbor.encode(shard_data_list) @@ -341,8 +544,8 @@ async def flush(self) -> str: new_shard_cid_obj ) self._dirty_root = True - - self._dirty_shards.clear() + # Mark shard as clean after flushing + await self._shard_data_cache.mark_clean(shard_idx) if self._dirty_root: # Ensure all metadata CIDs are CID objects for correct encoding @@ -447,9 +650,12 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: raw_data_bytes = value.to_bytes() # Save the data to CAS first to get its CID. # Metadata is often saved as 'raw', chunks as well unless compressed. - data_cid_obj = await self.cas.save(raw_data_bytes, codec="raw") - await self.set_pointer(key, str(data_cid_obj)) - return None + try: + data_cid_obj = await self.cas.save(raw_data_bytes, codec="raw") + await self.set_pointer(key, str(data_cid_obj)) + except Exception as e: + raise RuntimeError(f"Failed to save data for key {key}: {e}") + return None # type: ignore[return-value] async def set_pointer(self, key: str, pointer: str) -> None: chunk_coords = self._parse_chunk_key(key) @@ -470,7 +676,7 @@ async def set_pointer(self, key: str, pointer: str) -> None: if target_shard_list[index_in_shard] != pointer_cid_obj: target_shard_list[index_in_shard] = pointer_cid_obj - self._dirty_shards.add(shard_idx) + await self._shard_data_cache.mark_dirty(shard_idx) return None async def exists(self, key: str) -> bool: @@ -518,7 +724,7 @@ async def delete(self, key: str) -> None: target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) if target_shard_list[index_in_shard] is not None: target_shard_list[index_in_shard] = None - self._dirty_shards.add(shard_idx) + await self._shard_data_cache.mark_dirty(shard_idx) @property def supports_listing(self) -> bool: @@ -572,7 +778,7 @@ async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, .. ) if target_shard_list[index_in_global_shard] != pointer_cid_obj: target_shard_list[index_in_global_shard] = pointer_cid_obj - self._dirty_shards.add(global_shard_idx) + await self._shard_data_cache.mark_dirty(global_shard_idx) async def resize_store(self, new_shape: Tuple[int, ...]): """ diff --git a/tests/test_sharded_store_deleting.py b/tests/test_sharded_store_deleting.py index b67f27c..88222a8 100644 --- a/tests/test_sharded_store_deleting.py +++ b/tests/test_sharded_store_deleting.py @@ -63,7 +63,7 @@ async def test_delete_chunk_success(create_ipfs: tuple[str, str]): shard_idx, index_in_shard = store._get_shard_info(linear_index) target_shard_list = await store._load_or_initialize_shard_cache(shard_idx) assert target_shard_list[index_in_shard] is None - assert shard_idx in store._dirty_shards + assert shard_idx in store._shard_data_cache._dirty_shards # Flush and verify persistence root_cid = await store.flush() @@ -135,7 +135,7 @@ async def test_delete_nonexistent_key(create_ipfs: tuple[str, str]): # flush it await store.flush() - assert not store._dirty_shards # No dirty shards after flush + assert not store._shard_data_cache._dirty_shards # No dirty shards after flush # Try to delete nonexistent metadata key with pytest.raises(KeyError, match="Metadata key 'nonexistent.json' not found"): @@ -149,7 +149,7 @@ async def test_delete_nonexistent_key(create_ipfs: tuple[str, str]): await store.delete("temp/c/0/0") # Should not raise, as it sets to None assert not await store.exists("temp/c/0/0") assert ( - store._dirty_shards + store._shard_data_cache._dirty_shards ) # Shard is marked dirty even if chunk was already None @@ -232,7 +232,9 @@ async def delete_task(key): assert await store.get(key, proto) is None # Verify shards are marked dirty - assert store._dirty_shards # At least one shard should be dirty + assert ( + store._shard_data_cache._dirty_shards + ) # At least one shard should be dirty # Flush and verify persistence root_cid = await store.flush() diff --git a/tests/test_sharded_store_grafting.py b/tests/test_sharded_store_grafting.py index 6437d51..59acfd3 100644 --- a/tests/test_sharded_store_grafting.py +++ b/tests/test_sharded_store_grafting.py @@ -87,7 +87,9 @@ async def test_graft_store_success(create_ipfs: tuple[str, str]): assert target_store._chunks_per_dim == (4, 2) # ceil(40/10) = 4 assert target_store._total_chunks == 8 # 4 * 2 assert target_store._num_shards == 2 # ceil(8/4) = 2 - assert target_store._dirty_shards # Grafting marks shards as dirty + assert ( + target_store._shard_data_cache._dirty_shards + ) # Grafting marks shards as dirty # Flush and verify persistence target_root_cid = await target_store.flush() @@ -200,7 +202,9 @@ async def test_graft_store_empty_source(create_ipfs: tuple[str, str]): # Verify no chunks were grafted assert not await target_store.exists("temp/c/1/1") - assert not target_store._dirty_shards # No shards marked dirty since no changes + assert ( + not target_store._shard_data_cache._dirty_shards + ) # No shards marked dirty since no changes # Flush and verify target_root_cid = await target_store.flush() @@ -395,7 +399,7 @@ async def test_graft_store_overlapping_chunks(create_ipfs: tuple[str, str]): assert read_data is not None assert read_data.to_bytes() == existing_data assert ( - target_store._dirty_shards + target_store._shard_data_cache._dirty_shards ) # Shard is marked dirty due to attempted write # Verify other grafted chunks diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py index 33c88af..ea41734 100644 --- a/tests/test_sharded_zarr_store.py +++ b/tests/test_sharded_zarr_store.py @@ -121,8 +121,8 @@ async def load_shard(): assert result == shard_data # Verify shard is cached and no pending loads remain - assert shard_idx in store._shard_data_cache - assert store._shard_data_cache[shard_idx] == shard_data + assert await store._shard_data_cache.__contains__(shard_idx) + assert await store._shard_data_cache.get(shard_idx) == shard_data assert shard_idx not in store._pending_shard_loads @@ -555,7 +555,11 @@ async def test_with_read_only(create_ipfs: tuple[str, str]): assert store_read_only._root_cid == store_write._root_cid assert store_read_only._root_obj is store_write._root_obj assert store_read_only._shard_data_cache is store_write._shard_data_cache - assert store_read_only._dirty_shards is store_write._dirty_shards + # _dirty_shards is now managed by the cache, not the store directly + assert ( + store_read_only._shard_data_cache._dirty_shards + is store_write._shard_data_cache._dirty_shards + ) assert store_read_only._array_shape == store_write._array_shape assert store_read_only._chunk_shape == store_write._chunk_shape assert store_read_only._chunks_per_shard == store_write._chunks_per_shard @@ -1180,3 +1184,249 @@ async def test_sharded_zarr_store_lazy_concat_with_cids(create_ipfs: tuple[str, if query_result.time.size > 0: assert query_result.time.min() >= query_start assert query_result.time.max() <= query_end + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_basic(): + """Test basic functionality of MemoryBoundedLRUCache.""" + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + # Very small cache for testing to ensure eviction + cache = MemoryBoundedLRUCache(max_memory_bytes=500) # 500 bytes limit + + # Create some test data + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + small_shard = [test_cid] * 2 + medium_shard = [test_cid] * 5 + + # Test basic put/get + await cache.put(0, small_shard) + assert await cache.get(0) == small_shard + assert await cache.__contains__(0) + assert cache.cache_size == 1 + + # Test that get moves item to end (most recently used) + await cache.put(1, medium_shard) + await cache.get(0) # Should move shard 0 to end + + # Add more data to trigger eviction - this should be large enough to force eviction + large_shard = [test_cid] * 20 + await cache.put(2, large_shard) + + # Check basic cache behavior - at least one should be evicted due to memory limits + + # Add an even larger shard to definitely trigger eviction + huge_shard = [test_cid] * 50 + await cache.put(3, huge_shard) + + # Cache should be constrained by memory limit and perform evictions + # The exact behavior depends on actual memory usage, but we should see some eviction + assert ( + cache.estimated_memory_usage <= cache.max_memory_bytes or cache.cache_size == 1 + ) + # At least some items should remain in cache + assert cache.cache_size >= 1 + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_dirty_protection(): + """Test that dirty shards are never evicted.""" + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + # Very small cache to force eviction + cache = MemoryBoundedLRUCache(max_memory_bytes=500) # 500 bytes + + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + small_shard = [test_cid] * 3 + large_shard = [test_cid] * 20 # This should exceed memory limit + + # Add a dirty shard + await cache.put(0, small_shard, is_dirty=True) + assert cache.dirty_cache_size == 1 + + # Add a clean shard + await cache.put(1, small_shard) + # Cache size should be 2, but might be less if eviction occurred + assert cache.cache_size >= 1 # At least the dirty shard should remain + assert cache.dirty_cache_size == 1 + + # Add a large clean shard that should trigger eviction + await cache.put(2, large_shard) + + # Dirty shard 0 should still be there (protected) + assert await cache.get(0) is not None # Dirty shard protected + + # Either shard 1 or shard 2 might be evicted depending on memory constraints + # The important thing is that the dirty shard (0) is never evicted + cached_1 = await cache.get(1) + cached_2 = await cache.get(2) + + # At least one of the clean shards should be evicted due to memory pressure + evicted_count = (1 if cached_1 is None else 0) + (1 if cached_2 is None else 0) + assert evicted_count >= 1, ( + "At least one clean shard should be evicted due to memory pressure" + ) + + # Test marking dirty shard as clean + await cache.mark_clean(0) + assert cache.dirty_cache_size == 0 + + # Now shard 0 can be evicted + even_larger_shard = [test_cid] * 30 + await cache.put(3, even_larger_shard) + assert await cache.get(0) is None # Now evicted since it's clean + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_memory_estimation(): + """Test memory usage estimation.""" + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + cache = MemoryBoundedLRUCache(max_memory_bytes=10000) + + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + + # Test with None values (should use minimal memory) + sparse_shard = [None] * 100 + await cache.put(0, sparse_shard) + sparse_usage = cache.estimated_memory_usage + + # Test with CID values (should use more memory) + dense_shard = [test_cid] * 100 + await cache.put(1, dense_shard) + dense_usage = cache.estimated_memory_usage + + # Dense shard should use significantly more memory than sparse + assert dense_usage > sparse_usage # CIDs are larger than None + + # Test cache clear + await cache.clear() + assert cache.estimated_memory_usage == 0 + assert cache.cache_size == 0 + assert cache.dirty_cache_size == 0 + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_cache_integration(create_ipfs): + """Test that ShardedZarrStore properly uses the memory-bounded cache.""" + rpc_base_url, gateway_base_url = create_ipfs + + # Create a store with very small cache to test eviction + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10, 10), + chunk_shape=(2, 2, 2), + chunks_per_shard=8, + max_cache_memory_bytes=2000, # 2KB cache limit + ) + + # Create test data + test_data = np.random.randn(2, 2, 2).astype(np.float32) + buffer = zarr.core.buffer.default_buffer_prototype().buffer.from_bytes( + test_data.tobytes() + ) + + # Test cache behavior by writing data and checking dirty status + # Let's first manually mark a shard as dirty to test the mechanism + await store._shard_data_cache.put(0, [None] * 8, is_dirty=True) + assert store._shard_data_cache.dirty_cache_size == 1 + + # Now test actual store operations - write to multiple chunks + await store.set("c/0/0/0", buffer) + await store.set("c/0/0/1", buffer) + await store.set("c/0/1/0", buffer) + + # The cache should now have at least one dirty shard + assert store._shard_data_cache.dirty_cache_size > 0 + + # Read from the chunks we actually wrote to + result1 = await store.get( + "c/0/0/0", zarr.core.buffer.default_buffer_prototype() + ) + result2 = await store.get( + "c/0/0/1", zarr.core.buffer.default_buffer_prototype() + ) + result3 = await store.get( + "c/0/1/0", zarr.core.buffer.default_buffer_prototype() + ) + + assert result1 is not None + assert result2 is not None + assert result3 is not None + + # Flush to make shards clean + await store.flush() + + # After flush, dirty count should be 0 + assert store._shard_data_cache.dirty_cache_size == 0 + + # But cache should still contain some shards + assert store._shard_data_cache.cache_size > 0 + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_cache_eviction_during_read(create_ipfs): + """Test cache eviction behavior during read-heavy workloads.""" + rpc_base_url, gateway_base_url = create_ipfs + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Create and populate a store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20, 20), + chunk_shape=(2, 2, 2), + chunks_per_shard=10, + max_cache_memory_bytes=1500, # Small cache to force eviction + ) + + # Write data to multiple shards + test_data = np.random.randn(2, 2, 2).astype(np.float32) + buffer = zarr.core.buffer.default_buffer_prototype().buffer.from_bytes( + test_data.tobytes() + ) + + for i in range(0, 10, 2): # Write to shards 0, 1, 2, 3, 4 + await store.set(f"c/{i}/0/0", buffer) + + # Flush to make all shards clean + root_cid = await store.flush() + + # Now open in read-only mode to test pure read behavior + readonly_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=True, + root_cid=root_cid, + max_cache_memory_bytes=1000, # Even smaller cache + ) + + # Read from many different locations to fill and overflow cache + read_results = [] + for i in range(0, 10, 2): + result = await readonly_store.get( + f"c/{i}/0/0", zarr.core.buffer.default_buffer_prototype() + ) + read_results.append(result) + + # All reads should succeed even with cache eviction + assert all(result is not None for result in read_results) + + # Cache should have evicted some shards due to memory limit + assert ( + readonly_store._shard_data_cache.cache_size <= 5 + ) # Not all shards should be cached + + # No shards should be dirty in read-only mode + assert readonly_store._shard_data_cache.dirty_cache_size == 0 diff --git a/tests/testing_utils.py b/tests/testing_utils.py index e8e4422..229762c 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -26,17 +26,15 @@ def cid_strategy() -> SearchStrategy: def ipld_strategy() -> SearchStrategy: - return st.one_of( - [ - st.none(), - st.booleans(), - st.integers(min_value=-9223372036854775808, max_value=9223372036854775807), - st.floats(allow_infinity=False, allow_nan=False), - st.text(), - st.binary(), - cid_strategy(), - ] - ) + return st.one_of([ + st.none(), + st.booleans(), + st.integers(min_value=-9223372036854775808, max_value=9223372036854775807), + st.floats(allow_infinity=False, allow_nan=False), + st.text(), + st.binary(), + cid_strategy(), + ]) key_value_list = st.lists( From a3f60f2cc4d90ed7e3cbcbd2866944b1577be19c Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 10 Sep 2025 14:48:21 -0400 Subject: [PATCH 70/74] fix: zarr coverage --- py_hamt/sharded_zarr_store.py | 5 + tests/test_sharded_zarr_store_coverage.py | 394 ++++++++++++++++++++++ 2 files changed, 399 insertions(+) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index f0b86b5..9f5cc46 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -79,6 +79,7 @@ async def put( and len(self._cache) > 1 ): evicted = False + checked_dirty_shards = set() while self._cache: candidate_idx, candidate_data = self._cache.popitem(last=False) if candidate_idx not in self._dirty_shards: @@ -91,6 +92,10 @@ async def put( else: # Dirty: move to MRU self._cache[candidate_idx] = candidate_data + checked_dirty_shards.add(candidate_idx) + # If we've checked all dirty shards, no clean shards available + if len(checked_dirty_shards) == len(self._dirty_shards): + break if not evicted: # No clean shards to evict break diff --git a/tests/test_sharded_zarr_store_coverage.py b/tests/test_sharded_zarr_store_coverage.py index b53139d..d589587 100644 --- a/tests/test_sharded_zarr_store_coverage.py +++ b/tests/test_sharded_zarr_store_coverage.py @@ -260,3 +260,397 @@ async def test_sharded_zarr_store_other_exceptions(create_ipfs: tuple[str, str]) # with pytest.raises(ValueError, match="Linear chunk index cannot be negative."): # await store_no_root._get_shard_info(-1) + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_empty_shard(): + """Test line 40: empty shard handling in _get_shard_size""" + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + cache = MemoryBoundedLRUCache(max_memory_bytes=1000) + empty_shard = [] + + # Test that empty shard is handled correctly (line 40) + size = cache._get_shard_size(empty_shard) + assert size > 0 # sys.getsizeof should return some size even for empty list + + await cache.put(0, empty_shard) + retrieved = await cache.get(0) + assert retrieved == empty_shard + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_update_existing(): + """Test lines 64-65: cache update logic when shard already exists""" + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + cache = MemoryBoundedLRUCache(max_memory_bytes=10000) + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + + # First put + shard1 = [test_cid] * 2 + await cache.put(0, shard1) + + # Update existing shard (lines 64-65 should be hit) + shard2 = [test_cid] * 3 + await cache.put(0, shard2, is_dirty=True) + + retrieved = await cache.get(0) + assert retrieved == shard2 + assert cache.dirty_cache_size == 1 + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_eviction_break(): + """Test line 96: eviction break when no clean shards available""" + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + cache = MemoryBoundedLRUCache(max_memory_bytes=500) # Small cache + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + + # Add several dirty shards to fill cache + large_shard = [test_cid] * 10 + for i in range(3): + await cache.put(i, large_shard, is_dirty=True) + + # Try to add another large shard - should trigger line 96 break + huge_shard = [test_cid] * 20 + await cache.put(3, huge_shard) + + # All dirty shards should still be in cache (not evicted) + for i in range(3): + assert await cache.get(i) is not None + assert cache.dirty_cache_size == 3 + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_duplicate_root_loading(): + """Test line 277: duplicate root object loading""" + import dag_cbor + + from py_hamt.sharded_zarr_store import ShardedZarrStore + + # Create mock CAS that returns malformed data to trigger line 277 + class MockCAS: + def __init__(self): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + # Return valid DAG-CBOR for a sharded zarr root + root_obj = { + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [None, None], + }, + } + return dag_cbor.encode(root_obj) + + # Create store with mock CAS + mock_cas = MockCAS() + store = ShardedZarrStore(mock_cas, True, "test_cid") + + # This should trigger line 277 where root_obj gets set twice + await store._load_root_from_cid() + + assert store._root_obj is not None + assert store._array_shape == (10,) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_invalid_root_object_structure(): + """Test lines 274 and 278-280: root object structure validation""" + import dag_cbor + + from py_hamt.sharded_zarr_store import ShardedZarrStore + + class MockCAS: + def __init__(self, root_obj): + self.root_obj = root_obj + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + return dag_cbor.encode(self.root_obj) + + # Test line 274: root object is not a dict + mock_cas_not_dict = MockCAS("not a dictionary") + store = ShardedZarrStore(mock_cas_not_dict, True, "test_cid") + with pytest.raises(ValueError, match="Root object is not a valid dictionary"): + await store._load_root_from_cid() + + # Test line 274: root object missing 'chunks' key + mock_cas_no_chunks = MockCAS({ + "metadata": {}, + "manifest_version": "sharded_zarr_v1", + }) + store = ShardedZarrStore(mock_cas_no_chunks, True, "test_cid") + with pytest.raises(ValueError, match="Root object is not a valid dictionary"): + await store._load_root_from_cid() + + # Test lines 278-280: shard_cids is not a list + mock_cas_invalid_shard_cids = MockCAS({ + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": "not a list", # Should be a list + }, + }) + store = ShardedZarrStore(mock_cas_invalid_shard_cids, True, "test_cid") + with pytest.raises(ValueError, match="shard_cids is not a list"): + await store._load_root_from_cid() + + # Test line 280: generic exception handling (invalid DAG-CBOR) + class MockCASInvalidDagCbor: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + return b"invalid dag-cbor data" # This will cause dag_cbor.decode to fail + + mock_cas_invalid_cbor = MockCASInvalidDagCbor() + store = ShardedZarrStore(mock_cas_invalid_cbor, True, "test_cid") + with pytest.raises(ValueError, match="Failed to decode root object"): + await store._load_root_from_cid() + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_invalid_manifest_version(): + """Test lines 281-283: manifest version validation""" + import dag_cbor + + from py_hamt.sharded_zarr_store import ShardedZarrStore + + class MockCAS: + def __init__(self, manifest_version): + self.manifest_version = manifest_version + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + root_obj = { + "manifest_version": self.manifest_version, + "metadata": {}, + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [None, None], + }, + } + return dag_cbor.encode(root_obj) + + # Test with wrong manifest version (should hit lines 281-283) + mock_cas = MockCAS("wrong_version") + store = ShardedZarrStore(mock_cas, True, "test_cid") + + with pytest.raises(ValueError, match="Incompatible manifest version"): + await store._load_root_from_cid() + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_shard_fetch_retry(): + """Test lines 333-343: shard fetch retry and error logging""" + from unittest.mock import patch + + from py_hamt.sharded_zarr_store import ShardedZarrStore + + class MockCAS: + def __init__(self, fail_count=2): + self.fail_count = fail_count + self.attempts = 0 + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + self.attempts += 1 + if self.attempts <= self.fail_count: + raise ConnectionError("Mock connection error") + # Success on final attempt + import dag_cbor + + return dag_cbor.encode([None] * 4) + + async def save(self, data, codec=None): + from multiformats import CID + + return CID.decode( + "bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm" + ) + + mock_cas = MockCAS(fail_count=2) # Fail twice, succeed on 3rd attempt + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + + # Set up a shard CID to fetch + from multiformats import CID + + shard_cid = CID.decode( + "bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm" + ) + store._root_obj["chunks"]["shard_cids"][0] = shard_cid + + # This should retry and eventually succeed (testing retry logic lines 325-329) + with patch("builtins.print") as mock_print: + shard_data = await store._load_or_initialize_shard_cache(0) + assert shard_data is not None + assert len(shard_data) == 4 + + # Test case where all retries fail + mock_cas_fail = MockCAS(fail_count=5) # Fail more than max_retries + store_fail = await ShardedZarrStore.open( + cas=mock_cas_fail, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + store_fail._root_obj["chunks"]["shard_cids"][0] = shard_cid + + # This should hit lines 332-337 (failure after max retries) + with patch("builtins.print") as mock_print: + with pytest.raises( + RuntimeError, match="Failed to fetch shard 0 after 3 attempts" + ): + await store_fail._load_or_initialize_shard_cache(0) + # Should print failure message (line 333) + mock_print.assert_called() + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_with_read_only_clone_attribute(): + """Test line 490: with_read_only clone attribute assignment""" + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + # Create clone with different read_only status (should hit line 490) + clone = store.with_read_only(True) + + # Verify line 490: clone._root_obj = self._root_obj + assert clone._root_obj is store._root_obj + assert clone.read_only is True + assert store.read_only is False + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_get_method_line_565(): + """Test line 565: get method start (line 565 is the method definition)""" + import zarr.core.buffer + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid, offset=None, length=None, suffix=None): + return b"metadata_content" + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + # Add metadata to test the get method + from multiformats import CID + + metadata_cid = CID.decode( + "bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm" + ) + store._root_obj["metadata"]["test.json"] = metadata_cid + + # Test get method (line 565 is the method signature) + proto = zarr.core.buffer.default_buffer_prototype() + result = await store.get("test.json", proto) + + assert result is not None + assert result.to_bytes() == b"metadata_content" + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_exists_exception_handling(): + """Test lines 694-695: exists method exception handling""" + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + # Test exists with invalid chunk key that will cause exception (lines 694-695) + # This should trigger the exception handling and return False + exists = await store.exists("invalid/chunk/key/format") + assert exists is False + + # Test exists with valid chunk key that's out of bounds (should also return False) + exists = await store.exists("c/100") # Out of bounds chunk + assert exists is False From ac171399d7215802f38d1c8b8e3d0416f91e0d24 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:21:29 -0400 Subject: [PATCH 71/74] fix: final tests --- tests/test_sharded_zarr_store_coverage.py | 263 ++++++++++++++++++++++ 1 file changed, 263 insertions(+) diff --git a/tests/test_sharded_zarr_store_coverage.py b/tests/test_sharded_zarr_store_coverage.py index d589587..a4d814d 100644 --- a/tests/test_sharded_zarr_store_coverage.py +++ b/tests/test_sharded_zarr_store_coverage.py @@ -654,3 +654,266 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): # Test exists with valid chunk key that's out of bounds (should also return False) exists = await store.exists("c/100") # Out of bounds chunk assert exists is False + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_cas_save_failure(): + """Test RuntimeError when cas.save fails in set method""" + import zarr.core.buffer + + from py_hamt import ShardedZarrStore + + class MockCASFailingSave: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def save(self, data, codec=None): + # Always fail to save + raise ConnectionError("Mock CAS save failure") + + mock_cas = MockCASFailingSave() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + # Test that cas.save failure raises RuntimeError (lines 656-657) + proto = zarr.core.buffer.default_buffer_prototype() + test_data = proto.buffer.from_bytes(b"test_data") + + with pytest.raises(RuntimeError, match="Failed to save data for key test_key"): + await store.set("test_key", test_data) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_flush_dirty_shard_not_found(): + """Test RuntimeError when dirty shard not found in cache during flush""" + from unittest.mock import patch + + from multiformats import CID + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def save(self, data, codec=None): + return CID.decode( + "bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm" + ) + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + # First put a shard in the cache and mark it as dirty + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + shard_data = [test_cid, None] + await store._shard_data_cache.put(0, shard_data, is_dirty=True) + + # Verify the shard is dirty + assert store._shard_data_cache.dirty_cache_size == 1 + assert 0 in store._shard_data_cache._dirty_shards + + # Mock the cache.get to return None for the dirty shard (simulating cache corruption) + original_get = store._shard_data_cache.get + + async def mock_get_returns_none(shard_idx): + if shard_idx == 0: # Return None for the dirty shard + return None + return await original_get(shard_idx) + + with patch.object( + store._shard_data_cache, "get", side_effect=mock_get_returns_none + ): + # This should hit line 529 (RuntimeError for dirty shard not found in cache) + with pytest.raises(RuntimeError, match="Dirty shard 0 not found in cache"): + await store.flush() + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_failed_to_load_or_initialize_shard(): + """Test RuntimeError when shard fails to load or initialize""" + from unittest.mock import patch + + from multiformats import CID + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + import dag_cbor + + return dag_cbor.encode([None] * 4) + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + + # Set up a shard CID to fetch + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + store._root_obj["chunks"]["shard_cids"][0] = test_cid + + # Mock the cache to always return None, even after put operations + async def mock_get_always_none(shard_idx): + return None # Always return None to simulate cache failure + + async def mock_put_does_nothing(shard_idx, shard_data, is_dirty=False): + pass # Do nothing, so cache remains empty + + with patch.object(store._shard_data_cache, "get", side_effect=mock_get_always_none): + with patch.object( + store._shard_data_cache, "put", side_effect=mock_put_does_nothing + ): + # This should hit lines 451-452 (RuntimeError for failed to load or initialize) + with pytest.raises( + RuntimeError, match="Failed to load or initialize shard 0" + ): + await store._load_or_initialize_shard_cache(0) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_timeout_cleanup_logic(): + """Test timeout cleanup logic in _load_or_initialize_shard_cache""" + import asyncio + from unittest.mock import patch + + from multiformats import CID + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + # Never completes to simulate timeout + await asyncio.sleep(100) + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + + # Set up a shard CID + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + store._root_obj["chunks"]["shard_cids"][0] = test_cid + + # Manually create a pending load event to simulate the scenario + pending_event = asyncio.Event() + store._pending_shard_loads[0] = pending_event + + # Verify the pending load is set up + assert 0 in store._pending_shard_loads + assert not store._pending_shard_loads[0].is_set() + + # Mock wait_for to properly await the coroutine but still raise TimeoutError + async def mock_wait_for(coro, timeout=None): + # Properly cancel the coroutine to avoid the warning + if hasattr(coro, "close"): + coro.close() + raise asyncio.TimeoutError() + + # Test cleanup logic (lines 431-439) + with patch("asyncio.wait_for", side_effect=mock_wait_for): + with pytest.raises(RuntimeError, match="Timeout waiting for shard 0 to load"): + await store._load_or_initialize_shard_cache(0) + + # Verify cleanup occurred (lines 433-437) + # The event should be set and removed from pending loads + assert 0 not in store._pending_shard_loads # Should be cleaned up + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_pending_load_cache_miss(): + """Test RuntimeError when pending load completes but shard not found in cache""" + import asyncio + from unittest.mock import patch + + from multiformats import CID + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + import dag_cbor + + return dag_cbor.encode([None] * 4) + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + + # Set up a shard CID + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + store._root_obj["chunks"]["shard_cids"][0] = test_cid + + # Create a pending load event and manually add it + pending_event = asyncio.Event() + store._pending_shard_loads[0] = pending_event + + # Set up mocks: wait_for succeeds (doesn't timeout) but cache.get returns None + async def mock_wait_for(coro, timeout=None): + # Properly handle the coroutine to avoid warnings + if hasattr(coro, "close"): + coro.close() + # Simulate successful wait - the pending event gets set + pending_event.set() + return True # Successful wait + + async def mock_cache_get(shard_idx): + # Always return None to simulate cache miss after pending load + return None + + # Test the scenario where pending load "completes" but shard not in cache (line 428-430) + with patch("asyncio.wait_for", side_effect=mock_wait_for): + with patch.object(store._shard_data_cache, "get", side_effect=mock_cache_get): + with pytest.raises( + RuntimeError, + match="Shard 0 not found in cache after pending load completed", + ): + await store._load_or_initialize_shard_cache(0) From 6cac0f53d459a9d501b4c048d016b25750b2040f Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:30:18 -0400 Subject: [PATCH 72/74] fix: small changes --- py_hamt/sharded_zarr_store.py | 8 +----- tests/test_sharded_zarr_store_coverage.py | 33 ++++------------------- 2 files changed, 6 insertions(+), 35 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 9f5cc46..54bb7cb 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -333,17 +333,11 @@ async def _fetch_and_cache_full_shard( ) # Exponential backoff continue else: - # Log the failure and raise a specific error - print( - f"Failed to fetch shard {shard_idx} after {max_retries} attempts: {e}" - ) raise RuntimeError( f"Failed to fetch shard {shard_idx} after {max_retries} attempts: {e}" ) except Exception as e: - # Handle non-transient errors immediately - print(f"Error fetching shard {shard_idx}: {e}") - raise + raise e def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # 1. Exclude .json files immediately (metadata) diff --git a/tests/test_sharded_zarr_store_coverage.py b/tests/test_sharded_zarr_store_coverage.py index a4d814d..72463c7 100644 --- a/tests/test_sharded_zarr_store_coverage.py +++ b/tests/test_sharded_zarr_store_coverage.py @@ -281,7 +281,6 @@ async def test_memory_bounded_lru_cache_empty_shard(): @pytest.mark.asyncio async def test_memory_bounded_lru_cache_update_existing(): - """Test lines 64-65: cache update logic when shard already exists""" from multiformats import CID from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache @@ -293,7 +292,6 @@ async def test_memory_bounded_lru_cache_update_existing(): shard1 = [test_cid] * 2 await cache.put(0, shard1) - # Update existing shard (lines 64-65 should be hit) shard2 = [test_cid] * 3 await cache.put(0, shard2, is_dirty=True) @@ -372,7 +370,6 @@ async def load(self, cid): @pytest.mark.asyncio async def test_sharded_zarr_store_invalid_root_object_structure(): - """Test lines 274 and 278-280: root object structure validation""" import dag_cbor from py_hamt.sharded_zarr_store import ShardedZarrStore @@ -405,7 +402,6 @@ async def load(self, cid): with pytest.raises(ValueError, match="Root object is not a valid dictionary"): await store._load_root_from_cid() - # Test lines 278-280: shard_cids is not a list mock_cas_invalid_shard_cids = MockCAS({ "manifest_version": "sharded_zarr_v1", "metadata": {}, @@ -439,7 +435,6 @@ async def load(self, cid): @pytest.mark.asyncio async def test_sharded_zarr_store_invalid_manifest_version(): - """Test lines 281-283: manifest version validation""" import dag_cbor from py_hamt.sharded_zarr_store import ShardedZarrStore @@ -467,7 +462,6 @@ async def load(self, cid): } return dag_cbor.encode(root_obj) - # Test with wrong manifest version (should hit lines 281-283) mock_cas = MockCAS("wrong_version") store = ShardedZarrStore(mock_cas, True, "test_cid") @@ -477,9 +471,6 @@ async def load(self, cid): @pytest.mark.asyncio async def test_sharded_zarr_store_shard_fetch_retry(): - """Test lines 333-343: shard fetch retry and error logging""" - from unittest.mock import patch - from py_hamt.sharded_zarr_store import ShardedZarrStore class MockCAS: @@ -526,11 +517,9 @@ async def save(self, data, codec=None): ) store._root_obj["chunks"]["shard_cids"][0] = shard_cid - # This should retry and eventually succeed (testing retry logic lines 325-329) - with patch("builtins.print") as mock_print: - shard_data = await store._load_or_initialize_shard_cache(0) - assert shard_data is not None - assert len(shard_data) == 4 + shard_data = await store._load_or_initialize_shard_cache(0) + assert shard_data is not None + assert len(shard_data) == 4 # Test case where all retries fail mock_cas_fail = MockCAS(fail_count=5) # Fail more than max_retries @@ -543,14 +532,8 @@ async def save(self, data, codec=None): ) store_fail._root_obj["chunks"]["shard_cids"][0] = shard_cid - # This should hit lines 332-337 (failure after max retries) - with patch("builtins.print") as mock_print: - with pytest.raises( - RuntimeError, match="Failed to fetch shard 0 after 3 attempts" - ): - await store_fail._load_or_initialize_shard_cache(0) - # Should print failure message (line 333) - mock_print.assert_called() + with pytest.raises(RuntimeError, match="Failed to fetch shard 0 after 3 attempts"): + await store_fail._load_or_initialize_shard_cache(0) @pytest.mark.asyncio @@ -627,7 +610,6 @@ async def load(self, cid, offset=None, length=None, suffix=None): @pytest.mark.asyncio async def test_sharded_zarr_store_exists_exception_handling(): - """Test lines 694-695: exists method exception handling""" from py_hamt import ShardedZarrStore class MockCAS: @@ -646,7 +628,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): chunks_per_shard=2, ) - # Test exists with invalid chunk key that will cause exception (lines 694-695) # This should trigger the exception handling and return False exists = await store.exists("invalid/chunk/key/format") assert exists is False @@ -683,7 +664,6 @@ async def save(self, data, codec=None): chunks_per_shard=2, ) - # Test that cas.save failure raises RuntimeError (lines 656-657) proto = zarr.core.buffer.default_buffer_prototype() test_data = proto.buffer.from_bytes(b"test_data") @@ -791,7 +771,6 @@ async def mock_put_does_nothing(shard_idx, shard_data, is_dirty=False): with patch.object( store._shard_data_cache, "put", side_effect=mock_put_does_nothing ): - # This should hit lines 451-452 (RuntimeError for failed to load or initialize) with pytest.raises( RuntimeError, match="Failed to load or initialize shard 0" ): @@ -847,12 +826,10 @@ async def mock_wait_for(coro, timeout=None): coro.close() raise asyncio.TimeoutError() - # Test cleanup logic (lines 431-439) with patch("asyncio.wait_for", side_effect=mock_wait_for): with pytest.raises(RuntimeError, match="Timeout waiting for shard 0 to load"): await store._load_or_initialize_shard_cache(0) - # Verify cleanup occurred (lines 433-437) # The event should be set and removed from pending loads assert 0 not in store._pending_shard_loads # Should be cleaned up From 7529c2ff236f7dc911fd4c65cd8d1da3ec710f85 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 15 Sep 2025 09:15:11 -0400 Subject: [PATCH 73/74] fix: small updates --- py_hamt/sharded_zarr_store.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index 54bb7cb..c6a8fb6 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -336,8 +336,6 @@ async def _fetch_and_cache_full_shard( raise RuntimeError( f"Failed to fetch shard {shard_idx} after {max_retries} attempts: {e}" ) - except Exception as e: - raise e def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: # 1. Exclude .json files immediately (metadata) @@ -627,7 +625,7 @@ async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: and not key.startswith("time/") and not key.startswith(("lat/", "latitude/")) and not key.startswith(("lon/", "longitude/")) - and not len(key) == 9 + and not key == "zarr.json" ): metadata_json = json.loads(value.to_bytes().decode("utf-8")) new_array_shape = metadata_json.get("shape") From 78cf72d6512abebd73be1c1b384205d4b60a4217 Mon Sep 17 00:00:00 2001 From: TheGreatAlgo <37487508+TheGreatAlgo@users.noreply.github.com> Date: Mon, 15 Sep 2025 10:48:00 -0400 Subject: [PATCH 74/74] fix: remove duplicate --- py_hamt/sharded_zarr_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py index c6a8fb6..52093d4 100644 --- a/py_hamt/sharded_zarr_store.py +++ b/py_hamt/sharded_zarr_store.py @@ -278,7 +278,6 @@ async def _load_root_from_cid(self): raise ValueError("shard_cids is not a list.") except Exception as e: raise ValueError(f"Failed to decode root object: {e}") - self._root_obj = dag_cbor.decode(root_bytes) if self._root_obj.get("manifest_version") != "sharded_zarr_v1": raise ValueError(