diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72c594e..66a0794 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.11.11 + rev: v0.12.12 hooks: - id: ruff-check - id: ruff-format diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..49c6cf4 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,87 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Common Development Commands + +Setup environment: +```bash +uv sync +source .venv/bin/activate +pre-commit install +``` + +Run all checks (tests, linting, formatting, type checking): +```bash +bash run-checks.sh +``` + +Run tests: +```bash +# All tests (requires IPFS daemon or Docker) +pytest --ipfs --cov=py_hamt tests/ + +# Quick tests without IPFS integration +pytest --cov=py_hamt tests/ + +# Single test file +pytest tests/test_hamt.py + +# Coverage report +uv run coverage report --fail-under=100 --show-missing +``` + +Linting and formatting: +```bash +# Run all pre-commit hooks +uv run pre-commit run --all-files --show-diff-on-failure + +# Fix auto-fixable ruff issues +uv run ruff check --fix +``` + +Type checking and other tools: +```bash +# Type checking is handled by pre-commit hooks (mypy) +# Documentation preview +uv run pdoc py_hamt +``` + +## Architecture Overview + +py-hamt implements a Hash Array Mapped Trie (HAMT) for IPFS/IPLD content-addressed storage. The core architecture follows this pattern: + +1. **ContentAddressedStore (CAS)** - Abstract storage layer (store.py) + - `KuboCAS` - IPFS/Kubo implementation for production + - `InMemoryCAS` - In-memory implementation for testing + +2. **HAMT** - Core data structure (hamt.py) + - Uses blake3 hashing by default + - Implements content-addressed trie for efficient key-value storage + - Supports async operations for large datasets + +3. **ZarrHAMTStore** - Zarr integration (zarr_hamt_store.py) + - Implements zarr.abc.store.Store interface + - Enables storing large Zarr arrays on IPFS via HAMT + - Keys stored verbatim, values as raw bytes + +4. **Encryption Layer** - Optional encryption (encryption_hamt_store.py) + - `SimpleEncryptedZarrHAMTStore` for fully encrypted storage + +## Key Design Patterns + +- All storage operations are async to handle IPFS network calls +- Content addressing means identical data gets same hash/CID +- HAMT provides O(log n) access time for large key sets +- Store abstractions allow swapping storage backends +- Type hints required throughout (mypy enforced) +- 100% test coverage required with hypothesis property-based testing + +## IPFS Integration Requirements + +Tests require either: +- Local IPFS daemon running (`ipfs daemon`) +- Docker available for containerized IPFS +- Neither (unit tests only, integration tests skip) + +The `--ipfs` pytest flag controls IPFS test execution. diff --git a/fsgs.py b/fsgs.py index 0ff4462..f51e2e5 100644 --- a/fsgs.py +++ b/fsgs.py @@ -16,7 +16,7 @@ async def main(): - cid = "bafyr4iecw3faqyvj75psutabk2jxpddpjdokdy5b26jdnjjzpkzbgb5xoq" + cid = "bafyr4ibiduv7ml3jeyl3gn6cjcrcizfqss7j64rywpbj3whr7tc6xipt3y" # Use KuboCAS as an async context manager async with KuboCAS() as kubo_cas: # connects to a local kubo node diff --git a/public_gateway_example.py b/public_gateway_example.py index 186aa4d..d331ec9 100644 --- a/public_gateway_example.py +++ b/public_gateway_example.py @@ -53,7 +53,7 @@ async def fetch_zarr_from_gateway(cid: str, gateway: str = "https://ipfs.io"): async def main(): # Example CID - this points to a weather dataset stored on IPFS - cid = "bafyr4iecw3faqyvj75psutabk2jxpddpjdokdy5b26jdnjjzpkzbgb5xoq" + cid = "bafyr4ibiduv7ml3jeyl3gn6cjcrcizfqss7j64rywpbj3whr7tc6xipt3y" # Try different public gateways gateways = [ diff --git a/py_hamt/__init__.py b/py_hamt/__init__.py index d322d1b..2fff260 100644 --- a/py_hamt/__init__.py +++ b/py_hamt/__init__.py @@ -1,5 +1,7 @@ from .encryption_hamt_store import SimpleEncryptedZarrHAMTStore from .hamt import HAMT, blake3_hashfn +from .hamt_to_sharded_converter import convert_hamt_to_sharded, sharded_converter_cli +from .sharded_zarr_store import ShardedZarrStore from .store_httpx import ContentAddressedStore, InMemoryCAS, KuboCAS from .zarr_hamt_store import ZarrHAMTStore @@ -11,6 +13,7 @@ "KuboCAS", "ZarrHAMTStore", "SimpleEncryptedZarrHAMTStore", + "ShardedZarrStore", + "convert_hamt_to_sharded", + "sharded_converter_cli", ] - -print("Running py-hamt from source!") diff --git a/py_hamt/hamt.py b/py_hamt/hamt.py index 693505c..d22ddcf 100644 --- a/py_hamt/hamt.py +++ b/py_hamt/hamt.py @@ -8,6 +8,7 @@ Callable, Dict, Iterator, + Optional, cast, ) @@ -589,10 +590,18 @@ async def delete(self, key: str) -> None: # If we didn't make a change, then this key must not exist within the HAMT raise KeyError - async def get(self, key: str) -> IPLDKind: + async def get( + self, + key: str, + offset: Optional[int] = None, + length: Optional[int] = None, + suffix: Optional[int] = None, + ) -> IPLDKind: """Get a value.""" pointer: IPLDKind = await self.get_pointer(key) - data: bytes = await self.cas.load(pointer) + data: bytes = await self.cas.load( + pointer, offset=offset, length=length, suffix=suffix + ) if self.values_are_bytes: return data else: diff --git a/py_hamt/hamt_to_sharded_converter.py b/py_hamt/hamt_to_sharded_converter.py new file mode 100644 index 0000000..b0e8921 --- /dev/null +++ b/py_hamt/hamt_to_sharded_converter.py @@ -0,0 +1,130 @@ +import argparse +import asyncio +import time + +import xarray as xr +from multiformats import CID + +from .hamt import HAMT +from .sharded_zarr_store import ShardedZarrStore +from .store_httpx import KuboCAS +from .zarr_hamt_store import ZarrHAMTStore + + +async def convert_hamt_to_sharded( + cas: KuboCAS, hamt_root_cid: str, chunks_per_shard: int +) -> str: + """ + Converts a Zarr dataset from a HAMT-based store to a ShardedZarrStore. + + Args: + cas: An initialized ContentAddressedStore instance (KuboCAS). + hamt_root_cid: The root CID of the source ZarrHAMTStore. + chunks_per_shard: The number of chunks to group into a single shard in the new store. + + Returns: + The root CID of the newly created ShardedZarrStore. + """ + print(f"--- Starting Conversion from HAMT Root {hamt_root_cid} ---") + start_time = time.perf_counter() + # 1. Open the source HAMT store for reading + print("Opening source HAMT store...") + hamt_ro = await HAMT.build( + cas=cas, root_node_id=hamt_root_cid, values_are_bytes=True, read_only=True + ) + source_store = ZarrHAMTStore(hamt_ro, read_only=True) + source_dataset = xr.open_zarr(store=source_store, consolidated=True) + # 2. Introspect the source array to get its configuration + print("Reading metadata from source store...") + + # Read the stores metadata to get array shape and chunk shape + data_var_name = next(iter(source_dataset.data_vars)) + ordered_dims = list(source_dataset[data_var_name].dims) + array_shape_tuple = tuple(source_dataset.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(source_dataset.chunks[dim][0] for dim in ordered_dims) + array_shape = array_shape_tuple + chunk_shape = chunk_shape_tuple + + # 3. Create the destination ShardedZarrStore for writing + print( + f"Initializing new ShardedZarrStore with {chunks_per_shard} chunks per shard..." + ) + dest_store = await ShardedZarrStore.open( + cas=cas, + read_only=False, + array_shape=array_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + + print("Destination store initialized.") + + # 4. Iterate and copy all data from source to destination + print("Starting data migration...") + count = 0 + async for key in hamt_ro.keys(): + count += 1 + # Read the raw data (metadata or chunk) from the source + cid: CID = await hamt_ro.get_pointer(key) + cid_base32_str = str(cid.encode("base32")) + + # Write the exact same key-value pair to the destination. + await dest_store.set_pointer(key, cid_base32_str) + if count % 200 == 0: # pragma: no cover + print(f"Migrated {count} keys...") # pragma: no cover + + print(f"Migration of {count} total keys complete.") + + # 5. Finalize the new store by flushing it to the CAS + print("Flushing new store to get final root CID...") + new_root_cid = await dest_store.flush() + end_time = time.perf_counter() + + print("\n--- Conversion Complete! ---") + print(f"Total time: {end_time - start_time:.2f} seconds") + print(f"New ShardedZarrStore Root CID: {new_root_cid}") + return new_root_cid + + +async def sharded_converter_cli(): + parser = argparse.ArgumentParser( + description="Convert a Zarr HAMT store to a Sharded Zarr store." + ) + parser.add_argument( + "hamt_cid", type=str, help="The root CID of the source Zarr HAMT store." + ) + parser.add_argument( + "--chunks-per-shard", + type=int, + default=6250, + help="Number of chunk CIDs to store per shard in the new store.", + ) + parser.add_argument( + "--rpc-url", + type=str, + default="http://127.0.0.1:5001", + help="The URL of the IPFS Kubo RPC API.", + ) + parser.add_argument( + "--gateway-url", + type=str, + default="http://127.0.0.1:8080", + help="The URL of the IPFS Gateway.", + ) + args = parser.parse_args() + # Initialize the KuboCAS client with the provided RPC and Gateway URLs + async with KuboCAS( + rpc_base_url=args.rpc_url, gateway_base_url=args.gateway_url + ) as cas_client: + try: + await convert_hamt_to_sharded( + cas=cas_client, + hamt_root_cid=args.hamt_cid, + chunks_per_shard=args.chunks_per_shard, + ) + except Exception as e: + print(f"\nAn error occurred: {e}") + + +if __name__ == "__main__": + asyncio.run(sharded_converter_cli()) # pragma: no cover diff --git a/py_hamt/reference/flat_zarr_store.py b/py_hamt/reference/flat_zarr_store.py new file mode 100644 index 0000000..850d7ae --- /dev/null +++ b/py_hamt/reference/flat_zarr_store.py @@ -0,0 +1,375 @@ +# Functional Flat Store for Zarr using IPFS. This is ultimately the "Best" way to do it +# but only if the metadata is small enough to fit in a single CBOR object. +# This is commented out because we are not using it. However I did not want to delete it +# because it is a good example of how to implement a Zarr store using a flat index + +# import asyncio +# import math +# from collections.abc import AsyncIterator, Iterable +# from typing import Optional, cast + +# import dag_cbor +# import zarr.abc.store +# import zarr.core.buffer +# from zarr.core.common import BytesLike + +# from .store import ContentAddressedStore + + +# class FlatZarrStore(zarr.abc.store.Store): +# """ +# Implements the Zarr Store API using a flat, predictable layout for chunk CIDs. + +# This store bypasses the need for a HAMT, offering direct, calculated +# access to chunk data based on a mathematical formula. It is designed for +# dense, multi-dimensional arrays where chunk locations are predictable. + +# The store is structured around a single root CBOR object. This root object contains: +# 1. A dictionary mapping metadata keys (like 'zarr.json') to their CIDs. +# 2. A single CID pointing to a large, contiguous block of bytes (the "flat index"). +# This flat index is a concatenation of the CIDs of all data chunks. + +# Accessing a chunk involves: +# 1. Loading the root object (if not cached). +# 2. Calculating the byte offset of the chunk's CID within the flat index. +# 3. Fetching that specific CID using a byte-range request on the flat index. +# 4. Fetching the actual chunk data using the retrieved CID. + +# ### Sample Code +# ```python +# import xarray as xr +# import numpy as np +# from py_hamt import KuboCAS, FlatZarrStore + +# # --- Write --- +# ds = xr.Dataset( +# {"data": (("t", "y", "x"), np.arange(24).reshape(2, 3, 4))}, +# ) +# cas = KuboCAS() +# # When creating, must provide array shape and chunk shape +# store_write = await FlatZarrStore.open( +# cas, +# read_only=False, +# array_shape=ds.data.shape, +# chunk_shape=ds.data.encoding['chunks'] +# ) +# ds.to_zarr(store=store_write, mode="w") +# root_cid = await store_write.flush() # IMPORTANT: flush to get final root CID +# print(f"Finished writing. Root CID: {root_cid}") + + +# # --- Read --- +# store_read = await FlatZarrStore.open(cas, read_only=True, root_cid=root_cid) +# ds_read = xr.open_zarr(store=store_read) +# print("Read back dataset:") +# print(ds_read) +# xr.testing.assert_identical(ds, ds_read) +# ``` +# """ + +# def __init__( +# self, cas: ContentAddressedStore, read_only: bool, root_cid: Optional[str] +# ): +# """Use the async `open()` classmethod to instantiate this class.""" +# super().__init__(read_only=read_only) +# self.cas = cas +# self._root_cid = root_cid +# self._root_obj: Optional[dict] = None +# self._flat_index_cache: Optional[bytearray] = None +# self._cid_len: Optional[int] = None +# self._array_shape: Optional[tuple[int, ...]] = None +# self._chunk_shape: Optional[tuple[int, ...]] = None +# self._chunks_per_dim: Optional[tuple[int, ...]] = None +# self._dirty = False + +# @classmethod +# async def open( +# cls, +# cas: ContentAddressedStore, +# read_only: bool, +# root_cid: Optional[str] = None, +# *, +# array_shape: Optional[tuple[int, ...]] = None, +# chunk_shape: Optional[tuple[int, ...]] = None, +# cid_len: int = 59, # Default for base32 v1 CIDs like bafy... +# ) -> "FlatZarrStore": +# """ +# Asynchronously opens an existing FlatZarrStore or initializes a new one. + +# Args: +# cas: The Content Addressed Store (e.g., KuboCAS). +# read_only: If True, the store is in read-only mode. +# root_cid: The root CID of an existing store to open. Required for read_only. +# array_shape: The full shape of the Zarr array. Required for a new writeable store. +# chunk_shape: The shape of a single chunk. Required for a new writeable store. +# cid_len: The expected byte length of a CID string. +# """ +# store = cls(cas, read_only, root_cid) +# if root_cid: +# await store._load_root_from_cid() +# elif not read_only: +# if not all([array_shape, chunk_shape]): +# raise ValueError( +# "array_shape and chunk_shape must be provided for a new store." +# ) +# store._initialize_new_root(array_shape, chunk_shape, cid_len) +# else: +# raise ValueError("root_cid must be provided for a read-only store.") +# return store + +# def _initialize_new_root( +# self, +# array_shape: tuple[int, ...], +# chunk_shape: tuple[int, ...], +# cid_len: int, +# ): +# self._array_shape = array_shape +# self._chunk_shape = chunk_shape +# self._cid_len = cid_len +# self._chunks_per_dim = tuple( +# math.ceil(a / c) for a, c in zip(array_shape, chunk_shape) +# ) +# self._root_obj = { +# "manifest_version": "flat_zarr_v1", +# "metadata": {}, +# "chunks": { +# "cid": None, # Will be filled on first flush +# "array_shape": list(self._array_shape), +# "chunk_shape": list(self._chunk_shape), +# "cid_byte_length": self._cid_len, +# }, +# } +# self._dirty = True + +# async def _load_root_from_cid(self): +# if not self._root_cid: +# raise ValueError("Cannot load root without a root_cid.") +# root_bytes = await self.cas.load(self._root_cid) +# self._root_obj = dag_cbor.decode(root_bytes) +# chunk_info = self._root_obj.get("chunks", {}) +# self._array_shape = tuple(chunk_info["array_shape"]) +# self._chunk_shape = tuple(chunk_info["chunk_shape"]) +# self._cid_len = chunk_info["cid_byte_length"] +# self._chunks_per_dim = tuple( +# math.ceil(a / c) for a, c in zip(self._array_shape, self._chunk_shape) +# ) + +# def _parse_chunk_key(self, key: str) -> Optional[tuple[int, ...]]: +# if not self._array_shape or not key.startswith("c/"): +# return None +# parts = key.split("/") +# if len(parts) != len(self._array_shape) + 1: +# return None +# try: +# return tuple(map(int, parts[1:])) +# except (ValueError, IndexError): +# return None + +# async def set_partial_values( +# self, key_start_values: Iterable[tuple[str, int, BytesLike]] +# ) -> None: +# """@private""" +# raise NotImplementedError("Partial writes are not supported by this store.") + +# async def get_partial_values( +# self, +# prototype: zarr.core.buffer.BufferPrototype, +# key_ranges: Iterable[tuple[str, zarr.abc.store.ByteRequest | None]], +# ) -> list[zarr.core.buffer.Buffer | None]: +# """ +# Retrieves multiple keys or byte ranges concurrently. +# """ +# tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] +# results = await asyncio.gather(*tasks) +# return results + +# def __eq__(self, other: object) -> bool: +# """@private""" +# if not isinstance(other, FlatZarrStore): +# return NotImplemented +# return self._root_cid == other._root_cid + +# def _get_chunk_offset(self, chunk_coords: tuple[int, ...]) -> int: +# linear_index = 0 +# multiplier = 1 +# for i in reversed(range(len(self._chunks_per_dim))): +# linear_index += chunk_coords[i] * multiplier +# multiplier *= self._chunks_per_dim[i] +# return linear_index * self._cid_len + +# async def flush(self) -> str: +# """ +# Writes all pending changes (metadata and chunk index) to the CAS +# and returns the new root CID. This MUST be called after all writes are complete. +# """ +# if self.read_only or not self._dirty: +# return self._root_cid + +# if self._flat_index_cache is not None: +# flat_index_cid_obj = await self.cas.save( +# bytes(self._flat_index_cache), codec="raw" +# ) +# self._root_obj["chunks"]["cid"] = str(flat_index_cid_obj) + +# root_obj_bytes = dag_cbor.encode(self._root_obj) +# new_root_cid_obj = await self.cas.save(root_obj_bytes, codec="dag-cbor") +# self._root_cid = str(new_root_cid_obj) +# self._dirty = False +# return self._root_cid + +# async def get( +# self, +# key: str, +# prototype: zarr.core.buffer.BufferPrototype, +# byte_range: zarr.abc.store.ByteRequest | None = None, +# ) -> zarr.core.buffer.Buffer | None: +# """@private""" +# if self._root_obj is None: +# await self._load_root_from_cid() + +# chunk_coords = self._parse_chunk_key(key) +# try: +# # Metadata request +# if chunk_coords is None: +# metadata_cid = self._root_obj["metadata"].get(key) +# if metadata_cid is None: +# return None +# data = await self.cas.load(metadata_cid) +# return prototype.buffer.from_bytes(data) + +# # Chunk data request +# flat_index_cid = self._root_obj["chunks"]["cid"] +# if flat_index_cid is None: +# return None + +# offset = self._get_chunk_offset(chunk_coords) +# chunk_cid_bytes = await self.cas.load( +# flat_index_cid, offset=offset, length=self._cid_len +# ) + +# if all(b == 0 for b in chunk_cid_bytes): +# return None # Chunk doesn't exist + +# chunk_cid = chunk_cid_bytes.decode("ascii") +# data = await self.cas.load(chunk_cid) +# return prototype.buffer.from_bytes(data) + +# except (KeyError, IndexError): +# return None + +# async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: +# """@private""" +# if self.read_only: +# raise ValueError("Cannot write to a read-only store.") +# if self._root_obj is None: +# raise RuntimeError("Store not initialized for writing.") + +# self._dirty = True +# raw_bytes = value.to_bytes() +# value_cid_obj = await self.cas.save(raw_bytes, codec="raw") +# value_cid = str(value_cid_obj) + +# if len(value_cid) != self._cid_len: +# raise ValueError( +# f"Inconsistent CID length. Expected {self._cid_len}, got {len(value_cid)}" +# ) + +# chunk_coords = self._parse_chunk_key(key) + +# if chunk_coords is None: # Metadata +# self._root_obj["metadata"][key] = value_cid +# return + +# # Chunk Data +# if self._flat_index_cache is None: +# num_chunks = math.prod(self._chunks_per_dim) +# self._flat_index_cache = bytearray(num_chunks * self._cid_len) + +# offset = self._get_chunk_offset(chunk_coords) +# self._flat_index_cache[offset : offset + self._cid_len] = value_cid.encode( +# "ascii" +# ) + +# # --- Other required zarr.abc.store methods --- + +# async def exists(self, key: str) -> bool: +# """@private""" +# # A more efficient version might check for null bytes in the flat index +# # but this is functionally correct. + +# # TODO: Optimize this check +# return True + + +# # return (await self.get(key, zarr.core.buffer.Buffer.prototype, None)) is not None + +# @property +# def supports_writes(self) -> bool: +# """@private""" +# return not self.read_only + +# @property +# def supports_partial_writes(self) -> bool: +# """@private""" +# return False # Each chunk is an immutable object + +# @property +# def supports_deletes(self) -> bool: +# """@private""" +# return not self.read_only + +# async def delete(self, key: str) -> None: +# if self.read_only: +# raise ValueError("Cannot delete from a read-only store.") +# if self._root_obj is None: +# await self._load_root_from_cid() +# chunk_coords = self._parse_chunk_key(key) +# if chunk_coords is None: +# if key in self._root_obj["metadata"]: +# del self._root_obj["metadata"][key] +# self._dirty = True +# return +# else: +# raise KeyError(f"Metadata key '{key}' not found.") +# flat_index_cid = self._root_obj["chunks"]["cid"] +# if self._flat_index_cache is None: +# if not flat_index_cid: +# raise KeyError(f"Chunk key '{key}' not found in non-existent index.") +# self._flat_index_cache = bytearray(await self.cas.load(flat_index_cid)) +# offset = self._get_chunk_offset(chunk_coords) +# if all(b == 0 for b in self._flat_index_cache[offset : offset + self._cid_len]): +# raise KeyError(f"Chunk key '{key}' not found.") +# self._flat_index_cache[offset : offset + self._cid_len] = bytearray(self._cid_len) +# self._dirty = True + +# @property +# def supports_listing(self) -> bool: +# """@private""" +# return True + +# async def list(self) -> AsyncIterator[str]: +# """@private""" +# if self._root_obj is None: +# await self._load_root_from_cid() +# for key in self._root_obj["metadata"]: +# yield key +# # Note: Listing all chunk keys without reading the index is non-trivial. +# # A full implementation might need an efficient way to iterate non-null chunks. +# # This basic version only lists metadata. + +# async def list_prefix(self, prefix: str) -> AsyncIterator[str]: +# """@private""" +# async for key in self.list(): +# if key.startswith(prefix): +# yield key + +# async def list_dir(self, prefix: str) -> AsyncIterator[str]: +# """@private""" +# # This simplified version only works for the root. +# if prefix == "": +# seen = set() +# async for key in self.list(): +# name = key.split('/')[0] +# if name not in seen: +# seen.add(name) +# yield name diff --git a/py_hamt/sharded_zarr_store.py b/py_hamt/sharded_zarr_store.py new file mode 100644 index 0000000..52093d4 --- /dev/null +++ b/py_hamt/sharded_zarr_store.py @@ -0,0 +1,866 @@ +import asyncio +import itertools +import json +import math +import sys +from collections import OrderedDict, defaultdict +from collections.abc import AsyncIterator, Iterable +from typing import DefaultDict, Dict, List, Optional, Set, Tuple + +import dag_cbor +import zarr.abc.store +import zarr.core.buffer +from multiformats.cid import CID +from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest +from zarr.core.common import BytesLike + +from .store_httpx import ContentAddressedStore + + +class MemoryBoundedLRUCache: + """ + An LRU cache that evicts items when memory usage exceeds a threshold. + + Memory usage is calculated using sys.getsizeof for accurate sizing. + Dirty shards (those marked for writing) are never evicted until marked clean. + All operations are thread-safe for async access using an asyncio.Lock. + """ + + def __init__(self, max_memory_bytes: int = 100 * 1024 * 1024): # 100MB default + self.max_memory_bytes = max_memory_bytes + self._cache: OrderedDict[int, List[Optional[CID]]] = OrderedDict() + self._dirty_shards: Set[int] = set() + self._shard_sizes: Dict[int, int] = {} # Cached sizes for each shard + self._actual_memory_usage = 0 + self._cache_lock = asyncio.Lock() + + def _get_shard_size(self, shard_data: List[Optional[CID]]) -> int: + """Compute actual size: list overhead + sum of item sizes.""" + if not shard_data: + return sys.getsizeof(shard_data) + total = sys.getsizeof(shard_data) + for item in shard_data: + total += sys.getsizeof(item) + return total + + async def get(self, shard_idx: int) -> Optional[List[Optional[CID]]]: + """Get a shard from cache, moving it to end (most recently used).""" + async with self._cache_lock: + if shard_idx not in self._cache: + return None + shard_data = self._cache.pop(shard_idx) + self._cache[shard_idx] = shard_data + return shard_data + + async def put( + self, shard_idx: int, shard_data: List[Optional[CID]], is_dirty: bool = False + ) -> None: + """Add or update a shard in cache, evicting old items if needed.""" + async with self._cache_lock: + shard_size = self._get_shard_size(shard_data) + + # If shard exists, remove its old size + if shard_idx in self._cache: + self._cache.pop(shard_idx) + self._actual_memory_usage -= self._shard_sizes.pop(shard_idx, 0) + + # Track dirty status + if is_dirty: + self._dirty_shards.add(shard_idx) + + # Add new shard + self._cache[shard_idx] = shard_data + self._shard_sizes[shard_idx] = shard_size + self._actual_memory_usage += shard_size + + # Evict old items if over memory limit, never evict dirty shards + while ( + self._actual_memory_usage > self.max_memory_bytes + and len(self._cache) > 1 + ): + evicted = False + checked_dirty_shards = set() + while self._cache: + candidate_idx, candidate_data = self._cache.popitem(last=False) + if candidate_idx not in self._dirty_shards: + # Evict this clean LRU + self._actual_memory_usage -= self._shard_sizes.pop( + candidate_idx, 0 + ) + evicted = True + break + else: + # Dirty: move to MRU + self._cache[candidate_idx] = candidate_data + checked_dirty_shards.add(candidate_idx) + # If we've checked all dirty shards, no clean shards available + if len(checked_dirty_shards) == len(self._dirty_shards): + break + if not evicted: + # No clean shards to evict + break + + async def mark_dirty(self, shard_idx: int) -> None: + """Mark a shard as dirty (should not be evicted).""" + async with self._cache_lock: + if shard_idx in self._cache: + self._dirty_shards.add(shard_idx) + + async def mark_clean(self, shard_idx: int) -> None: + """Mark a shard as clean (can be evicted).""" + async with self._cache_lock: + self._dirty_shards.discard(shard_idx) + + async def clear(self) -> None: + """Clear all cached data.""" + async with self._cache_lock: + self._cache.clear() + self._dirty_shards.clear() + self._shard_sizes.clear() + self._actual_memory_usage = 0 + + async def __contains__(self, shard_idx: int) -> bool: + async with self._cache_lock: + return shard_idx in self._cache + + @property + def estimated_memory_usage(self) -> int: + """Current memory usage in bytes, based on actual sizes.""" + return self._actual_memory_usage + + @property + def cache_size(self) -> int: + """Number of items currently cached.""" + return len(self._cache) + + @property + def dirty_cache_size(self) -> int: + """Number of dirty items currently cached.""" + return len(self._dirty_shards) + + +class ShardedZarrStore(zarr.abc.store.Store): + """ + Implements the Zarr Store API using a sharded layout for chunk CIDs. + + This store divides the flat index of chunk CIDs into multiple "shards". + Each shard is a DAG-CBOR array where each element is either a CID link + to a chunk or a null value if the chunk is empty. This structure allows + for efficient traversal by IPLD-aware systems. + + The store's root object contains: + 1. A dictionary mapping metadata keys (like 'zarr.json') to their CIDs. + 2. A list of CIDs, where each CID points to a shard object. + 3. Sharding configuration details (e.g., chunks_per_shard). + """ + + def __init__( + self, + cas: ContentAddressedStore, + read_only: bool, + root_cid: Optional[str] = None, + *, + max_cache_memory_bytes: int = 100 * 1024 * 1024, # 100MB default + ): + """Use the async `open()` classmethod to instantiate this class.""" + super().__init__(read_only=read_only) + self.cas = cas + self._root_cid = root_cid + self._root_obj: dict + + self._resize_lock = asyncio.Lock() + # An event to signal when a resize is in-progress. + # It starts in the "set" state, allowing all operations to proceed. + self._resize_complete = asyncio.Event() + self._resize_complete.set() + self._shard_locks: DefaultDict[int, asyncio.Lock] = defaultdict(asyncio.Lock) + + self._shard_data_cache = MemoryBoundedLRUCache(max_cache_memory_bytes) + self._pending_shard_loads: Dict[int, asyncio.Event] = {} + + self._array_shape: Tuple[int, ...] + self._chunk_shape: Tuple[int, ...] + self._chunks_per_dim: Tuple[int, ...] + self._chunks_per_shard: int + self._num_shards: int = 0 + self._total_chunks: int = 0 + + self._dirty_root = False + + def __update_geometry(self): + """Calculates derived geometric properties from the base shapes.""" + + if not all(cs > 0 for cs in self._chunk_shape): + raise ValueError("All chunk_shape dimensions must be positive.") + if not all(s >= 0 for s in self._array_shape): + raise ValueError("All array_shape dimensions must be non-negative.") + + self._chunks_per_dim = tuple( + math.ceil(a / c) if c > 0 else 0 + for a, c in zip(self._array_shape, self._chunk_shape) + ) + self._total_chunks = math.prod(self._chunks_per_dim) + + if not self._total_chunks == 0: + self._num_shards = ( + self._total_chunks + self._chunks_per_shard - 1 + ) // self._chunks_per_shard + + @classmethod + async def open( + cls, + cas: ContentAddressedStore, + read_only: bool, + root_cid: Optional[str] = None, + *, + array_shape: Optional[Tuple[int, ...]] = None, + chunk_shape: Optional[Tuple[int, ...]] = None, + chunks_per_shard: Optional[int] = None, + max_cache_memory_bytes: int = 100 * 1024 * 1024, # 100MB default + ) -> "ShardedZarrStore": + """ + Asynchronously opens an existing ShardedZarrStore or initializes a new one. + """ + store = cls( + cas, read_only, root_cid, max_cache_memory_bytes=max_cache_memory_bytes + ) + if root_cid: + await store._load_root_from_cid() + elif not read_only: + if array_shape is None or chunk_shape is None: + raise ValueError( + "array_shape and chunk_shape must be provided for a new store." + ) + + if not isinstance(chunks_per_shard, int) or chunks_per_shard <= 0: + raise ValueError("chunks_per_shard must be a positive integer.") + + store._initialize_new_root(array_shape, chunk_shape, chunks_per_shard) + else: + raise ValueError("root_cid must be provided for a read-only store.") + return store + + def _initialize_new_root( + self, + array_shape: Tuple[int, ...], + chunk_shape: Tuple[int, ...], + chunks_per_shard: int, + ): + self._array_shape = array_shape + self._chunk_shape = chunk_shape + self._chunks_per_shard = chunks_per_shard + + self.__update_geometry() + + self._root_obj = { + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": list(self._array_shape), + "chunk_shape": list(self._chunk_shape), + "sharding_config": { + "chunks_per_shard": self._chunks_per_shard, + }, + "shard_cids": [None] * self._num_shards, + }, + } + self._dirty_root = True + + async def _load_root_from_cid(self): + root_bytes = await self.cas.load(self._root_cid) + try: + self._root_obj = dag_cbor.decode(root_bytes) + if not isinstance(self._root_obj, dict) or "chunks" not in self._root_obj: + raise ValueError( + "Root object is not a valid dictionary with 'chunks' key." + ) + if not isinstance(self._root_obj["chunks"]["shard_cids"], list): + raise ValueError("shard_cids is not a list.") + except Exception as e: + raise ValueError(f"Failed to decode root object: {e}") + + if self._root_obj.get("manifest_version") != "sharded_zarr_v1": + raise ValueError( + f"Incompatible manifest version: {self._root_obj.get('manifest_version')}. Expected 'sharded_zarr_v1'." + ) + + chunk_info = self._root_obj["chunks"] + self._array_shape = tuple(chunk_info["array_shape"]) + self._chunk_shape = tuple(chunk_info["chunk_shape"]) + self._chunks_per_shard = chunk_info["sharding_config"]["chunks_per_shard"] + + self.__update_geometry() + + if len(chunk_info["shard_cids"]) != self._num_shards: + raise ValueError( + f"Inconsistent number of shards. Expected {self._num_shards}, found {len(chunk_info['shard_cids'])}." + ) + + async def _fetch_and_cache_full_shard( + self, + shard_idx: int, + shard_cid: str, + max_retries: int = 3, + retry_delay: float = 1.0, + ) -> None: + """ + Fetch a shard from CAS and cache it, with retry logic for transient errors. + + Args: + shard_idx: The index of the shard to fetch. + shard_cid: The CID of the shard. + max_retries: Maximum number of retry attempts for transient errors. + retry_delay: Delay between retry attempts in seconds. + """ + for attempt in range(max_retries): + try: + shard_data_bytes = await self.cas.load(shard_cid) + decoded_shard = dag_cbor.decode(shard_data_bytes) + if not isinstance(decoded_shard, list): + raise TypeError(f"Shard {shard_idx} did not decode to a list.") + await self._shard_data_cache.put(shard_idx, decoded_shard) + # Always set the Event to unblock waiting coroutines + if shard_idx in self._pending_shard_loads: + self._pending_shard_loads[shard_idx].set() + del self._pending_shard_loads[shard_idx] + return # Success + except (ConnectionError, TimeoutError) as e: + # Handle transient errors (e.g., network issues) + if attempt < max_retries - 1: + await asyncio.sleep( + retry_delay * (2**attempt) + ) # Exponential backoff + continue + else: + raise RuntimeError( + f"Failed to fetch shard {shard_idx} after {max_retries} attempts: {e}" + ) + + def _parse_chunk_key(self, key: str) -> Optional[Tuple[int, ...]]: + # 1. Exclude .json files immediately (metadata) + if key.endswith(".json"): + return None + excluded_array_prefixes = {"time", "lat", "lon", "latitude", "longitude"} + + chunk_marker = "/c/" + marker_idx = key.rfind(chunk_marker) # Use rfind for robustness + if marker_idx == -1: + # Key does not contain "/c/", so it's not a chunk data key + # in the expected format (e.g., could be .zattrs, .zgroup at various levels). + return None + + # Extract the part of the key before "/c/", which might represent the array/group path + # e.g., "temp" from "temp/c/0/0/0" + # e.g., "group1/lat" from "group1/lat/c/0" + # e.g., "" if key is "c/0/0/0" (root array) + path_before_c = key[:marker_idx] + + # Determine the actual array name (the last component of the path before "/c/") + actual_array_name = "" + if path_before_c: + actual_array_name = path_before_c.split("/")[-1] + + # 2. If the determined array name is in our exclusion list, return None. + if actual_array_name in excluded_array_prefixes: + return None + + # The part after "/c/" contains the chunk coordinates + coord_part = key[marker_idx + len(chunk_marker) :] + parts = coord_part.split("/") + + coords = tuple(map(int, parts)) + # Validate coordinates against the chunk grid of the store's configured array + for i, c_coord in enumerate(coords): + if not (0 <= c_coord < self._chunks_per_dim[i]): + raise IndexError( + f"Chunk coordinate {c_coord} at dimension {i} is out of bounds for dimension size {self._chunks_per_dim[i]}." + ) + return coords + + def _get_linear_chunk_index(self, chunk_coords: Tuple[int, ...]) -> int: + linear_index = 0 + multiplier = 1 + # Convert N-D chunk coordinates to a flat 1-D index (row-major order) + for i in reversed(range(len(self._chunks_per_dim))): + linear_index += chunk_coords[i] * multiplier + multiplier *= self._chunks_per_dim[i] + return linear_index + + def _get_shard_info(self, linear_chunk_index: int) -> Tuple[int, int]: + shard_idx = linear_chunk_index // self._chunks_per_shard + index_in_shard = linear_chunk_index % self._chunks_per_shard + return shard_idx, index_in_shard + + async def _load_or_initialize_shard_cache( + self, shard_idx: int + ) -> List[Optional[CID]]: + """ + Load a shard into the cache or initialize an empty shard if it doesn't exist. + + Args: + shard_idx: The index of the shard to load or initialize. + + Returns: + List[Optional[CID]]: The shard data (list of CIDs or None). + + Raises: + ValueError: If the shard index is out of bounds. + RuntimeError: If the shard cannot be loaded or initialized. + """ + cached_shard = await self._shard_data_cache.get(shard_idx) + if cached_shard is not None: + return cached_shard + + if shard_idx in self._pending_shard_loads: + try: + # Wait for the pending load with a timeout (e.g., 60 seconds) + await asyncio.wait_for( + self._pending_shard_loads[shard_idx].wait(), timeout=60.0 + ) + cached_shard = await self._shard_data_cache.get(shard_idx) + if cached_shard is not None: + return cached_shard + else: + raise RuntimeError( + f"Shard {shard_idx} not found in cache after pending load completed." + ) + except asyncio.TimeoutError: + # Clean up the pending load to allow retry + if shard_idx in self._pending_shard_loads: + self._pending_shard_loads[shard_idx].set() + del self._pending_shard_loads[shard_idx] + raise RuntimeError(f"Timeout waiting for shard {shard_idx} to load.") + + if not (0 <= shard_idx < self._num_shards): + raise ValueError(f"Shard index {shard_idx} out of bounds.") + + shard_cid_obj = self._root_obj["chunks"]["shard_cids"][shard_idx] + if shard_cid_obj: + self._pending_shard_loads[shard_idx] = asyncio.Event() + shard_cid_str = str(shard_cid_obj) + await self._fetch_and_cache_full_shard(shard_idx, shard_cid_str) + else: + empty_shard = [None] * self._chunks_per_shard + await self._shard_data_cache.put(shard_idx, empty_shard) + + result = await self._shard_data_cache.get(shard_idx) + if result is None: + raise RuntimeError(f"Failed to load or initialize shard {shard_idx}") + return result # type: ignore[return-value] + + async def set_partial_values( + self, key_start_values: Iterable[Tuple[str, int, BytesLike]] + ) -> None: + raise NotImplementedError( + "Partial writes are not supported by ShardedZarrStore." + ) + + async def get_partial_values( + self, + prototype: zarr.core.buffer.BufferPrototype, + key_ranges: Iterable[Tuple[str, zarr.abc.store.ByteRequest | None]], + ) -> List[Optional[zarr.core.buffer.Buffer]]: + tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + results = await asyncio.gather(*tasks) + return results + + def with_read_only(self, read_only: bool = False) -> "ShardedZarrStore": + """ + Return this store (if the flag already matches) or a *shallow* + clone that presents the requested read‑only status. + + The clone **shares** the same CAS instance and internal state; + no flushing, network traffic or async work is done. + """ + # Fast path + if read_only == self.read_only: + return self # Same mode, return same instance + + # Create new instance with different read_only flag + # Creates a *bare* instance without running its __init__ + clone = type(self).__new__(type(self)) + + # Copy all attributes from the current instance + clone.cas = self.cas + clone._root_cid = self._root_cid + clone._root_obj = self._root_obj + + clone._resize_lock = self._resize_lock + clone._resize_complete = self._resize_complete + clone._shard_locks = self._shard_locks + + clone._shard_data_cache = self._shard_data_cache + clone._pending_shard_loads = self._pending_shard_loads + + clone._array_shape = self._array_shape + clone._chunk_shape = self._chunk_shape + clone._chunks_per_dim = self._chunks_per_dim + clone._chunks_per_shard = self._chunks_per_shard + clone._num_shards = self._num_shards + clone._total_chunks = self._total_chunks + + clone._dirty_root = self._dirty_root + + # Re‑initialise the zarr base class so that Zarr sees the flag + zarr.abc.store.Store.__init__(clone, read_only=read_only) + return clone + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ShardedZarrStore): + return False + # For equality, root CID is primary. Config like chunks_per_shard is part of that root's identity. + return self._root_cid == other._root_cid + + # If nothing to flush, return the root CID. + async def flush(self) -> str: + async with self._shard_data_cache._cache_lock: + dirty_shards = list(self._shard_data_cache._dirty_shards) + if dirty_shards: + for shard_idx in sorted(dirty_shards): + # Get the list of CIDs/Nones from the cache + shard_data_list = await self._shard_data_cache.get(shard_idx) + if shard_data_list is None: + raise RuntimeError(f"Dirty shard {shard_idx} not found in cache") + + # Encode this list into a DAG-CBOR byte representation + shard_data_bytes = dag_cbor.encode(shard_data_list) + + # Save the DAG-CBOR block and get its CID + new_shard_cid_obj = await self.cas.save( + shard_data_bytes, + codec="dag-cbor", # Use 'dag-cbor' codec + ) + + if ( + self._root_obj["chunks"]["shard_cids"][shard_idx] + != new_shard_cid_obj + ): + # Store the CID object directly + self._root_obj["chunks"]["shard_cids"][shard_idx] = ( + new_shard_cid_obj + ) + self._dirty_root = True + # Mark shard as clean after flushing + await self._shard_data_cache.mark_clean(shard_idx) + + if self._dirty_root: + # Ensure all metadata CIDs are CID objects for correct encoding + self._root_obj["metadata"] = { + k: (CID.decode(v) if isinstance(v, str) else v) + for k, v in self._root_obj["metadata"].items() + } + root_obj_bytes = dag_cbor.encode(self._root_obj) + new_root_cid = await self.cas.save(root_obj_bytes, codec="dag-cbor") + self._root_cid = str(new_root_cid) + self._dirty_root = False + + # Ignore because root_cid will always exist after initialization or flush. + return self._root_cid # type: ignore[return-value] + + async def get( + self, + key: str, + prototype: zarr.core.buffer.BufferPrototype, + byte_range: Optional[zarr.abc.store.ByteRequest] = None, + ) -> Optional[zarr.core.buffer.Buffer]: + chunk_coords = self._parse_chunk_key(key) + # Metadata request + if chunk_coords is None: + metadata_cid_obj = self._root_obj["metadata"].get(key) + if metadata_cid_obj is None: + return None + if byte_range is not None: + raise ValueError( + "Byte range requests are not supported for metadata keys." + ) + data = await self.cas.load(str(metadata_cid_obj)) + return prototype.buffer.from_bytes(data) + # Chunk data request + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + + # This will load the full shard into cache if it's not already there. + shard_lock = self._shard_locks[shard_idx] + async with shard_lock: + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + + # Get the CID object (or None) from the cached list. + chunk_cid_obj = target_shard_list[index_in_shard] + + if chunk_cid_obj is None: + return None # Chunk is empty/doesn't exist. + + chunk_cid_str = str(chunk_cid_obj) + + req_offset = None + req_length = None + req_suffix = None + + if byte_range: + if isinstance(byte_range, RangeByteRequest): + req_offset = byte_range.start + if byte_range.end is not None: + if byte_range.start > byte_range.end: + raise ValueError( + f"Byte range start ({byte_range.start}) cannot be greater than end ({byte_range.end})" + ) + req_length = byte_range.end - byte_range.start + elif isinstance(byte_range, OffsetByteRequest): + req_offset = byte_range.offset + elif isinstance(byte_range, SuffixByteRequest): + req_suffix = byte_range.suffix + data = await self.cas.load( + chunk_cid_str, offset=req_offset, length=req_length, suffix=req_suffix + ) + return prototype.buffer.from_bytes(data) + + async def set(self, key: str, value: zarr.core.buffer.Buffer) -> None: + if self.read_only: + raise PermissionError("Cannot write to a read-only store.") + await self._resize_complete.wait() + + if ( + key.endswith("zarr.json") + and not key.startswith("time/") + and not key.startswith(("lat/", "latitude/")) + and not key.startswith(("lon/", "longitude/")) + and not key == "zarr.json" + ): + metadata_json = json.loads(value.to_bytes().decode("utf-8")) + new_array_shape = metadata_json.get("shape") + if not new_array_shape: + raise ValueError("Shape not found in metadata.") + if tuple(new_array_shape) != self._array_shape: + async with self._resize_lock: + # Double-check after acquiring the lock, in case another task + # just finished this exact resize while we were waiting. + if tuple(new_array_shape) != self._array_shape: + # Block all other tasks until resize is complete. + self._resize_complete.clear() + try: + await self.resize_store(new_shape=tuple(new_array_shape)) + finally: + # All waiting tasks will now un-pause and proceed safely. + self._resize_complete.set() + + raw_data_bytes = value.to_bytes() + # Save the data to CAS first to get its CID. + # Metadata is often saved as 'raw', chunks as well unless compressed. + try: + data_cid_obj = await self.cas.save(raw_data_bytes, codec="raw") + await self.set_pointer(key, str(data_cid_obj)) + except Exception as e: + raise RuntimeError(f"Failed to save data for key {key}: {e}") + return None # type: ignore[return-value] + + async def set_pointer(self, key: str, pointer: str) -> None: + chunk_coords = self._parse_chunk_key(key) + + pointer_cid_obj = CID.decode(pointer) # Convert string to CID object + + if chunk_coords is None: # Metadata key + self._root_obj["metadata"][key] = pointer_cid_obj + self._dirty_root = True + return None + + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + + shard_lock = self._shard_locks[shard_idx] + async with shard_lock: + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + + if target_shard_list[index_in_shard] != pointer_cid_obj: + target_shard_list[index_in_shard] = pointer_cid_obj + await self._shard_data_cache.mark_dirty(shard_idx) + return None + + async def exists(self, key: str) -> bool: + try: + chunk_coords = self._parse_chunk_key(key) + if chunk_coords is None: # Metadata + return key in self._root_obj.get("metadata", {}) + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + # Load shard if not cached and check the index + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + return target_shard_list[index_in_shard] is not None + except (ValueError, IndexError, KeyError): + return False + + @property + def supports_writes(self) -> bool: + return not self.read_only + + @property + def supports_partial_writes(self) -> bool: + return False # Each chunk CID is written atomically into a shard slot + + @property + def supports_deletes(self) -> bool: + return not self.read_only + + async def delete(self, key: str) -> None: + if self.read_only: + raise PermissionError("Cannot delete from a read-only store.") + + chunk_coords = self._parse_chunk_key(key) + if chunk_coords is None: # Metadata + if self._root_obj["metadata"].pop(key, None): + self._dirty_root = True + else: + raise KeyError(f"Metadata key '{key}' not found.") + return None + + linear_chunk_index = self._get_linear_chunk_index(chunk_coords) + shard_idx, index_in_shard = self._get_shard_info(linear_chunk_index) + + shard_lock = self._shard_locks[shard_idx] + async with shard_lock: + target_shard_list = await self._load_or_initialize_shard_cache(shard_idx) + if target_shard_list[index_in_shard] is not None: + target_shard_list[index_in_shard] = None + await self._shard_data_cache.mark_dirty(shard_idx) + + @property + def supports_listing(self) -> bool: + return True + + async def list(self) -> AsyncIterator[str]: + for key in list(self._root_obj.get("metadata", {})): + yield key + + async def list_prefix(self, prefix: str) -> AsyncIterator[str]: + async for key in self.list(): + if key.startswith(prefix): + yield key + + async def graft_store(self, store_to_graft_cid: str, chunk_offset: Tuple[int, ...]): + if self.read_only: + raise PermissionError("Cannot graft onto a read-only store.") + + store_to_graft = await ShardedZarrStore.open( + cas=self.cas, read_only=True, root_cid=store_to_graft_cid + ) + source_chunk_grid = store_to_graft._chunks_per_dim + for local_coords in itertools.product(*[range(s) for s in source_chunk_grid]): + linear_local_index = store_to_graft._get_linear_chunk_index(local_coords) + local_shard_idx, index_in_local_shard = store_to_graft._get_shard_info( + linear_local_index + ) + # Load the source shard into its cache + source_shard_list = await store_to_graft._load_or_initialize_shard_cache( + local_shard_idx + ) + + pointer_cid_obj = source_shard_list[index_in_local_shard] + if pointer_cid_obj is None: + continue + + # Calculate global coordinates and write to the main store's index + global_coords = tuple( + c_local + c_offset + for c_local, c_offset in zip(local_coords, chunk_offset) + ) + linear_global_index = self._get_linear_chunk_index(global_coords) + global_shard_idx, index_in_global_shard = self._get_shard_info( + linear_global_index + ) + + shard_lock = self._shard_locks[global_shard_idx] + async with shard_lock: + target_shard_list = await self._load_or_initialize_shard_cache( + global_shard_idx + ) + if target_shard_list[index_in_global_shard] != pointer_cid_obj: + target_shard_list[index_in_global_shard] = pointer_cid_obj + await self._shard_data_cache.mark_dirty(global_shard_idx) + + async def resize_store(self, new_shape: Tuple[int, ...]): + """ + Resizes the store's main shard index to accommodate a new overall array shape. + This is a metadata-only operation on the store's root object. + Used when doing skeleton writes or appends via xarray where the array shape changes. + """ + if self.read_only: + raise PermissionError("Cannot resize a read-only store.") + if ( + # self._root_obj is None + self._chunk_shape is None + or self._chunks_per_shard is None + or self._array_shape is None + ): + raise RuntimeError("Store is not properly initialized for resizing.") + if len(new_shape) != len(self._array_shape): + raise ValueError( + "New shape must have the same number of dimensions as the old shape." + ) + + self._array_shape = tuple(new_shape) + self._chunks_per_dim = tuple( + math.ceil(a / c) if c > 0 else 0 + for a, c in zip(self._array_shape, self._chunk_shape) + ) + self._total_chunks = math.prod(self._chunks_per_dim) + old_num_shards = self._num_shards if self._num_shards is not None else 0 + self._num_shards = ( + (self._total_chunks + self._chunks_per_shard - 1) // self._chunks_per_shard + if self._total_chunks > 0 + else 0 + ) + self._root_obj["chunks"]["array_shape"] = list(self._array_shape) + if self._num_shards > old_num_shards: + self._root_obj["chunks"]["shard_cids"].extend( + [None] * (self._num_shards - old_num_shards) + ) + elif self._num_shards < old_num_shards: + self._root_obj["chunks"]["shard_cids"] = self._root_obj["chunks"][ + "shard_cids" + ][: self._num_shards] + + self._dirty_root = True + + async def resize_variable(self, variable_name: str, new_shape: Tuple[int, ...]): + """ + Resizes the Zarr metadata for a specific variable (e.g., '.json' file). + This does NOT change the store's main shard index. + """ + if self.read_only: + raise PermissionError("Cannot resize a read-only store.") + + zarr_metadata_key = f"{variable_name}/zarr.json" + + old_zarr_metadata_cid = self._root_obj["metadata"].get(zarr_metadata_key) + if not old_zarr_metadata_cid: + raise KeyError( + f"Cannot find metadata for key '{zarr_metadata_key}' to resize." + ) + + old_zarr_metadata_bytes = await self.cas.load(old_zarr_metadata_cid) + zarr_metadata_json = json.loads(old_zarr_metadata_bytes) + + zarr_metadata_json["shape"] = list(new_shape) + + new_zarr_metadata_bytes = json.dumps(zarr_metadata_json, indent=2).encode( + "utf-8" + ) + # Metadata is a raw blob of bytes + new_zarr_metadata_cid = await self.cas.save( + new_zarr_metadata_bytes, codec="raw" + ) + + self._root_obj["metadata"][zarr_metadata_key] = new_zarr_metadata_cid + self._dirty_root = True + + async def list_dir(self, prefix: str) -> AsyncIterator[str]: + seen: Set[str] = set() + if prefix == "": + async for key in self.list(): # Iterates metadata keys + # e.g., if key is "group1/.zgroup" or "array1/.json", first_component is "group1" or "array1" + # if key is ".zgroup", first_component is ".zgroup" + first_component = key.split("/", 1)[0] + if first_component not in seen: + seen.add(first_component) + yield first_component + else: + raise NotImplementedError("Listing with a prefix is not implemented yet.") diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index ecbe4d0..5f6b5e2 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -2,7 +2,7 @@ import random import re from abc import ABC, abstractmethod -from typing import Any, Literal, Tuple, cast +from typing import Any, Dict, Literal, Optional, Tuple, cast import httpx from dag_cbor.ipld import IPLDKind @@ -32,9 +32,33 @@ async def save(self, data: bytes, codec: CodecInput) -> IPLDKind: """ @abstractmethod - async def load(self, id: IPLDKind) -> bytes: + async def load( + self, + id: IPLDKind, + offset: Optional[int] = None, + length: Optional[int] = None, + suffix: Optional[int] = None, + ) -> bytes: """Retrieve data.""" + async def pin_cid(self, id: IPLDKind, target_rpc: str) -> None: + """Pin a CID in the storage.""" + pass # pragma: no cover + + async def unpin_cid(self, id: IPLDKind, target_rpc: str) -> None: + """Unpin a CID in the storage.""" + pass # pragma: no cover + + async def pin_update( + self, old_id: IPLDKind, new_id: IPLDKind, target_rpc: str + ) -> None: + """Update the pinned CID in the storage.""" + pass # pragma: no cover + + async def pin_ls(self, target_rpc: str) -> list[Dict[str, Any]]: + """List all pinned CIDs in the storage.""" + return [] # pragma: no cover + class InMemoryCAS(ContentAddressedStore): """Used mostly for faster testing, this is why this is not exported. It hashes all inputs and uses that as a key to an in-memory python dict, mimicking a content addressed storage system. The hash bytes are the ID that `save` returns and `load` takes in.""" @@ -51,7 +75,13 @@ async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> by self.store[hash] = data return hash - async def load(self, id: IPLDKind) -> bytes: + async def load( + self, + id: IPLDKind, + offset: Optional[int] = None, + length: Optional[int] = None, + suffix: Optional[int] = None, + ) -> bytes: """ `ContentAddressedStore` allows any IPLD scalar key. For the in-memory backend we *require* a `bytes` hash; anything else is rejected at run @@ -66,12 +96,24 @@ async def load(self, id: IPLDKind) -> bytes: raise TypeError( f"InMemoryCAS only supports byte‐hash keys; got {type(id).__name__}" ) - + data: bytes try: - return self.store[key] + data = self.store[key] except KeyError as exc: raise KeyError("Object not found in in-memory store") from exc + if offset is not None: + start = offset + if length is not None: + end = start + length + return data[start:end] + else: + return data[start:] + elif suffix is not None: # If only length is given, assume start from 0 + return data[-suffix:] + else: # Full load + return data + class KuboCAS(ContentAddressedStore): """ @@ -146,6 +188,7 @@ def __init__( *, headers: dict[str, str] | None = None, auth: Tuple[str, str] | None = None, + pin_on_add: bool = False, chunker: str = "size-1048576", max_retries: int = 3, initial_delay: float = 1.0, @@ -187,6 +230,13 @@ def __init__( These are the first part of the url, defaults that refer to the default that kubo launches with on a local machine are provided. """ + self._owns_client: bool = False + self._closed: bool = True + self._client_per_loop: Dict[asyncio.AbstractEventLoop, httpx.AsyncClient] = {} + self._default_headers = headers + self._default_auth = auth + + # Now, perform validation that might raise an exception chunker_pattern = r"(?:size-[1-9]\d*|rabin(?:-[1-9]\d*-[1-9]\d*-[1-9]\d*)?)" if re.fullmatch(chunker_pattern, chunker) is None: raise ValueError("Invalid chunker specification") @@ -209,7 +259,8 @@ def __init__( else: gateway_base_url = f"{gateway_base_url}/ipfs/" - self.rpc_url: str = f"{rpc_base_url}/api/v0/add?hash={self.hasher}&chunker={self.chunker}&pin=false" + pin_string: str = "true" if pin_on_add else "false" + self.rpc_url: str = f"{rpc_base_url}/api/v0/add?hash={self.hasher}&chunker={self.chunker}&pin={pin_string}" """@private""" self.gateway_base_url: str = gateway_base_url """@private""" @@ -228,7 +279,7 @@ def __init__( self._default_auth = auth self._sem: asyncio.Semaphore = asyncio.Semaphore(concurrency) - self._closed: bool = False + self._closed = False # Validate retry parameters if max_retries < 0: @@ -392,16 +443,41 @@ async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> CI raise raise RuntimeError("Exited the retry loop unexpectedly.") # pragma: no cover - async def load(self, id: IPLDKind) -> bytes: + async def load( + self, + id: IPLDKind, + offset: Optional[int] = None, + length: Optional[int] = None, + suffix: Optional[int] = None, + ) -> bytes: + """Load data from a CID using the IPFS gateway with optional Range requests.""" cid = cast(CID, id) url: str = f"{self.gateway_base_url + str(cid)}" - async with self._sem: + headers: Dict[str, str] = {} + + # Construct the Range header if required + if offset is not None: + start = offset + if length is not None: + # Standard HTTP Range: bytes=start-end (inclusive) + end = start + length - 1 + headers["Range"] = f"bytes={start}-{end}" + else: + # Standard HTTP Range: bytes=start- (from start to end) + headers["Range"] = f"bytes={start}-" + elif suffix is not None: + # Standard HTTP Range: bytes=-N (last N bytes) + headers["Range"] = f"bytes=-{suffix}" + + async with self._sem: # Throttle gateway client = self._loop_client() retry_count = 0 while retry_count <= self.max_retries: try: - response = await client.get(url, timeout=60.0) + response = await client.get( + url, headers=headers or None, timeout=60.0 + ) response.raise_for_status() return response.content @@ -415,15 +491,96 @@ async def load(self, id: IPLDKind) -> bytes: else None, ) - # Calculate backoff delay + # Calculate backoff delay with jitter delay = self.initial_delay * ( self.backoff_factor ** (retry_count - 1) ) - # Add some jitter to prevent thundering herd jitter = delay * 0.1 * (random.random() - 0.5) await asyncio.sleep(delay + jitter) except httpx.HTTPStatusError: # Re-raise non-timeout HTTP errors immediately raise + raise RuntimeError("Exited the retry loop unexpectedly.") # pragma: no cover + + # --------------------------------------------------------------------- # + # pin_cid() – method to pin a CID # + # --------------------------------------------------------------------- # + async def pin_cid( + self, + cid: CID, + target_rpc: str = "http://127.0.0.1:5001", + ) -> None: + """ + Pins a CID to the local Kubo node via the RPC API. + + This call is recursive by default, pinning all linked objects. + + Args: + cid (CID): The Content ID to pin. + target_rpc (str): The RPC URL of the Kubo node. + """ + params = {"arg": str(cid), "recursive": "true"} + pin_add_url_base: str = f"{target_rpc}/api/v0/pin/add" + + async with self._sem: # throttle RPC + client = self._loop_client() + response = await client.post(pin_add_url_base, params=params) + response.raise_for_status() + + async def unpin_cid( + self, cid: CID, target_rpc: str = "http://127.0.0.1:5001" + ) -> None: + """ + Unpins a CID from the local Kubo node via the RPC API. + + Args: + cid (CID): The Content ID to unpin. + """ + params = {"arg": str(cid), "recursive": "true"} + unpin_url_base: str = f"{target_rpc}/api/v0/pin/rm" + async with self._sem: # throttle RPC + client = self._loop_client() + response = await client.post(unpin_url_base, params=params) + response.raise_for_status() + + async def pin_update( + self, + old_id: IPLDKind, + new_id: IPLDKind, + target_rpc: str = "http://127.0.0.1:5001", + ) -> None: + """ + Updates the pinned CID in the storage. + + Args: + old_id (IPLDKind): The old Content ID to replace. + new_id (IPLDKind): The new Content ID to pin. + """ + params = {"arg": [str(old_id), str(new_id)]} + pin_update_url_base: str = f"{target_rpc}/api/v0/pin/update" + async with self._sem: # throttle RPC + client = self._loop_client() + response = await client.post(pin_update_url_base, params=params) + response.raise_for_status() + + async def pin_ls( + self, target_rpc: str = "http://127.0.0.1:5001" + ) -> list[Dict[str, Any]]: + """ + Lists all pinned CIDs on the local Kubo node via the RPC API. + + Args: + target_rpc (str): The RPC URL of the Kubo node. + + Returns: + List[CID]: A list of pinned CIDs. + """ + pin_ls_url_base: str = f"{target_rpc}/api/v0/pin/ls" + async with self._sem: # throttle RPC + client = self._loop_client() + response = await client.post(pin_ls_url_base) + response.raise_for_status() + pins = response.json().get("Keys", []) + return pins diff --git a/py_hamt/zarr_hamt_store.py b/py_hamt/zarr_hamt_store.py index 2892f4e..ba752e6 100644 --- a/py_hamt/zarr_hamt_store.py +++ b/py_hamt/zarr_hamt_store.py @@ -1,5 +1,6 @@ +import asyncio from collections.abc import AsyncIterator, Iterable -from typing import cast +from typing import Optional, cast import zarr.abc.store import zarr.core.buffer @@ -80,6 +81,29 @@ def __init__(self, hamt: HAMT, read_only: bool = False) -> None: self.metadata_read_cache: dict[str, bytes] = {} """@private""" + def _map_byte_request( + self, byte_range: Optional[zarr.abc.store.ByteRequest] + ) -> tuple[Optional[int], Optional[int], Optional[int]]: + """Helper to map Zarr ByteRequest to offset, length, suffix.""" + offset: Optional[int] = None + length: Optional[int] = None + suffix: Optional[int] = None + + if byte_range: + if isinstance(byte_range, zarr.abc.store.RangeByteRequest): + offset = byte_range.start + length = byte_range.end - byte_range.start + if length is not None and length < 0: + raise ValueError("End must be >= start for RangeByteRequest") + elif isinstance(byte_range, zarr.abc.store.OffsetByteRequest): + offset = byte_range.offset + elif isinstance(byte_range, zarr.abc.store.SuffixByteRequest): + suffix = byte_range.suffix + else: + raise TypeError(f"Unsupported ByteRequest type: {type(byte_range)}") + + return offset, length, suffix + @property def read_only(self) -> bool: # type: ignore[override] if self._forced_read_only is not None: # instance attr overrides @@ -131,27 +155,40 @@ async def get( len(key) >= 9 and key[-9:] == "zarr.json" ) # if path ends with zarr.json - if is_metadata and key in self.metadata_read_cache: + if is_metadata and byte_range is None and key in self.metadata_read_cache: val = self.metadata_read_cache[key] else: + offset, length, suffix = self._map_byte_request(byte_range) val = cast( - bytes, await self.hamt.get(key) + bytes, + await self.hamt.get( + key, offset=offset, length=length, suffix=suffix + ), ) # We know values received will always be bytes since we only store bytes in the HAMT - if is_metadata: + if is_metadata and byte_range is None: self.metadata_read_cache[key] = val return prototype.buffer.from_bytes(val) except KeyError: # Sometimes zarr queries keys that don't exist anymore, just return nothing on those cases return None + except Exception as e: + print(f"Error getting key '{key}' with range {byte_range}: {e}") + raise async def get_partial_values( self, prototype: zarr.core.buffer.BufferPrototype, key_ranges: Iterable[tuple[str, zarr.abc.store.ByteRequest | None]], ) -> list[zarr.core.buffer.Buffer | None]: - """@private""" - raise NotImplementedError + """ + Retrieves multiple keys or byte ranges concurrently using asyncio.gather. + """ + tasks = [self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + results = await asyncio.gather( + *tasks, return_exceptions=False + ) # Set return_exceptions=True for debugging + return results async def exists(self, key: str) -> bool: """@private""" diff --git a/pyproject.toml b/pyproject.toml index 5e72353..f712c93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,3 +39,4 @@ dev = [ [tool.ruff] lint.extend-select = ["I"] +preview = true diff --git a/run-checks.sh b/run-checks.sh index bfbd222..b4995d1 100644 --- a/run-checks.sh +++ b/run-checks.sh @@ -7,4 +7,5 @@ uv run pytest --ipfs --cov=py_hamt tests/ uv run coverage report --fail-under=100 --show-missing # Check for linting, formatting, and type checking using the pre-commit hooks found in .pre-commit-config.yaml -uv run pre-commit run --all-files --show-diff-on-failure +# uv run pre-commit run --all-files --show-diff-on-failure +uv run pre-commit run --all-files diff --git a/tests/test_benchmark_stores.py b/tests/test_benchmark_stores.py new file mode 100644 index 0000000..b2b9534 --- /dev/null +++ b/tests/test_benchmark_stores.py @@ -0,0 +1,310 @@ +# import time + +# import numpy as np +# import pandas as pd +# import pytest +# import xarray as xr +# from dag_cbor.ipld import IPLDKind + +# # Import both store implementations +# from py_hamt import HAMT, KuboCAS, FlatZarrStore, ShardedZarrStore +# from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +# @pytest.fixture(scope="module") +# def random_zarr_dataset(): +# """Creates a random xarray Dataset for benchmarking.""" +# # Using a slightly larger dataset for a more meaningful benchmark +# times = pd.date_range("2024-01-01", periods=100) +# lats = np.linspace(-90, 90, 18) +# lons = np.linspace(-180, 180, 36) + +# temp = np.random.randn(len(times), len(lats), len(lons)) +# precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons))) + +# ds = xr.Dataset( +# { +# "temp": (["time", "lat", "lon"], temp), +# }, +# coords={"time": times, "lat": lats, "lon": lons}, +# ) + +# # Define chunking for the store +# ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) +# yield ds + +# @pytest.fixture(scope="module") +# def random_shard_dataset(): +# """Creates a random xarray Dataset for benchmarking.""" +# # Using a slightly larger dataset for a more meaningful benchmark +# times = pd.date_range("2024-01-01", periods=100) +# lats = np.linspace(-90, 90, 18) +# lons = np.linspace(-180, 180, 36) + +# temp = np.random.randn(len(times), len(lats), len(lons)) +# precip = np.random.gamma(4, 2.5, size=(len(times), len(lats), len(lons))) + +# ds = xr.Dataset( +# { +# "precip": (["time", "lat", "lon"], precip), +# }, +# coords={"time": times, "lat": lats, "lon": lons}, +# ) + +# # Define chunking for the store +# ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) +# yield ds + + +# # # ### +# # # BENCHMARK FOR THE NEW FlatZarrStore +# # # ### +# # @pytest.mark.asyncio(loop_scope="session") +# # async def test_benchmark_flat_store( +# # create_ipfs: tuple[str, str], +# # random_zarr_dataset: xr.Dataset, +# # ): +# # """Benchmarks write and read performance for the new FlatZarrStore.""" +# # print("\n\n" + "=" * 80) +# # print("🚀 STARTING BENCHMARK for FlatZarrStore") +# # print("=" * 80) + +# # rpc_base_url, gateway_base_url = create_ipfs +# # # rpc_base_url = f"https://ipfs-gateway.dclimate.net" +# # # gateway_base_url = f"https://ipfs-gateway.dclimate.net" +# # # headers = { +# # # "X-API-Key": "", +# # # } +# # headers = {} +# # test_ds = random_zarr_dataset + +# # async with KuboCAS( +# # rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# # ) as kubo_cas: +# # # --- Write --- +# # # The full shape after appending +# # appended_shape = list(test_ds.dims.values()) +# # time_axis_index = list(test_ds.dims).index("time") +# # appended_shape[time_axis_index] *= 2 +# # final_array_shape = tuple(appended_shape) + +# # final_chunk_shape = [] +# # for dim_name in test_ds.dims: # Preserves dimension order +# # if dim_name in test_ds.chunks: +# # # test_ds.chunks[dim_name] is a tuple e.g. (20,) +# # final_chunk_shape.append(test_ds.chunks[dim_name][0]) +# # else: +# # # Fallback if a dimension isn't explicitly chunked (should use its full size) +# # final_chunk_shape.append(test_ds.dims[dim_name]) +# # final_chunk_shape = tuple(final_chunk_shape) + +# # store_write = await FlatZarrStore.open( +# # cas=kubo_cas, +# # read_only=False, +# # array_shape=final_array_shape, +# # chunk_shape=final_chunk_shape, +# # ) + +# # start_write = time.perf_counter() +# # # Perform an initial write and an append +# # test_ds.to_zarr(store=store_write, mode="w") +# # test_ds.to_zarr(store=store_write, mode="a", append_dim="time") +# # root_cid = await store_write.flush() # Flush to get the final CID +# # end_write = time.perf_counter() + +# # print(f"\n--- [FlatZarr] Write Stats ---") +# # print(f"Total time to write and append: {end_write - start_write:.2f} seconds") +# # print(f"Final Root CID: {root_cid}") + +# # # --- Read --- +# # store_read = await FlatZarrStore.open( +# # cas=kubo_cas, read_only=True, root_cid=root_cid +# # ) + +# # start_read = time.perf_counter() +# # ipfs_ds = xr.open_zarr(store=store_read) +# # # Force a read of some data to ensure it's loaded +# # _ = ipfs_ds.temp.isel(time=0).values +# # end_read = time.perf_counter() + +# # print(f"\n--- [FlatZarr] Read Stats ---") +# # print(f"Total time to open and read: {end_read - start_read:.2f} seconds") + +# # # --- Verification --- +# # full_test_ds = xr.concat([test_ds, test_ds], dim="time") +# # xr.testing.assert_identical(full_test_ds, ipfs_ds) +# # print("\n✅ [FlatZarr] Data verification successful.") +# # print("=" * 80) + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_benchmark_sharded_store( # Renamed function +# create_ipfs: tuple[str, str], +# random_shard_dataset: xr.Dataset, +# ): +# """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring +# print("\n\n" + "=" * 80) +# print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print +# print("=" * 80) + +# rpc_base_url, gateway_base_url = create_ipfs + +# rpc_base_url = f"https://ipfs-gateway.dclimate.net" +# gateway_base_url = f"https://ipfs-gateway.dclimate.net" +# headers = { +# "X-API-Key": "", +# } +# # headers = {} +# test_ds = random_shard_dataset + +# # Define chunks_per_shard for the ShardedZarrStore +# chunks_per_shard_config = 50 # Configuration for sharding + +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# ) as kubo_cas: +# # --- Write --- +# # The full shape after appending +# appended_shape = list(test_ds.dims.values()) +# time_axis_index = list(test_ds.dims).index("time") +# appended_shape[time_axis_index] *= 2 # Simulating appending along time dimension +# final_array_shape = tuple(appended_shape) + +# # Determine chunk shape from the dataset's encoding or dimensions +# final_chunk_shape_list = [] +# for dim_name in test_ds.dims: # Preserves dimension order from the dataset +# if dim_name in test_ds.chunks: +# # test_ds.chunks is a dict like {'time': (20,), 'y': (20,), 'x': (20,)} +# final_chunk_shape_list.append(test_ds.chunks[dim_name][0]) +# else: +# # Fallback if a dimension isn't explicitly chunked (should use its full size) +# final_chunk_shape_list.append(test_ds.dims[dim_name]) +# final_chunk_shape = tuple(final_chunk_shape_list) + +# # Use ShardedZarrStore and provide chunks_per_shard +# store_write = await ShardedZarrStore.open( +# cas=kubo_cas, +# read_only=False, +# array_shape=final_array_shape, +# chunk_shape=final_chunk_shape, +# chunks_per_shard=chunks_per_shard_config # Added new parameter +# ) + +# start_write = time.perf_counter() +# # Perform an initial write and an append +# test_ds.to_zarr(store=store_write, mode="w") +# test_ds.to_zarr(store=store_write, mode="a", append_dim="time") +# root_cid = await store_write.flush() # Flush to get the final CID +# end_write = time.perf_counter() + +# print(f"\n--- [ShardedZarr] Write Stats (chunks_per_shard={chunks_per_shard_config}) ---") # Updated print +# print(f"Total time to write and append: {end_write - start_write:.2f} seconds") +# print(f"Final Root CID: {root_cid}") + +# print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print +# # --- Read --- +# # When opening for read, chunks_per_shard is read from the store's metadata +# store_read = await ShardedZarrStore.open( # Use ShardedZarrStore +# cas=kubo_cas, read_only=True, root_cid=root_cid +# ) + +# start_read = time.perf_counter() +# ipfs_ds = xr.open_zarr(store=store_read) +# # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) +# if "precip" in ipfs_ds.variables and "time" in ipfs_ds.coords: +# # _ = ipfs_ds.temp.isel(time=0).values +# data_fetched = ipfs_ds.precip.isel(time=slice(0, 1)).values + +# # Calculate the size of the fetched data +# data_size = data_fetched.nbytes if data_fetched is not None else 0 +# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") +# elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable +# first_var_name = list(ipfs_ds.data_vars.keys())[0] +# # Construct a minimal selection based on available dimensions +# selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} +# if selection: +# _ = ipfs_ds[first_var_name].isel(**selection).values +# else: # If no dimensions, try loading the whole variable (e.g. scalar) +# _ = ipfs_ds[first_var_name].values +# end_read = time.perf_counter() + +# print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print +# print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") + +# # --- Verification --- +# # Create the expected full dataset after append operation +# full_test_ds = xr.concat([test_ds, test_ds], dim="time") +# xr.testing.assert_identical(full_test_ds, ipfs_ds) +# print("\n✅ [ShardedZarr] Data verification successful.") # Updated print +# print("=" * 80) + +# # ### +# # BENCHMARK FOR THE ORIGINAL ZarrHAMTStore +# # ### +# @pytest.mark.asyncio(loop_scope="session") +# async def test_benchmark_hamt_store( +# create_ipfs: tuple[str, str], +# random_zarr_dataset: xr.Dataset, +# ): +# """Benchmarks write and read performance for the ZarrHAMTStore.""" +# print("\n\n" + "=" * 80) +# print("🚀 STARTING BENCHMARK for ZarrHAMTStore") +# print("=" * 80) + +# rpc_base_url, gateway_base_url = create_ipfs + +# # rpc_base_url = f"https://ipfs-gateway.dclimate.net" +# # gateway_base_url = f"https://ipfs-gateway.dclimate.net" +# # headers = { +# # "X-API-Key": "", +# # } +# headers = {} +# test_ds = random_zarr_dataset + +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# ) as kubo_cas: +# # --- Write --- +# print("Building HAMT store...") +# hamt = await HAMT.build(cas=kubo_cas, values_are_bytes=True) +# print("HAMT store built successfully.") +# zhs = ZarrHAMTStore(hamt) +# print("ZarrHAMTStore created successfully.") + +# start_write = time.perf_counter() +# # Perform an initial write and an append to simulate a common workflow +# test_ds.to_zarr(store=zhs, mode="w") +# print("Initial write completed, now appending...") +# test_ds.to_zarr(store=zhs, mode="a", append_dim="time") +# await hamt.make_read_only() # Flush and freeze to get the final CID +# end_write = time.perf_counter() + +# cid: IPLDKind = hamt.root_node_id +# print(f"\n--- [HAMT] Write Stats ---") +# print(f"Total time to write and append: {end_write - start_write:.2f} seconds") +# print(f"Final Root CID: {cid}") + +# # --- Read --- +# hamt_ro = await HAMT.build( +# cas=kubo_cas, root_node_id=cid, values_are_bytes=True, read_only=True +# ) +# zhs_ro = ZarrHAMTStore(hamt_ro, read_only=True) + +# start_read = time.perf_counter() +# ipfs_ds = xr.open_zarr(store=zhs_ro) +# # Force a read of some data to ensure it's loaded +# data_fetched = ipfs_ds.temp.isel(time=slice(0, 1)).values + +# # Calculate the size of the fetched data +# data_size = data_fetched.nbytes if data_fetched is not None else 0 +# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") +# end_read = time.perf_counter() + +# print(f"\n--- [HAMT] Read Stats ---") +# print(f"Total time to open and read: {end_read - start_read:.2f} seconds") + + +# # --- Verification --- +# full_test_ds = xr.concat([test_ds, test_ds], dim="time") +# xr.testing.assert_identical(full_test_ds, ipfs_ds) +# print("\n✅ [HAMT] Data verification successful.") +# print("=" * 80) diff --git a/tests/test_converter.py b/tests/test_converter.py new file mode 100644 index 0000000..64e5799 --- /dev/null +++ b/tests/test_converter.py @@ -0,0 +1,265 @@ +import sys +import time +import uuid +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +# Import store implementations +from py_hamt import ( + HAMT, + KuboCAS, + ShardedZarrStore, +) +from py_hamt.hamt_to_sharded_converter import ( + convert_hamt_to_sharded, + sharded_converter_cli, +) +from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +@pytest.fixture(scope="module") +def converter_test_dataset(): + """ + Creates a random, uniquely-named xarray Dataset specifically for the converter test. + Using a unique variable name helps avoid potential caching issues between test runs. + """ + # A smaller dataset is fine for a verification test + times = pd.date_range("2025-01-01", periods=20) + lats = np.linspace(40, 50, 10) + lons = np.linspace(-85, -75, 20) + + # Generate a unique variable name for this test run + unique_var_name = f"data_{str(uuid.uuid4())[:8]}" + + data = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + {unique_var_name: (["time", "latitude", "longitude"], data)}, + coords={"time": times, "latitude": lats, "longitude": lons}, + attrs={"description": "Test dataset for converter verification."}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 10, "latitude": 10, "longitude": 10}) + yield ds + + +@pytest.mark.asyncio(loop_scope="session") +async def test_converter_produces_identical_dataset( + create_ipfs: tuple[str, str], + converter_test_dataset: xr.Dataset, +): + """ + Tests the hamt_to_sharded_converter by performing a full conversion + and verifying that the resulting dataset is identical to the source. + """ + print("\n\n" + "=" * 80) + print("🚀 STARTING TEST for HAMT to Sharded Converter") + print("=" * 80) + + rpc_base_url, gateway_base_url = create_ipfs + test_ds = converter_test_dataset + chunks_per_shard_config = 64 # A reasonable value for this test size + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # -------------------------------------------------------------------- + # STEP 1: Create the source HAMT store from our test dataset + # -------------------------------------------------------------------- + print("\n--- STEP 1: Creating source HAMT store ---") + hamt_write = await HAMT.build(cas=kubo_cas, values_are_bytes=True) + source_hamt_store = ZarrHAMTStore(hamt_write) + + start_write = time.perf_counter() + test_ds.to_zarr(store=source_hamt_store, mode="w") + await hamt_write.make_read_only() # Flush to get the final CID + end_write = time.perf_counter() + + hamt_root_cid = str(hamt_write.root_node_id) + print(f"Source HAMT store created in {end_write - start_write:.2f}s") + print(f"Source HAMT Root CID: {hamt_root_cid}") + + # -------------------------------------------------------------------- + # STEP 2: Run the conversion script + # -------------------------------------------------------------------- + print("\n--- STEP 2: Running conversion script ---") + sharded_root_cid = await convert_hamt_to_sharded( + cas=kubo_cas, + hamt_root_cid=hamt_root_cid, + chunks_per_shard=chunks_per_shard_config, + ) + print("Conversion script finished.") + print(f"New Sharded Store Root CID: {sharded_root_cid}") + assert sharded_root_cid is not None + + # -------------------------------------------------------------------- + # STEP 3: Verification + # -------------------------------------------------------------------- + print("\n--- STEP 3: Verifying data integrity ---") + + # Open the original dataset from the HAMT store + print("Reading data back from original HAMT store...") + + hamt_ro = await HAMT.build( + cas=kubo_cas, + root_node_id=hamt_root_cid, + values_are_bytes=True, + read_only=True, + ) + zhs_ro = ZarrHAMTStore(hamt_ro, read_only=True) + + start_read = time.perf_counter() + ds_from_hamt = xr.open_zarr(store=zhs_ro) + + end_read = time.perf_counter() + print(f"Original HAMT store read in {end_read - start_read:.2f}s") + + # Open the converted dataset from the new Sharded store + print("Reading data back from new Sharded store...") + dest_store_ro = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=sharded_root_cid + ) + ds_from_sharded = xr.open_zarr(dest_store_ro) + + # The ultimate test: are the two xarray.Dataset objects identical? + # This checks coordinates, variables, data values, and attributes. + print("Comparing the two datasets...") + xr.testing.assert_identical(ds_from_hamt, ds_from_sharded) + # Ask for random samples from both datasets to ensure they match + for var in ds_from_hamt.data_vars: + # Assert all identical + np.testing.assert_array_equal( + ds_from_hamt[var].values, ds_from_sharded[var].values + ) + + print("\n✅ Verification successful! The datasets are identical.") + print("=" * 80) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_hamt_to_sharded_cli_success( + create_ipfs: tuple[str, str], converter_test_dataset: xr.Dataset, capsys +): + """ + Tests the CLI for successful conversion of a HAMT store to a ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = converter_test_dataset + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Step 1: Create a HAMT store with the test dataset + hamt_write = await HAMT.build(cas=kubo_cas, values_are_bytes=True) + source_hamt_store = ZarrHAMTStore(hamt_write) + test_ds.to_zarr(store=source_hamt_store, mode="w", consolidated=True) + await hamt_write.make_read_only() + hamt_root_cid = str(hamt_write.root_node_id) + + # Step 2: Simulate CLI execution with valid arguments + test_args = [ + "script.py", # Dummy script name + hamt_root_cid, + "--chunks-per-shard", + "64", + "--rpc-url", + rpc_base_url, + "--gateway-url", + gateway_base_url, + ] + with patch.object(sys, "argv", test_args): + await sharded_converter_cli() + + # Step 3: Capture and verify CLI output + captured = capsys.readouterr() + assert "Starting Conversion from HAMT Root" in captured.out + assert "Conversion Complete!" in captured.out + assert "New ShardedZarrStore Root CID" in captured.out + + # Step 4: Verify the converted dataset + # Extract the new root CID from output (assuming it's the last line) + output_lines = captured.out.strip().split("\n") + new_root_cid = output_lines[-1].split(": ")[-1] + dest_store_ro = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + ds_from_sharded = xr.open_zarr(dest_store_ro) + xr.testing.assert_identical(test_ds, ds_from_sharded) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_hamt_to_sharded_cli_default_args( + create_ipfs: tuple[str, str], converter_test_dataset: xr.Dataset, capsys +): + """ + Tests the CLI with default argument values. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = converter_test_dataset + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Create a HAMT store + hamt_write = await HAMT.build(cas=kubo_cas, values_are_bytes=True) + source_hamt_store = ZarrHAMTStore(hamt_write) + test_ds.to_zarr(store=source_hamt_store, mode="w", consolidated=True) + await hamt_write.make_read_only() + hamt_root_cid = str(hamt_write.root_node_id) + + # Simulate CLI with only hamt_cid and gateway URLs. + test_args = [ + "script.py", # Dummy script name + hamt_root_cid, + "--rpc-url", + rpc_base_url, + "--gateway-url", + gateway_base_url, + ] + with patch.object(sys, "argv", test_args): + await sharded_converter_cli() + + # Verify output and conversion + captured = capsys.readouterr() + output_lines = captured.out.strip().split("\n") + print("Captured CLI Output:") + for line in output_lines: + print(line) + new_root_cid = output_lines[-1].split(": ")[-1] + dest_store_ro = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + ds_from_sharded = xr.open_zarr(dest_store_ro) + xr.testing.assert_identical(test_ds, ds_from_sharded) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_hamt_to_sharded_cli_invalid_cid(create_ipfs: tuple[str, str], capsys): + """ + Tests the CLI with an invalid hamt_cid. + """ + rpc_base_url, gateway_base_url = create_ipfs + invalid_cid = "invalid_cid" + + test_args = [ + "script.py", + invalid_cid, + "--chunks-per-shard", + "64", + "--rpc-url", + rpc_base_url, + "--gateway-url", + gateway_base_url, + ] + with patch.object(sys, "argv", test_args): + await sharded_converter_cli() + + # Verify error handling + captured = capsys.readouterr() + assert "An error occurred" in captured.out + assert f"{invalid_cid}" in captured.out diff --git a/tests/test_cpc_compare.py b/tests/test_cpc_compare.py new file mode 100644 index 0000000..3680591 --- /dev/null +++ b/tests/test_cpc_compare.py @@ -0,0 +1,161 @@ +# import time + +# import numpy as np +# import pandas as pd +# import pytest +# import xarray as xr +# from dag_cbor.ipld import IPLDKind +# from multiformats import CID + +# # Import both store implementations +# from py_hamt import HAMT, KuboCAS, ShardedZarrStore +# from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_benchmark_sharded_store(): +# """Benchmarks write and read performance for the new ShardedZarrStore.""" # Updated docstring +# print("\n\n" + "=" * 80) +# print("🚀 STARTING BENCHMARK for ShardedZarrStore") # Updated print +# print("=" * 80) + + +# rpc_base_url = f"https://ipfs-gateway.dclimate.net" +# gateway_base_url = f"https://ipfs-gateway.dclimate.net" +# headers = { +# "X-API-Key": "", +# } + +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# ) as kubo_cas: +# # --- Write --- +# # root_cid = "bafyr4ifayhevbtfg2qzffuicic3rwli4fhnnkhrfduuxkwvetppfk4ogbe" +# root_cid = "bafyr4ifs4oejlvtwvb57udbbhba5yllss4cixjkxrevq54g3mo5kwknpwy" +# print(f"\n--- [ShardedZarr] STARTING READ ---") # Updated print +# # --- Read --- +# start = time.perf_counter() +# # When opening for read, chunks_per_shard is read from the store's metadata +# store_read = await ShardedZarrStore.open( # Use ShardedZarrStore +# cas=kubo_cas, read_only=True, root_cid=root_cid +# ) +# stop = time.perf_counter() +# print(f"Total time to open ShardedZarrStore: {stop - start:.2f} seconds") +# print(f"Opened ShardedZarrStore for reading with root CID: {root_cid}") + +# start_read = time.perf_counter() +# ipfs_ds = xr.open_zarr(store=store_read) +# start_read = time.perf_counter() +# print(ipfs_ds) +# stop_read = time.perf_counter() +# print(f"Total time to read dataset: {stop_read - start_read:.2f} seconds") +# # start_read = time.perf_counter() +# # print(ipfs_ds.variables, ipfs_ds.coords) # Print available variables and coordinates for debugging +# # stop_read = time.perf_counter() +# # print(f"Total time to read dataset variables and coordinates: {stop_read - start_read:.2f} seconds") +# start_read = time.perf_counter() +# # Force a read of some data to ensure it's loaded (e.g., first time slice of 'temp' variable) +# if "FPAR" in ipfs_ds.variables and "time" in ipfs_ds.coords: +# print("Fetching 'FPAR' data...") + +# # Define date range +# date_from = "2000-05-15" +# date_to = "2004-05-30" + +# # Define bounding box from polygon coordinates +# min_lon, max_lon = 4.916695, 5.258908 +# min_lat, max_lat = 51.921763, 52.160344 + +# print(ipfs_ds["FPAR"].sel( +# time=slice(date_from, date_to), +# latitude=slice(min_lat, max_lat), +# longitude=slice(min_lon, max_lon) +# )) + +# # Fetch data for the specified time and region +# data_fetched = ipfs_ds["FPAR"].sel( +# time=slice(date_from, date_to), +# latitude=slice(min_lat, max_lat), +# longitude=slice(min_lon, max_lon) +# ).values + +# # Calculate the size of the fetched data +# data_size = data_fetched.nbytes if data_fetched is not None else 0 +# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") +# elif len(ipfs_ds.data_vars) > 0 : # Fallback: try to read from the first data variable +# first_var_name = list(ipfs_ds.data_vars.keys())[0] +# # Construct a minimal selection based on available dimensions +# selection = {dim: 0 for dim in ipfs_ds[first_var_name].dims} +# if selection: +# _ = ipfs_ds[first_var_name].isel(**selection).values +# else: # If no dimensions, try loading the whole variable (e.g. scalar) +# _ = ipfs_ds[first_var_name].values +# end_read = time.perf_counter() + +# print(f"\n--- [ShardedZarr] Read Stats ---") # Updated print +# print(f"Total time to open and read some data: {end_read - start_read:.2f} seconds") +# print("=" * 80) +# # Speed in MB/s +# if data_size > 0: +# speed = data_size / (end_read - start_read) / (1024 * 1024) +# print(f"Read speed: {speed:.2f} MB/s") +# else: +# print("No data fetched, cannot calculate speed.") + +# # ### +# # BENCHMARK FOR THE ORIGINAL ZarrHAMTStore +# # ### +# @pytest.mark.asyncio(loop_scope="session") +# async def test_benchmark_hamt_store(): +# """Benchmarks write and read performance for the ZarrHAMTStore.""" +# print("\n\n" + "=" * 80) +# print("🚀 STARTING BENCHMARK for ZarrHAMTStore") +# print("=" * 80) + +# rpc_base_url = f"https://ipfs-gateway.dclimate.net/" +# gateway_base_url = f"https://ipfs-gateway.dclimate.net/" +# # headers = { +# # "X-API-Key": "", +# # } +# headers = {} + +# async with KuboCAS( +# rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url, headers=headers +# ) as kubo_cas: + +# root_cid = "bafyr4igl3pmswu5pfzb6dcgcxj3ipxlpxxxad7j7tf45obxe5pkp4xgpwe" +# # root_node_id = CID.decode(root_cid) + +# hamt = await HAMT.build( +# cas=kubo_cas, root_node_id=root_cid, values_are_bytes=True, read_only=True +# ) +# start = time.perf_counter() +# ipfs_ds: xr.Dataset +# zhs = ZarrHAMTStore(hamt, read_only=True) +# ipfs_ds = xr.open_zarr(store=zhs) +# print(ipfs_ds) + +# # --- Read --- +# hamt = HAMT(cas=kubo_cas, values_are_bytes=True, root_node_id=root_cid, read_only=True) + +# # Initialize the store +# zhs = ZarrHAMTStore(hamt, read_only=True) + +# start_read = time.perf_counter() +# ipfs_ds = xr.open_zarr(store=zhs) +# # Force a read of some data to ensure it's loaded +# data_fetched = ipfs_ds.precip.values + +# # Calculate the size of the fetched data +# data_size = data_fetched.nbytes if data_fetched is not None else 0 +# print(f"Fetched data size: {data_size / (1024 * 1024):.4f} MB") +# end_read = time.perf_counter() + +# print(f"\n--- [HAMT] Read Stats ---") +# print(f"Total time to open and read: {end_read - start_read:.2f} seconds") + +# if data_size > 0: +# speed = data_size / (end_read - start_read) / (1024 * 1024) +# print(f"Read speed: {speed:.2f} MB/s") +# else: +# print("No data fetched, cannot calculate speed.") diff --git a/tests/test_kubo_pin.py b/tests/test_kubo_pin.py new file mode 100644 index 0000000..703dc92 --- /dev/null +++ b/tests/test_kubo_pin.py @@ -0,0 +1,62 @@ +import dag_cbor +import pytest + +from py_hamt import KuboCAS + + +@pytest.mark.asyncio(loop_scope="session") +async def test_pinning(create_ipfs, global_client_session): + """ + Tests pinning a CID using KuboCAS with explicit URLs. + Verifies that a CID can be pinned and is retrievable after pinning. + """ + rpc_url, gateway_url = create_ipfs + + async with KuboCAS( + rpc_base_url=rpc_url, + gateway_base_url=gateway_url, + client=global_client_session, + ) as kubo_cas: + # Save data to get a CID + data = b"test1" + encoded_data = dag_cbor.encode(data) + cid = await kubo_cas.save(encoded_data, codec="dag-cbor") + # Pin the CID + await kubo_cas.pin_cid(cid, target_rpc=rpc_url) + listed_pins = await kubo_cas.pin_ls(target_rpc=rpc_url) + # Verify the CID is pinned by querying the pin list + assert str(cid) in listed_pins, f"CID {cid} was not pinned" + + # Unpine the CID + await kubo_cas.unpin_cid(cid, target_rpc=rpc_url) + + # Verify the CID is no longer pinned + listed_pins_after_unpin = await kubo_cas.pin_ls(target_rpc=rpc_url) + assert str(cid) not in listed_pins_after_unpin, f"CID {cid} was not unpinned" + + # Pin again, then perform a pin update + await kubo_cas.pin_cid(cid, target_rpc=rpc_url) + + data = b"test2" + encoded_data = dag_cbor.encode(data) + new_cid = await kubo_cas.save(encoded_data, codec="dag-cbor") + + # Update the pinned CID + await kubo_cas.pin_update(cid, new_cid, target_rpc=rpc_url) + + # Verify the old CID is no longer pinned and the new CID is pinned + listed_pins_after_update = await kubo_cas.pin_ls(target_rpc=rpc_url) + assert str(cid) not in listed_pins_after_update, ( + f"Old CID {cid} was not unpinned after update" + ) + assert str(new_cid) in listed_pins_after_update, ( + f"New CID {new_cid} was not pinned after update" + ) + + # unpin the new CID + await kubo_cas.unpin_cid(new_cid, target_rpc=rpc_url) + # Verify the new CID is no longer pinned + listed_pins_after_unpin_update = await kubo_cas.pin_ls(target_rpc=rpc_url) + assert str(new_cid) not in listed_pins_after_unpin_update, ( + f"New CID {new_cid} was not unpinned after update" + ) diff --git a/tests/test_sharded_store_deleting.py b/tests/test_sharded_store_deleting.py new file mode 100644 index 0000000..88222a8 --- /dev/null +++ b/tests/test_sharded_store_deleting.py @@ -0,0 +1,350 @@ +import asyncio +import json + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +import zarr.core.buffer + +from py_hamt import KuboCAS, ShardedZarrStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset for benchmarking.""" + # Using a slightly larger dataset for a more meaningful benchmark + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + temp = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], temp), + }, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) + yield ds + + +@pytest.mark.asyncio +async def test_delete_chunk_success(create_ipfs: tuple[str, str]): + """Tests successful deletion of a chunk from the store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Write a chunk + chunk_key = "temp/c/0/0" + chunk_data = b"test_chunk_data" + proto = zarr.core.buffer.default_buffer_prototype() + await store.set(chunk_key, proto.buffer.from_bytes(chunk_data)) + assert await store.exists(chunk_key) + + # Delete the chunk + await store.delete(chunk_key) + + # Verify chunk is deleted in cache and shard is marked dirty + linear_index = store._get_linear_chunk_index((0, 0)) + shard_idx, index_in_shard = store._get_shard_info(linear_index) + target_shard_list = await store._load_or_initialize_shard_cache(shard_idx) + assert target_shard_list[index_in_shard] is None + assert shard_idx in store._shard_data_cache._dirty_shards + + # Flush and verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert not await store_read.exists(chunk_key) + assert await store_read.get(chunk_key, proto) is None + + +@pytest.mark.asyncio +async def test_delete_metadata_success(create_ipfs: tuple[str, str]): + """Tests successful deletion of a metadata key.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Write metadata + metadata_key = "temp/zarr.json" + metadata = json.dumps({"shape": [20, 20], "dtype": "float32"}).encode("utf-8") + proto = zarr.core.buffer.default_buffer_prototype() + await store.set(metadata_key, proto.buffer.from_bytes(metadata)) + assert await store.exists(metadata_key) + + # Delete metadata + await store.delete(metadata_key) + + # Verify metadata is deleted and root is marked dirty + assert metadata_key not in store._root_obj["metadata"] + assert store._dirty_root is True + + # Flush and verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert not await store_read.exists(metadata_key) + assert await store_read.get(metadata_key, proto) is None + + +@pytest.mark.asyncio +async def test_delete_nonexistent_key(create_ipfs: tuple[str, str]): + """Tests deletion of a nonexistent metadata key.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Temp write to temp/c/0/0 to ensure it exists + proto = zarr.core.buffer.default_buffer_prototype() + await store.set("temp/c/0/0", proto.buffer.from_bytes(b"test_data")) + + # flush it + await store.flush() + assert not store._shard_data_cache._dirty_shards # No dirty shards after flush + + # Try to delete nonexistent metadata key + with pytest.raises(KeyError, match="Metadata key 'nonexistent.json' not found"): + await store.delete("nonexistent.json") + + # Try to delete nonexistent chunk key (out of bounds) + with pytest.raises(IndexError, match="Chunk coordinate"): + await store.delete("temp/c/3/0") # Out of bounds for 2x2 chunk grid + + # Try to delete nonexistent chunk key (within bounds but not set) + await store.delete("temp/c/0/0") # Should not raise, as it sets to None + assert not await store.exists("temp/c/0/0") + assert ( + store._shard_data_cache._dirty_shards + ) # Shard is marked dirty even if chunk was already None + + +@pytest.mark.asyncio +async def test_delete_read_only_store(create_ipfs: tuple[str, str]): + """Tests deletion attempt on a read-only store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize writable store and add data + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + chunk_key = "temp/c/0/0" + proto = zarr.core.buffer.default_buffer_prototype() + await store_write.set(chunk_key, proto.buffer.from_bytes(b"test_data")) + metadata_key = "temp/zarr.json" + await store_write.set( + metadata_key, proto.buffer.from_bytes(b'{"shape": [20, 20]}') + ) + root_cid = await store_write.flush() + + # Open as read-only + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + + # Try to delete chunk + with pytest.raises( + PermissionError, match="Cannot delete from a read-only store" + ): + await store_read_only.delete(chunk_key) + + # Try to delete metadata + with pytest.raises( + PermissionError, match="Cannot delete from a read-only store" + ): + await store_read_only.delete(metadata_key) + + +@pytest.mark.asyncio +async def test_delete_concurrency(create_ipfs: tuple[str, str]): + """Tests concurrent delete operations to ensure shard locking works.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Write multiple chunks + proto = zarr.core.buffer.default_buffer_prototype() + chunk_keys = ["temp/c/0/0", "temp/c/1/0", "temp/c/0/1"] + for key in chunk_keys: + await store.set(key, proto.buffer.from_bytes(f"data_{key}".encode("utf-8"))) + assert await store.exists(key) + + # Define concurrent delete tasks + async def delete_task(key): + await store.delete(key) + + # Run concurrent deletes + tasks = [delete_task(key) for key in chunk_keys] + await asyncio.gather(*tasks) + + # Verify all chunks are deleted + for key in chunk_keys: + assert not await store.exists(key) + assert await store.get(key, proto) is None + + # Verify shards are marked dirty + assert ( + store._shard_data_cache._dirty_shards + ) # At least one shard should be dirty + + # Flush and verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + for key in chunk_keys: + assert not await store_read.exists(key) + assert await store_read.get(key, proto) is None + + +@pytest.mark.asyncio +async def test_delete_with_dataset( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """Tests deletion of chunks and metadata in a store with a full dataset.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Write dataset + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store, mode="w") + root_cid = await store.flush() + + # Re-open store + store = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + + # Delete a chunk + chunk_key = "temp/c/0/0/0" + assert await store.exists(chunk_key) + await store.delete(chunk_key) + assert not await store.exists(chunk_key) + + # Delete metadata + metadata_key = "temp/zarr.json" + assert await store.exists(metadata_key) + await store.delete(metadata_key) + assert not await store.exists(metadata_key) + + # Flush and verify + new_root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + assert not await store_read.exists(chunk_key) + assert not await store_read.exists(metadata_key) + + # Verify other data remains intact + other_chunk_key = "temp/c/1/0/0" + assert await store_read.exists(other_chunk_key) + + +@pytest.mark.asyncio +async def test_supports_writes_property(create_ipfs: tuple[str, str]): + """Tests the supports_writes property.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test writable store + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + assert store_write.supports_writes is True + + # Test read-only store + root_cid = await store_write.flush() + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read_only.supports_writes is False + + +@pytest.mark.asyncio +async def test_supports_partial_writes_property(create_ipfs: tuple[str, str]): + """Tests the supports_partial_writes property.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test for both read-only and writable stores + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + assert store_write.supports_partial_writes is False + + root_cid = await store_write.flush() + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read_only.supports_partial_writes is False diff --git a/tests/test_sharded_store_grafting.py b/tests/test_sharded_store_grafting.py new file mode 100644 index 0000000..59acfd3 --- /dev/null +++ b/tests/test_sharded_store_grafting.py @@ -0,0 +1,423 @@ +import asyncio + +import dag_cbor +import numpy as np +import pandas as pd +import pytest +import xarray as xr +import zarr.core.buffer + +from py_hamt import KuboCAS, ShardedZarrStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset for benchmarking.""" + # Using a slightly larger dataset for a more meaningful benchmark + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + temp = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], temp), + }, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) + yield ds + + +@pytest.mark.asyncio +async def test_graft_store_success(create_ipfs: tuple[str, str]): + """Tests successful grafting of a source store onto a target store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize source store + source_shape = (20, 20) + chunk_shape = (10, 10) + chunks_per_shard = 4 + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=source_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + + # Write some chunk data to source store + proto = zarr.core.buffer.default_buffer_prototype() + chunk_key = "temp/c/0/0" + chunk_data = b"test_source_data" + await source_store.set(chunk_key, proto.buffer.from_bytes(chunk_data)) + source_root_cid = await source_store.flush() + + # Initialize target store with larger shape to accommodate graft + target_shape = (40, 20) + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=target_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + + # Graft source store onto target store at offset (1, 0) + chunk_offset = (1, 0) + await target_store.graft_store(source_root_cid, chunk_offset) + + # Verify grafted data + grafted_chunk_key = "temp/c/1/0" # Offset (1,0) corresponds to chunk (1,0) + assert await target_store.exists(grafted_chunk_key) + grafted_data = await target_store.get(grafted_chunk_key, proto) + assert grafted_data is not None + assert grafted_data.to_bytes() == chunk_data + + # Verify original chunk position in source store is not present in target + assert not await target_store.exists("temp/c/0/0") + + # Verify target store's geometry unchanged + assert target_store._array_shape == target_shape + assert target_store._chunks_per_dim == (4, 2) # ceil(40/10) = 4 + assert target_store._total_chunks == 8 # 4 * 2 + assert target_store._num_shards == 2 # ceil(8/4) = 2 + assert ( + target_store._shard_data_cache._dirty_shards + ) # Grafting marks shards as dirty + + # Flush and verify persistence + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + assert await target_store_read.exists(grafted_chunk_key) + read_data = await target_store_read.get(grafted_chunk_key, proto) + assert read_data is not None + assert read_data.to_bytes() == chunk_data + assert target_store_read._array_shape == target_shape + + +@pytest.mark.asyncio +async def test_graft_store_with_dataset( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """Tests grafting a store containing a full dataset.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Write dataset to source store + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=source_store, mode="w") + source_root_cid = await source_store.flush() + + # Initialize target store with larger shape + target_shape = ( + array_shape_tuple[0] + 20, + array_shape_tuple[1], + array_shape_tuple[2], + ) + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=target_shape, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + + # Graft source store at offset (1, 0, 0) + chunk_offset = (1, 0, 0) + await target_store.graft_store(source_root_cid, chunk_offset) + + # Verify grafted chunk data + proto = zarr.core.buffer.default_buffer_prototype() + source_chunk_key = "temp/c/0/0/0" + target_chunk_key = "temp/c/1/0/0" # Offset by 1 in time dimension + assert await target_store.exists(target_chunk_key) + source_data = await source_store.get(source_chunk_key, proto) + target_data = await target_store.get(target_chunk_key, proto) + assert source_data is not None + assert target_data is not None + assert source_data.to_bytes() == target_data.to_bytes() + + # Verify metadata is not grafted + assert not await target_store.exists("temp/zarr.json") + + # Flush and verify persistence + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + assert await target_store_read.exists(target_chunk_key) + read_data = await target_store_read.get(target_chunk_key, proto) + assert read_data is not None + assert read_data.to_bytes() == source_data.to_bytes() + + +@pytest.mark.asyncio +async def test_graft_store_empty_source(create_ipfs: tuple[str, str]): + """Tests grafting an empty source store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize empty source store + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + source_root_cid = await source_store.flush() + + # Initialize target store + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(40, 40), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Graft empty source store + await target_store.graft_store(source_root_cid, chunk_offset=(1, 1)) + + # Verify no chunks were grafted + assert not await target_store.exists("temp/c/1/1") + assert ( + not target_store._shard_data_cache._dirty_shards + ) # No shards marked dirty since no changes + + # Flush and verify + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + assert not await target_store_read.exists("temp/c/1/1") + + +@pytest.mark.asyncio +async def test_graft_store_invalid_cases(create_ipfs: tuple[str, str]): + """Tests error handling in graft_store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize target store + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(40, 40), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Test read-only target store + target_store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=True, + root_cid=await target_store.flush(), + ) + with pytest.raises( + PermissionError, match="Cannot graft onto a read-only store" + ): + await target_store_read_only.graft_store("some_cid", chunk_offset=(0, 0)) + + # Test invalid source CID + # invalid_cid = "invalid_cid" + # with pytest.raises(ValueError, match="Store to graft could not be loaded"): + # await target_store.graft_store(invalid_cid, chunk_offset=(0, 0)) + + # Test source store with invalid configuration + invalid_root_obj = { + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": [10, 10], + "chunk_shape": [5, 5], + "sharding_config": {"chunks_per_shard": 4}, + "shard_cids": [None] * 4, + }, + } + invalid_root_cid = await kubo_cas.save( + dag_cbor.encode(invalid_root_obj), codec="dag-cbor" + ) + with pytest.raises(ValueError, match="Inconsistent number of shards"): + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_root_cid + ) + # source_store._chunks_per_dim = None # Simulate unconfigured store + # with pytest.raises(ValueError, match="Inconsistent number of shards"): + # await target_store.graft_store(invalid_root_cid, chunk_offset=(0, 0)) + + # Test grafting out-of-bounds offset + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + # Write some data to source store + proto = zarr.core.buffer.default_buffer_prototype() + await source_store.set("temp/c/0/0", proto.buffer.from_bytes(b"data")) + source_root_cid = await source_store.flush() + with pytest.raises(ValueError, match="Shard index 10 out of bounds."): + await target_store.graft_store( + source_root_cid, chunk_offset=(10, 0) + ) # Out of bounds for target (4x4 chunks) + + +@pytest.mark.asyncio +async def test_graft_store_concurrency(create_ipfs: tuple[str, str]): + """Tests concurrent graft_store operations to ensure shard locking works.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize source stores + source_shape = (20, 20) + chunk_shape = (10, 10) + chunks_per_shard = 4 + source_store1 = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=source_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + proto = zarr.core.buffer.default_buffer_prototype() + await source_store1.set("temp/c/0/0", proto.buffer.from_bytes(b"data1")) + source_cid1 = await source_store1.flush() + + source_store2 = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=source_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + await source_store2.set("temp/c/0/0", proto.buffer.from_bytes(b"data2")) + source_cid2 = await source_store2.flush() + + # Initialize target store + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(40, 40), + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + + # Define graft tasks + async def graft_task(cid, offset): + await target_store.graft_store(cid, chunk_offset=offset) + + # Run concurrent grafts + tasks = [ + graft_task(source_cid1, (1, 1)), + graft_task(source_cid2, (2, 2)), + ] + await asyncio.gather(*tasks) + + # Verify grafted data + assert await target_store.exists("temp/c/1/1") + assert await target_store.exists("temp/c/2/2") + data1 = await target_store.get("temp/c/1/1", proto) + data2 = await target_store.get("temp/c/2/2", proto) + assert data1 is not None + assert data2 is not None + assert data1.to_bytes() in [b"data1", b"data2"] + assert data2.to_bytes() in [b"data1", b"data2"] + assert data1.to_bytes() != data2.to_bytes() # Ensure distinct data + + # Flush and verify persistence + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + assert await target_store_read.exists("temp/c/1/1") + assert await target_store_read.exists("temp/c/2/2") + + +@pytest.mark.asyncio +async def test_graft_store_overlapping_chunks(create_ipfs: tuple[str, str]): + """Tests grafting when target already has data at some chunk positions.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize source store + source_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + proto = zarr.core.buffer.default_buffer_prototype() + source_chunk_key = "temp/c/0/0" + source_data = b"source_data" + await source_store.set(source_chunk_key, proto.buffer.from_bytes(source_data)) + source_root_cid = await source_store.flush() + + # Initialize target store with some existing data + target_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(40, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + target_chunk_key = "temp/c/1/1" + existing_data = b"existing_data" + await target_store.set(target_chunk_key, proto.buffer.from_bytes(existing_data)) + + # Graft source store at offset (1, 0) + await target_store.graft_store(source_root_cid, chunk_offset=(1, 0)) + + # Verify that existing data was not overwritten + read_data = await target_store.get(target_chunk_key, proto) + assert read_data is not None + assert read_data.to_bytes() == existing_data + assert ( + target_store._shard_data_cache._dirty_shards + ) # Shard is marked dirty due to attempted write + + # Verify other grafted chunks + grafted_chunk_key = "temp/c/1/0" # Corresponds to source (0,0) at offset (1,0) + assert await target_store.exists(grafted_chunk_key) + read_data = await target_store.get(grafted_chunk_key, proto) + assert read_data is not None + assert read_data.to_bytes() == source_data + + # Flush and verify + target_root_cid = await target_store.flush() + target_store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=target_root_cid + ) + targeted_target_chunk = await target_store_read.get(target_chunk_key, proto) + assert targeted_target_chunk is not None + assert targeted_target_chunk.to_bytes() == existing_data + + grafted_chunk_data = await target_store_read.get(grafted_chunk_key, proto) + assert grafted_chunk_data is not None + assert grafted_chunk_data.to_bytes() == source_data diff --git a/tests/test_sharded_store_resizing.py b/tests/test_sharded_store_resizing.py new file mode 100644 index 0000000..ce43465 --- /dev/null +++ b/tests/test_sharded_store_resizing.py @@ -0,0 +1,440 @@ +import asyncio +import json +import math + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +import zarr.core.buffer + +from py_hamt import KuboCAS, ShardedZarrStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset for benchmarking.""" + # Using a slightly larger dataset for a more meaningful benchmark + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + temp = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], temp), + }, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) + yield ds + + +@pytest.mark.asyncio +async def test_resize_store_success(create_ipfs: tuple[str, str]): + """Tests successful resizing of the store's main shard index.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + initial_shape = (20, 20) + chunk_shape = (10, 10) + chunks_per_shard = 4 + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=initial_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + + # Verify initial geometry + assert store._array_shape == initial_shape + assert store._chunk_shape == chunk_shape + assert store._chunks_per_shard == chunks_per_shard + initial_chunks_per_dim = (2, 2) # ceil(20/10) = 2 + assert store._chunks_per_dim == initial_chunks_per_dim + assert store._total_chunks == 4 # 2 * 2 + assert store._num_shards == 1 # ceil(4/4) = 1 + assert len(store._root_obj["chunks"]["shard_cids"]) == 1 + + # Resize to a larger shape + new_shape = (30, 30) + await store.resize_store(new_shape=new_shape) + + # Verify updated geometry + assert store._array_shape == new_shape + assert store._chunks_per_dim == (3, 3) # ceil(30/10) = 3 + assert store._total_chunks == 9 # 3 * 3 + assert store._num_shards == 3 # ceil(9/4) = 3 + assert len(store._root_obj["chunks"]["shard_cids"]) == 3 + assert store._root_obj["chunks"]["array_shape"] == list(new_shape) + assert store._dirty_root is True + + # Verify shard cids extended correctly + assert store._root_obj["chunks"]["shard_cids"][1] is None + assert store._root_obj["chunks"]["shard_cids"][2] is None + + # Flush and verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read._array_shape == new_shape + assert store_read._num_shards == 3 + assert len(store_read._root_obj["chunks"]["shard_cids"]) == 3 + + # Resize to a smaller shape + smaller_shape = (10, 10) + await store.resize_store(new_shape=smaller_shape) + assert store._array_shape == smaller_shape + assert store._chunks_per_dim == (1, 1) # ceil(10/10) = 1 + assert store._total_chunks == 1 # 1 * 1 + assert store._num_shards == 1 # ceil(1/4) = 1 + assert len(store._root_obj["chunks"]["shard_cids"]) == 1 + assert store._dirty_root is True + + # Flush and verify + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read._array_shape == smaller_shape + assert store_read._num_shards == 1 + + +@pytest.mark.asyncio +async def test_resize_store_zero_sized_array(create_ipfs: tuple[str, str]): + """Tests resizing to/from a zero-sized array.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize with zero-sized dimension + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 0), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + assert store._total_chunks == 0 + assert store._num_shards == 0 + assert len(store._root_obj["chunks"]["shard_cids"]) == 0 + + # Resize to non-zero shape + new_shape = (20, 20) + await store.resize_store(new_shape=new_shape) + assert store._array_shape == new_shape + assert store._chunks_per_dim == (2, 2) + assert store._total_chunks == 4 + assert store._num_shards == 1 + assert len(store._root_obj["chunks"]["shard_cids"]) == 1 + assert store._dirty_root is True + + # Resize back to zero-sized + zero_shape = (0, 20) + await store.resize_store(new_shape=zero_shape) + assert store._array_shape == zero_shape + assert store._total_chunks == 0 + assert store._num_shards == 0 + assert len(store._root_obj["chunks"]["shard_cids"]) == 0 + + # Verify persistence + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read._array_shape == zero_shape + assert store_read._num_shards == 0 + + +@pytest.mark.asyncio +async def test_resize_store_invalid_cases(create_ipfs: tuple[str, str]): + """Tests error handling in resize_store.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Test read-only store + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=True, + root_cid=await store.flush(), + ) + with pytest.raises(PermissionError, match="Cannot resize a read-only store"): + await store_read_only.resize_store(new_shape=(30, 30)) + + # Test wrong number of dimensions + with pytest.raises( + ValueError, + match="New shape must have the same number of dimensions as the old shape", + ): + await store.resize_store(new_shape=(30, 30, 30)) + + # Test uninitialized store (simulate by setting attributes to None) + store._chunk_shape = None # type: ignore + store._chunks_per_shard = None # type: ignore + with pytest.raises( + RuntimeError, match="Store is not properly initialized for resizing" + ): + await store.resize_store(new_shape=(30, 30)) + + +@pytest.mark.asyncio +async def test_resize_variable_success( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """Tests successful resizing of a variable's metadata.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Write dataset + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store, mode="w") + root_cid = await store.flush() + + # Re-open store + store = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + variable_name = "temp" + zarr_metadata_key = f"{variable_name}/zarr.json" + assert zarr_metadata_key in store._root_obj["metadata"] + + # Resize variable + new_shape = (150, 18, 36) # Extend time dimension + await store.resize_variable(variable_name=variable_name, new_shape=new_shape) + + # Verify metadata updated + new_metadata_cid = store._root_obj["metadata"][zarr_metadata_key] + new_metadata_bytes = await kubo_cas.load(new_metadata_cid) + new_metadata = json.loads(new_metadata_bytes) + assert new_metadata["shape"] == list(new_shape) + assert store._dirty_root is True + + # Verify store's main array shape unchanged + assert store._array_shape == array_shape_tuple + + # Flush and verify persistence + new_root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + read_metadata_cid = store_read._root_obj["metadata"][zarr_metadata_key] + read_metadata_bytes = await kubo_cas.load(read_metadata_cid) + read_metadata = json.loads(read_metadata_bytes) + assert read_metadata["shape"] == list(new_shape) + + +@pytest.mark.asyncio +async def test_resize_variable_invalid_cases( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """Tests error handling in resize_variable.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Write dataset + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store, mode="w") + root_cid = await store.flush() + + # Re-open store + store = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + + # Test read-only store + store_read_only = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + with pytest.raises(PermissionError, match="Cannot resize a read-only store"): + await store_read_only.resize_variable("temp", new_shape=(150, 18, 36)) + + # Test non-existent variable + with pytest.raises( + KeyError, + match="Cannot find metadata for key 'nonexistent/zarr.json' to resize", + ): + await store.resize_variable("nonexistent", new_shape=(150, 18, 36)) + + # Test invalid metadata (simulate by setting invalid metadata) + invalid_metadata = json.dumps({"not_shape": [1, 2, 3]}).encode("utf-8") + invalid_cid = await kubo_cas.save(invalid_metadata, codec="raw") + store._root_obj["metadata"]["invalid/zarr.json"] = invalid_cid + with pytest.raises(ValueError, match="Shape not found in metadata"): + await store.set( + "invalid/zarr.json", + zarr.core.buffer.default_buffer_prototype().buffer.from_bytes( + invalid_metadata + ), + ) + + +@pytest.mark.asyncio +async def test_resize_store_with_data_preservation(create_ipfs: tuple[str, str]): + """Tests that resizing the store preserves existing data.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Write a chunk + chunk_key = "temp/c/0/0" + chunk_data = b"test_data" + proto = zarr.core.buffer.default_buffer_prototype() + await store.set(chunk_key, proto.buffer.from_bytes(chunk_data)) + root_cid = await store.flush() + + # Verify chunk exists + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert await store_read.exists(chunk_key) + read_chunk = await store_read.get(chunk_key, proto) + assert read_chunk is not None + assert read_chunk.to_bytes() == chunk_data + + # Resize store + store_write = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + new_shape = (30, 30) + await store_write.resize_store(new_shape=new_shape) + new_root_cid = await store_write.flush() + + # Verify chunk still exists + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + assert await store_read.exists(chunk_key) + read_chunk = await store_read.get(chunk_key, proto) + assert read_chunk is not None + assert read_chunk.to_bytes() == chunk_data + assert store_read._array_shape == new_shape + assert store_read._num_shards == 3 # ceil((3*3)/4) = 3 + + +@pytest.mark.asyncio +async def test_resize_store_in_set_method(create_ipfs: tuple[str, str]): + """Tests that setting zarr.json triggers resize_store appropriately.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Set zarr.json with a new shape + new_shape = [30, 30] + metadata = json.dumps({"shape": new_shape, "dtype": "float32"}).encode("utf-8") + proto = zarr.core.buffer.default_buffer_prototype() + await store.set("temp/zarr.json", proto.buffer.from_bytes(metadata)) + + # Verify resize occurred + assert store._array_shape == tuple(new_shape) + assert store._chunks_per_dim == (3, 3) + assert store._total_chunks == 9 + assert store._num_shards == 3 + assert store._root_obj["chunks"]["array_shape"] == new_shape + assert len(store._root_obj["chunks"]["shard_cids"]) == 3 + + # Verify metadata stored + assert "temp/zarr.json" in store._root_obj["metadata"] + root_cid = await store.flush() + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + metadata_buffer = await store_read.get("temp/zarr.json", proto) + assert metadata_buffer is not None + assert json.loads(metadata_buffer.to_bytes())["shape"] == new_shape + + +@pytest.mark.asyncio +async def test_resize_concurrency(create_ipfs: tuple[str, str]): + """Tests concurrent resize_store operations to ensure locking works.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Define multiple resize tasks + async def resize_task(shape): + await store.resize_store(new_shape=shape) + + # Run concurrent resize operations + tasks = [ + resize_task((30, 30)), + resize_task((40, 40)), + resize_task((50, 50)), + ] + await asyncio.gather(*tasks) + + # Verify final state (last resize should win, but all are safe due to locking) + assert store._array_shape in [(30, 30), (40, 40), (50, 50)] + expected_chunks_per_dim = tuple(math.ceil(s / 10) for s in store._array_shape) + assert store._chunks_per_dim == expected_chunks_per_dim + assert store._total_chunks == math.prod(expected_chunks_per_dim) + assert store._num_shards == math.ceil(store._total_chunks / 4) + assert len(store._root_obj["chunks"]["shard_cids"]) == store._num_shards + assert store._dirty_root is True diff --git a/tests/test_sharded_zarr_store.py b/tests/test_sharded_zarr_store.py new file mode 100644 index 0000000..ea41734 --- /dev/null +++ b/tests/test_sharded_zarr_store.py @@ -0,0 +1,1432 @@ +import asyncio + +import dag_cbor +import numpy as np +import pandas as pd +import pytest +import xarray as xr +import zarr.core.buffer +from multiformats import CID +from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest + +from py_hamt import KuboCAS, ShardedZarrStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset for benchmarking.""" + # Using a slightly larger dataset for a more meaningful benchmark + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + temp = np.random.randn(len(times), len(lats), len(lons)) + + ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], temp), + }, + coords={"time": times, "lat": lats, "lon": lons}, + ) + + # Define chunking for the store + ds = ds.chunk({"time": 20, "lat": 18, "lon": 36}) + yield ds + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_write_read( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests writing and reading a Zarr dataset using ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # --- Write --- + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + assert root_cid is not None + + # --- Read --- + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + ds_read = xr.open_zarr(store=store_read) + xr.testing.assert_identical(test_ds, ds_read) + + # Try to set a chunk directly in read-only mode + with pytest.raises(PermissionError): + proto = zarr.core.buffer.default_buffer_prototype() + await store_read.set("temp/c/0/0", proto.buffer.from_bytes(b"test_data")) + + +@pytest.mark.asyncio +async def test_load_or_initialize_shard_cache_concurrent_loads( + create_ipfs: tuple[str, str], +): + """Tests concurrent shard loading to trigger _pending_shard_loads wait.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Initialize store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Create a shard with data + shard_idx = 0 + shard_data = [ + CID.decode("bafyr4iacuutc5bgmirkfyzn4igi2wys7e42kkn674hx3c4dv4wrgjp2k2u") + for _ in range(4) + ] + shard_data_bytes = dag_cbor.encode(shard_data) + shard_cid_obj = await kubo_cas.save(shard_data_bytes, codec="dag-cbor") + store._root_obj["chunks"]["shard_cids"][shard_idx] = shard_cid_obj + store._dirty_root = True + await store.flush() + + # Simulate concurrent shard loads + async def load_shard(): + return await store._load_or_initialize_shard_cache(shard_idx) + + # Run multiple tasks concurrently + tasks = [load_shard() for _ in range(3)] + results = await asyncio.gather(*tasks) + + # Verify all tasks return the same shard data + for result in results: + assert len(result) == 4 + assert all(isinstance(cid, CID) for cid in result) + assert result == shard_data + + # Verify shard is cached and no pending loads remain + assert await store._shard_data_cache.__contains__(shard_idx) + assert await store._shard_data_cache.get(shard_idx) == shard_data + assert shard_idx not in store._pending_shard_loads + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_append( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests appending data to an existing Zarr dataset in the ShardedZarrStore, + which specifically exercises the store resizing logic. + """ + rpc_base_url, gateway_base_url = create_ipfs + initial_ds = random_zarr_dataset + + ordered_dims = list(initial_ds.sizes) + array_shape_tuple = tuple(initial_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(initial_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # 1. --- Write Initial Dataset --- + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=10, + ) + initial_ds.to_zarr(store=store_write, mode="w") + initial_cid = await store_write.flush() + assert initial_cid is not None + + # 2. --- Prepare Data to Append --- + # Create a new dataset with 50 more time steps + append_times = pd.date_range( + initial_ds.time[-1].values + pd.Timedelta(days=1), periods=50 + ) + append_temp = np.random.randn( + len(append_times), len(initial_ds.lat), len(initial_ds.lon) + ) + + append_ds = xr.Dataset( + { + "temp": (["time", "lat", "lon"], append_temp), + }, + coords={"time": append_times, "lat": initial_ds.lat, "lon": initial_ds.lon}, + ).chunk({"time": 20, "lat": 18, "lon": 36}) + + # 3. --- Perform Append Operation --- + store_append = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + root_cid=initial_cid, + ) + append_ds.to_zarr(store=store_append, mode="a", append_dim="time") + final_cid = await store_append.flush() + print(f"Data written to ShardedZarrStore with root CID: {final_cid}") + assert final_cid is not None + assert final_cid != initial_cid + + # 4. --- Verify the Final Dataset --- + store_read = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=True, + root_cid=final_cid, + ) + final_ds_read = xr.open_zarr(store=store_read) + + # The expected result is the concatenation of the two datasets + expected_final_ds = xr.concat([initial_ds, append_ds], dim="time") + + # Verify that the data read from the store is identical to the expected result + xr.testing.assert_identical(expected_final_ds, final_ds_read) + print("\n✅ Append test successful! Data verified.") + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_init(create_ipfs: tuple[str, str]): + """ + Tests the initialization of the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + array_shape = (100, 100) + chunk_shape = (10, 10) + chunks_per_shard = 64 + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test successful creation + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape, + chunk_shape=chunk_shape, + chunks_per_shard=chunks_per_shard, + ) + assert store is not None + + # Test missing parameters for new store + with pytest.raises(ValueError): + await ShardedZarrStore.open(cas=kubo_cas, read_only=False) + + # Test opening read-only store without root_cid + with pytest.raises(ValueError): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_metadata( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests metadata handling in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + # Test exists + assert await store_read.exists("lat/zarr.json") + assert await store_read.exists("lon/zarr.json") + assert await store_read.exists("time/zarr.json") + assert await store_read.exists("temp/zarr.json") + assert await store_read.exists("lat/c/0") + assert await store_read.exists("lon/c/0") + assert await store_read.exists("time/c/0") + # assert not await store_read.exists("nonexistent") + + # Test does not exist + assert not await store_read.exists("temp/c/20/0/0") # Non-existent chunk + + # Test list + keys = [key async for key in store_read.list()] + assert len(keys) > 0 + assert "lat/zarr.json" in keys + + prefix = "lat" + keys_with_prefix = [key async for key in store_read.list_prefix(prefix=prefix)] + assert "lat/zarr.json" in keys_with_prefix + assert "lat/c/0" in keys_with_prefix + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_chunks( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests chunk data handling in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + + # Test get + chunk_key = "temp/c/0/0/0" + proto = zarr.core.buffer.default_buffer_prototype() + chunk_data = await store_read.get(chunk_key, proto) + assert chunk_data is not None + + # Test delete + store_write = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + await store_write.delete(chunk_key) + await store_write.flush() + + store_read_after_delete = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=await store_write.flush() + ) + assert await store_read_after_delete.get(chunk_key, proto) is None + + +@pytest.mark.asyncio +async def test_chunk_and_delete_logic( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """Tests chunk getting, deleting, and related error handling.""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w", consolidated=True) + root_cid = await store_write.flush() + + # Re-open as writable to test deletion + store_rw = await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, root_cid=root_cid + ) + + chunk_key = "temp/c/0/0/0" + proto = zarr.core.buffer.default_buffer_prototype() + + # Verify chunk exists and can be read + assert await store_rw.exists(chunk_key) + chunk_data = await store_rw.get(chunk_key, proto) + assert chunk_data is not None + + # Delete the chunk + await store_rw.delete(chunk_key) + new_root_cid = await store_rw.flush() + + # Verify it's gone + store_after_delete = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=new_root_cid + ) + assert not await store_after_delete.exists(chunk_key) + assert await store_after_delete.get(chunk_key, proto) is None + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_partial_reads( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests partial reads in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + proto = zarr.core.buffer.default_buffer_prototype() + chunk_key = "temp/c/0/0/0" + full_chunk = await store_read.get(chunk_key, proto) + assert full_chunk is not None + full_chunk_bytes = full_chunk.to_bytes() + + # Test RangeByteRequest + byte_range = RangeByteRequest(start=10, end=50) + partial_chunk = await store_read.get(chunk_key, proto, byte_range=byte_range) + assert partial_chunk is not None + assert partial_chunk.to_bytes() == full_chunk_bytes[10:50] + + +@pytest.mark.asyncio +async def test_partial_reads_and_errors( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """Tests partial reads and error handling in get().""" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w", consolidated=True) + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + proto = zarr.core.buffer.default_buffer_prototype() + chunk_key = "temp/c/0/0/0" + full_chunk = await store_read.get(chunk_key, proto) + assert full_chunk is not None + full_chunk_bytes = full_chunk.to_bytes() + + # Test RangeByteRequest + byte_range = RangeByteRequest(start=10, end=50) + partial_chunk = await store_read.get(chunk_key, proto, byte_range=byte_range) + assert partial_chunk is not None + assert partial_chunk.to_bytes() == full_chunk_bytes[10:50] + + # Test invalid byte range + with pytest.raises(ValueError): + await store_read.get( + chunk_key, proto, byte_range=RangeByteRequest(start=50, end=10) + ) + + +@pytest.mark.asyncio +async def test_zero_sized_array(create_ipfs: tuple[str, str]): + """Test handling of arrays with a zero-length dimension.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(100, 0), + chunk_shape=(10, 10), + chunks_per_shard=64, + ) + assert store._total_chunks == 0 + assert store._num_shards == 0 + root_cid = await store.flush() + + # Read it back and verify + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + assert store_read._total_chunks == 0 + assert store_read._num_shards == 0 + + +@pytest.mark.asyncio +async def test_store_eq_method(create_ipfs: tuple[str, str]): + """Tests the __eq__ method.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store1 = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(1, 1), + chunk_shape=(1, 1), + chunks_per_shard=1, + ) + root_cid = await store1.flush() + store2 = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + + assert store1 == store2 + + +@pytest.mark.asyncio +async def test_with_read_only(create_ipfs: tuple[str, str]): + """Tests the with_read_only method.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Create a writable store + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(5, 5), + chunks_per_shard=4, + ) + + # Test same mode returns same instance + same_store = store_write.with_read_only(False) + assert same_store is store_write + + # Test switching to read-only + store_read_only = store_write.with_read_only(True) + assert store_read_only is not store_write + assert store_read_only.read_only is True + assert store_write.read_only is False + + # Test that clone shares the same state + assert store_read_only.cas is store_write.cas + assert store_read_only._root_cid == store_write._root_cid + assert store_read_only._root_obj is store_write._root_obj + assert store_read_only._shard_data_cache is store_write._shard_data_cache + # _dirty_shards is now managed by the cache, not the store directly + assert ( + store_read_only._shard_data_cache._dirty_shards + is store_write._shard_data_cache._dirty_shards + ) + assert store_read_only._array_shape == store_write._array_shape + assert store_read_only._chunk_shape == store_write._chunk_shape + assert store_read_only._chunks_per_shard == store_write._chunks_per_shard + + # Test switching back to writable + store_write_again = store_read_only.with_read_only(False) + assert store_write_again is not store_read_only + assert store_write_again.read_only is False + + # Test write operations are blocked on read-only store + proto = zarr.core.buffer.default_buffer_prototype() + test_data = proto.buffer.from_bytes(b"test_data") + + with pytest.raises(PermissionError): + await store_read_only.set("test_key", test_data) + + with pytest.raises(PermissionError): + await store_read_only.delete("test_key") + + # Test write operations work on writable store + await store_write.set("test_key", test_data) + await store_write.delete("test_key") + + +@pytest.mark.asyncio +async def test_listing_and_metadata( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests metadata handling and listing in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w", consolidated=True) + root_cid = await store_write.flush() + + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + # Test exists for metadata (correcting for xarray group structure) + assert await store_read.exists("temp/zarr.json") + assert await store_read.exists("lat/zarr.json") + assert not await store_read.exists("nonexistent.json") + + # Test listing + keys = {key async for key in store_read.list()} + assert "temp/zarr.json" in keys + assert "lat/zarr.json" in keys + + # Test list_prefix + prefix_keys = {key async for key in store_read.list_prefix("temp/")} + assert "temp/zarr.json" in prefix_keys + + # Test list_dir for root + dir_keys = {key async for key in store_read.list_dir("")} + assert "temp" in dir_keys + assert "lat" in dir_keys + assert "lon" in dir_keys + assert "zarr.json" in dir_keys + + # Test listing with a prefix + prefix = "temp/" + with pytest.raises( + NotImplementedError, match="Listing with a prefix is not implemented yet." + ): + async for key in store_read.list_dir(prefix): + print(f"Key with prefix '{prefix}': {key}") + + with pytest.raises( + ValueError, match="Byte range requests are not supported for metadata keys." + ): + proto = zarr.core.buffer.default_buffer_prototype() + byte_range = zarr.abc.store.RangeByteRequest(start=10, end=50) + await store_read.get("lat/zarr.json", proto, byte_range=byte_range) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_init_errors(create_ipfs: tuple[str, str]): + """ + Tests initialization errors for ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test missing parameters for a new store + with pytest.raises(ValueError, match="must be provided for a new store"): + await ShardedZarrStore.open(cas=kubo_cas, read_only=False) + + # Test opening a read-only store without a root_cid + with pytest.raises(ValueError, match="must be provided for a read-only store"): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True) + + # Test invalid chunks_per_shard + with pytest.raises(ValueError, match="must be a positive integer"): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=0, + ) + + # Test invalid chunk_shape + with pytest.raises( + ValueError, match="All chunk_shape dimensions must be positive" + ): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(0,), + chunks_per_shard=10, + ) + + # Test invalid array_shape + with pytest.raises( + ValueError, match="All array_shape dimensions must be non-negative" + ): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(-10,), + chunk_shape=(5,), + chunks_per_shard=10, + ) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_get_partial_values( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests the get_partial_values method of ShardedZarrStore, including RangeByteRequest, + OffsetByteRequest, SuffixByteRequest, and full reads, along with error handling for + invalid byte ranges. + """ + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + + ordered_dims = list(test_ds.sizes) + array_shape_tuple = tuple(test_ds.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(test_ds.chunks[dim][0] for dim in ordered_dims) + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # 1. --- Write Dataset to ShardedZarrStore --- + store_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + test_ds.to_zarr(store=store_write, mode="w") + root_cid = await store_write.flush() + assert root_cid is not None + + # 2. --- Open Store for Reading --- + store_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid + ) + proto = zarr.core.buffer.default_buffer_prototype() + + # 3. --- Find a Chunk Key to Test --- + chunk_key = "temp/c/0/0/0" # Default chunk key + # async for key in store_read.list(): + # print(f"Found key: {key}") + # if key.startswith("temp/c/") and not key.endswith(".json"): + # chunk_key = key + # break + + assert chunk_key is not None, "Could not find a chunk key to test." + print(f"Testing with chunk key: {chunk_key}") + + # 4. --- Get Full Chunk Data for Comparison --- + full_chunk_buffer = await store_read.get(chunk_key, proto) + assert full_chunk_buffer is not None + full_chunk_data = full_chunk_buffer.to_bytes() + chunk_len = len(full_chunk_data) + print(f"Full chunk size: {chunk_len} bytes") + + # Ensure chunk is large enough for meaningful partial read tests + assert chunk_len > 100, "Chunk size too small for partial value tests" + + # 5. --- Define Byte Requests --- + range_req = RangeByteRequest(start=10, end=50) # Request 40 bytes + offset_req = OffsetByteRequest(offset=chunk_len - 30) # Last 30 bytes + suffix_req = SuffixByteRequest(suffix=20) # Last 20 bytes + + key_ranges_to_test = [ + (chunk_key, range_req), + (chunk_key, offset_req), + (chunk_key, suffix_req), + (chunk_key, None), # Full read + ] + + # 6. --- Call get_partial_values --- + results = await store_read.get_partial_values(proto, key_ranges_to_test) + + # 7. --- Assertions --- + assert len(results) == 4, "Expected 4 results from get_partial_values" + + assert results[0] is not None, "RangeByteRequest result should not be None" + assert results[1] is not None, "OffsetByteRequest result should not be None" + assert results[2] is not None, "SuffixByteRequest result should not be None" + assert results[3] is not None, "Full read result should not be None" + + # Check RangeByteRequest result + expected_range = full_chunk_data[10:50] + assert results[0].to_bytes() == expected_range, ( + "RangeByteRequest result does not match" + ) + print(f"RangeByteRequest: OK (Got {len(results[0].to_bytes())} bytes)") + + # Check OffsetByteRequest result + expected_offset = full_chunk_data[chunk_len - 30 :] + assert results[1].to_bytes() == expected_offset, ( + "OffsetByteRequest result does not match" + ) + print(f"OffsetByteRequest: OK (Got {len(results[1].to_bytes())} bytes)") + + # Check SuffixByteRequest result + expected_suffix = full_chunk_data[-20:] + assert results[2].to_bytes() == expected_suffix, ( + "SuffixByteRequest result does not match" + ) + print(f"SuffixByteRequest: OK (Got {len(results[2].to_bytes())} bytes)") + + # Check full read result + assert results[3].to_bytes() == full_chunk_data, ( + "Full read via get_partial_values does not match" + ) + print(f"Full Read: OK (Got {len(results[3].to_bytes())} bytes)") + + # 8. --- Test Invalid Byte Range --- + invalid_range_req = RangeByteRequest(start=50, end=10) + with pytest.raises( + ValueError, + match="Byte range start.*cannot be greater than end", + ): + await store_read.get_partial_values(proto, [(chunk_key, invalid_range_req)]) + + print("\n✅ get_partial_values test successful! All partial reads verified.") + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_parse_chunk_key(create_ipfs: tuple[str, str]): + """Tests chunk key parsing edge cases.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20), + chunk_shape=(10, 10), + chunks_per_shard=4, + ) + + # Test metadata key + assert store._parse_chunk_key("zarr.json") is None + assert store._parse_chunk_key("group1/zarr.json") is None + + # Test excluded array prefixes + assert store._parse_chunk_key("time/c/0") is None + assert store._parse_chunk_key("lat/c/0/0") is None + assert store._parse_chunk_key("lon/c/0/0") is None + + # Test dimensionality mismatch + with pytest.raises(IndexError, match="tuple index out of range"): + store._parse_chunk_key("temp/c/0/0/0/0") + + # Test invalid coordinates + with pytest.raises(ValueError, match="invalid literal"): + store._parse_chunk_key("temp/c/0/invalid") + with pytest.raises(IndexError, match="Chunk coordinate"): + store._parse_chunk_key("temp/c/0/-1") + + with pytest.raises(IndexError, match="Chunk coordinate"): + store._parse_chunk_key("temp/c/3/0") + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_init_invalid_shapes(create_ipfs: tuple[str, str]): + """Tests initialization with invalid shapes and manifest errors.""" + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test negative chunk_shape dimension + with pytest.raises( + ValueError, match="All chunk_shape dimensions must be positive" + ): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(-5, 5), + chunks_per_shard=10, + ) + + # Test negative array_shape dimension + with pytest.raises( + ValueError, match="All array_shape dimensions must be non-negative" + ): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, -10), + chunk_shape=(5, 5), + chunks_per_shard=10, + ) + + # Test zero-sized array + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(0, 10), + chunk_shape=(5, 5), + chunks_per_shard=10, + ) + assert store._total_chunks == 0 + assert store._num_shards == 0 + assert store._root_obj is not None + assert len(store._root_obj["chunks"]["shard_cids"]) == 0 # Line 163 + + # Test invalid manifest version + invalid_root_obj = { + "manifest_version": "invalid_version", + "metadata": {}, + "chunks": { + "array_shape": [10, 10], + "chunk_shape": [5, 5], + "cid_byte_length": 59, + "sharding_config": {"chunks8048": 10}, + "shard_cids": [None] * 4, + }, + } + invalid_root_cid = await kubo_cas.save( + dag_cbor.encode(invalid_root_obj), codec="dag-cbor" + ) + with pytest.raises(ValueError, match="Incompatible manifest version"): + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_root_cid + ) + + # Test inconsistent shard count + invalid_root_obj = { + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": [ + 10, + 10, + ], # 100 chunks, with 10 chunks per shard -> 10 shards + "chunk_shape": [5, 5], + "cid_byte_length": 59, + "sharding_config": {"chunks_per_shard": 10}, + "shard_cids": [None] * 5, # Wrong number of shards + }, + } + invalid_root_cid = await kubo_cas.save( + dag_cbor.encode(invalid_root_obj), codec="dag-cbor" + ) + with pytest.raises(ValueError, match="Inconsistent number of shards"): + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_root_cid + ) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_lazy_concat( + create_ipfs: tuple[str, str], random_zarr_dataset: xr.Dataset +): + """ + Tests lazy concatenation of two xarray datasets stored in separate ShardedZarrStores, + ensuring the combined dataset can be queried as a single dataset with data fetched + correctly from the respective stores. + """ + rpc_base_url, gateway_base_url = create_ipfs + base_ds = random_zarr_dataset + + # 1. --- Prepare Two Datasets with Distinct Time Ranges --- + # First dataset: August 1, 2024 to September 30, 2024 (61 days) + aug_sep_times = pd.date_range("2024-08-01", "2024-09-30", freq="D") + aug_sep_temp = np.random.randn( + len(aug_sep_times), len(base_ds.lat), len(base_ds.lon) + ) + ds1 = xr.Dataset( + { + "temp": (["time", "lat", "lon"], aug_sep_temp), + }, + coords={"time": aug_sep_times, "lat": base_ds.lat, "lon": base_ds.lon}, + ).chunk({"time": 20, "lat": 18, "lon": 36}) + + # Second dataset: October 1, 2024 to November 30, 2024 (61 days) + oct_nov_times = pd.date_range("2024-10-01", "2024-11-30", freq="D") + oct_nov_temp = np.random.randn( + len(oct_nov_times), len(base_ds.lat), len(base_ds.lon) + ) + ds2 = xr.Dataset( + { + "temp": (["time", "lat", "lon"], oct_nov_temp), + }, + coords={"time": oct_nov_times, "lat": base_ds.lat, "lon": base_ds.lon}, + ).chunk({"time": 20, "lat": 18, "lon": 36}) + + # Expected concatenated dataset for verification + expected_combined_ds = xr.concat([ds1, ds2], dim="time") + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # 2. --- Write First Dataset to ShardedZarrStore --- + ordered_dims = list(ds1.sizes) + array_shape_tuple = tuple(ds1.sizes[dim] for dim in ordered_dims) + chunk_shape_tuple = tuple(ds1.chunks[dim][0] for dim in ordered_dims) + + store1_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + ds1.to_zarr(store=store1_write, mode="w") + root_cid1 = await store1_write.flush() + assert root_cid1 is not None + + # 3. --- Write Second Dataset to ShardedZarrStore --- + array_shape_tuple = tuple(ds2.sizes[dim] for dim in ordered_dims) + store2_write = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=array_shape_tuple, + chunk_shape=chunk_shape_tuple, + chunks_per_shard=64, + ) + ds2.to_zarr(store=store2_write, mode="w") + root_cid2 = await store2_write.flush() + assert root_cid2 is not None + + # 4. --- Read and Lazily Concatenate Datasets --- + store1_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid1 + ) + store2_read = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=root_cid2 + ) + + ds1_read = xr.open_zarr(store=store1_read, chunks="auto") + ds2_read = xr.open_zarr(store=store2_read, chunks="auto") + + # Verify that datasets are lazy (Dask-backed) + assert ds1_read.temp.chunks is not None + assert ds2_read.temp.chunks is not None + + # Lazily concatenate along time dimension + combined_ds = xr.concat([ds1_read, ds2_read], dim="time") + + # Verify that the combined dataset is still lazy + assert combined_ds.temp.chunks is not None + + # 5. --- Query Across Both Datasets --- + # Select a time slice that spans both datasets (e.g., Sep 15 to Oct 15) + query_slice = combined_ds.sel(time=slice("2024-09-15", "2024-10-15")) + + # Verify that the query is still lazy + assert query_slice.temp.chunks is not None + + # Compute the result to trigger data loading + query_result = query_slice.compute() + + # 6. --- Verify Results --- + # Compare with the expected concatenated dataset + expected_query_result = expected_combined_ds.sel( + time=slice("2024-09-15", "2024-10-15") + ) + xr.testing.assert_identical(query_result, expected_query_result) + + # Verify specific values at a point to ensure data integrity + sample_time = pd.Timestamp("2024-09-30") # From ds1 + sample_result = query_result.sel(time=sample_time, method="nearest").temp.values + expected_sample = expected_combined_ds.sel( + time=sample_time, method="nearest" + ).temp.values + np.testing.assert_array_equal(sample_result, expected_sample) + + sample_time = pd.Timestamp("2024-10-01") # From ds2 + sample_result = query_result.sel(time=sample_time, method="nearest").temp.values + expected_sample = expected_combined_ds.sel( + time=sample_time, method="nearest" + ).temp.values + np.testing.assert_array_equal(sample_result, expected_sample) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_lazy_concat_with_cids(create_ipfs: tuple[str, str]): + """ + Tests lazy concatenation of two xarray datasets stored in separate ShardedZarrStores + using provided CIDs for finalized and non-finalized data, ensuring the non-finalized + dataset is sliced after the finalization date (inclusive) and the combined dataset + can be queried as a single dataset with data fetched correctly from the respective stores. + """ + rpc_base_url, gateway_base_url = create_ipfs + + # Provided CIDs + finalized_cid = "bafyr4iacuutc5bgmirkfyzn4igi2wys7e42kkn674hx3c4dv4wrgjp2k2u" + non_finalized_cid = "bafyr4ihicmzx4uw4pefk7idba3mz5r5g27au3l7d62yj4gguxx6neaa5ti" + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # 1. --- Open Finalized Dataset --- + store_finalized = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=finalized_cid + ) + ds_finalized = xr.open_zarr(store=store_finalized, chunks="auto") + + # Verify that the dataset is lazy (Dask-backed) + assert ds_finalized["2m_temperature"].chunks is not None + + # Determine the finalization date (last date in finalized dataset) + finalization_date = np.datetime64(ds_finalized.time.max().values) + + # 2. --- Open Non-Finalized Dataset and Slice After Finalization Date --- + store_non_finalized = await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=non_finalized_cid + ) + ds_non_finalized = xr.open_zarr(store=store_non_finalized, chunks="auto") + + # Verify that the dataset is lazy + assert ds_non_finalized["2m_temperature"].chunks is not None + + # Slice non-finalized dataset to start *after* the finalization date + # (finalization_date is inclusive for finalized data, so start at +1 hour) + start_time = finalization_date + np.timedelta64(1, "h") + ds_non_finalized_sliced = ds_non_finalized.sel(time=slice(start_time, None)) + + # Verify that the sliced dataset starts after the finalization date + if ds_non_finalized_sliced.time.size > 0: + assert ds_non_finalized_sliced.time.min() > finalization_date + + # 3. --- Lazily Concatenate Datasets --- + combined_ds = xr.concat([ds_finalized, ds_non_finalized_sliced], dim="time") + + # Verify that the combined dataset is still lazy + assert combined_ds["2m_temperature"].chunks is not None + + # 4. --- Query Across Both Datasets --- + # Select a time slice that spans both datasets + # Use a range that includes the boundary (e.g., finalization date and after) + query_start = finalization_date - np.timedelta64(1, "D") # 1 day before + query_end = finalization_date + np.timedelta64(1, "D") # 1 day after + query_slice = combined_ds.sel( + time=slice(query_start, query_end), latitude=0, longitude=0 + ) + # Make sure the query slice aligned with the query_start and query_end + + assert query_slice.time.min() >= query_start + assert query_slice.time.max() <= query_end + + # Verify that the query is still lazy + assert query_slice["2m_temperature"].chunks is not None + + # Compute the result to trigger data loading + query_result = query_slice.compute() + + # 5. --- Verify Results --- + # Verify data integrity at specific points + # Check the last finalized time (from finalized dataset) + sample_time_finalized = finalization_date + if sample_time_finalized in query_result.time.values: + sample_result = query_result.sel( + time=sample_time_finalized, method="nearest" + )["2m_temperature"].values + expected_sample = ds_finalized.sel( + time=sample_time_finalized, latitude=0, longitude=0, method="nearest" + )["2m_temperature"].values + np.testing.assert_array_equal(sample_result, expected_sample) + + # Check the first non-finalized time (from non-finalized dataset, if available) + if ds_non_finalized_sliced.time.size > 0: + sample_time_non_finalized = ds_non_finalized_sliced.time.min().values + if sample_time_non_finalized in query_result.time.values: + sample_result = query_result.sel( + time=sample_time_non_finalized, method="nearest" + )["2m_temperature"].values + expected_sample = ds_non_finalized_sliced.sel( + time=sample_time_non_finalized, + latitude=0, + longitude=0, + method="nearest", + )["2m_temperature"].values + np.testing.assert_array_equal(sample_result, expected_sample) + + # 6. --- Additional Validation --- + # Verify that the concatenated dataset has no overlapping times + time_values = combined_ds.time.values + assert np.all(np.diff(time_values) > np.timedelta64(0, "ns")), ( + "Overlapping or unsorted time values detected" + ) + + # Verify that the query result covers the expected time range + if query_result.time.size > 0: + assert query_result.time.min() >= query_start + assert query_result.time.max() <= query_end + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_basic(): + """Test basic functionality of MemoryBoundedLRUCache.""" + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + # Very small cache for testing to ensure eviction + cache = MemoryBoundedLRUCache(max_memory_bytes=500) # 500 bytes limit + + # Create some test data + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + small_shard = [test_cid] * 2 + medium_shard = [test_cid] * 5 + + # Test basic put/get + await cache.put(0, small_shard) + assert await cache.get(0) == small_shard + assert await cache.__contains__(0) + assert cache.cache_size == 1 + + # Test that get moves item to end (most recently used) + await cache.put(1, medium_shard) + await cache.get(0) # Should move shard 0 to end + + # Add more data to trigger eviction - this should be large enough to force eviction + large_shard = [test_cid] * 20 + await cache.put(2, large_shard) + + # Check basic cache behavior - at least one should be evicted due to memory limits + + # Add an even larger shard to definitely trigger eviction + huge_shard = [test_cid] * 50 + await cache.put(3, huge_shard) + + # Cache should be constrained by memory limit and perform evictions + # The exact behavior depends on actual memory usage, but we should see some eviction + assert ( + cache.estimated_memory_usage <= cache.max_memory_bytes or cache.cache_size == 1 + ) + # At least some items should remain in cache + assert cache.cache_size >= 1 + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_dirty_protection(): + """Test that dirty shards are never evicted.""" + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + # Very small cache to force eviction + cache = MemoryBoundedLRUCache(max_memory_bytes=500) # 500 bytes + + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + small_shard = [test_cid] * 3 + large_shard = [test_cid] * 20 # This should exceed memory limit + + # Add a dirty shard + await cache.put(0, small_shard, is_dirty=True) + assert cache.dirty_cache_size == 1 + + # Add a clean shard + await cache.put(1, small_shard) + # Cache size should be 2, but might be less if eviction occurred + assert cache.cache_size >= 1 # At least the dirty shard should remain + assert cache.dirty_cache_size == 1 + + # Add a large clean shard that should trigger eviction + await cache.put(2, large_shard) + + # Dirty shard 0 should still be there (protected) + assert await cache.get(0) is not None # Dirty shard protected + + # Either shard 1 or shard 2 might be evicted depending on memory constraints + # The important thing is that the dirty shard (0) is never evicted + cached_1 = await cache.get(1) + cached_2 = await cache.get(2) + + # At least one of the clean shards should be evicted due to memory pressure + evicted_count = (1 if cached_1 is None else 0) + (1 if cached_2 is None else 0) + assert evicted_count >= 1, ( + "At least one clean shard should be evicted due to memory pressure" + ) + + # Test marking dirty shard as clean + await cache.mark_clean(0) + assert cache.dirty_cache_size == 0 + + # Now shard 0 can be evicted + even_larger_shard = [test_cid] * 30 + await cache.put(3, even_larger_shard) + assert await cache.get(0) is None # Now evicted since it's clean + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_memory_estimation(): + """Test memory usage estimation.""" + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + cache = MemoryBoundedLRUCache(max_memory_bytes=10000) + + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + + # Test with None values (should use minimal memory) + sparse_shard = [None] * 100 + await cache.put(0, sparse_shard) + sparse_usage = cache.estimated_memory_usage + + # Test with CID values (should use more memory) + dense_shard = [test_cid] * 100 + await cache.put(1, dense_shard) + dense_usage = cache.estimated_memory_usage + + # Dense shard should use significantly more memory than sparse + assert dense_usage > sparse_usage # CIDs are larger than None + + # Test cache clear + await cache.clear() + assert cache.estimated_memory_usage == 0 + assert cache.cache_size == 0 + assert cache.dirty_cache_size == 0 + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_cache_integration(create_ipfs): + """Test that ShardedZarrStore properly uses the memory-bounded cache.""" + rpc_base_url, gateway_base_url = create_ipfs + + # Create a store with very small cache to test eviction + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10, 10), + chunk_shape=(2, 2, 2), + chunks_per_shard=8, + max_cache_memory_bytes=2000, # 2KB cache limit + ) + + # Create test data + test_data = np.random.randn(2, 2, 2).astype(np.float32) + buffer = zarr.core.buffer.default_buffer_prototype().buffer.from_bytes( + test_data.tobytes() + ) + + # Test cache behavior by writing data and checking dirty status + # Let's first manually mark a shard as dirty to test the mechanism + await store._shard_data_cache.put(0, [None] * 8, is_dirty=True) + assert store._shard_data_cache.dirty_cache_size == 1 + + # Now test actual store operations - write to multiple chunks + await store.set("c/0/0/0", buffer) + await store.set("c/0/0/1", buffer) + await store.set("c/0/1/0", buffer) + + # The cache should now have at least one dirty shard + assert store._shard_data_cache.dirty_cache_size > 0 + + # Read from the chunks we actually wrote to + result1 = await store.get( + "c/0/0/0", zarr.core.buffer.default_buffer_prototype() + ) + result2 = await store.get( + "c/0/0/1", zarr.core.buffer.default_buffer_prototype() + ) + result3 = await store.get( + "c/0/1/0", zarr.core.buffer.default_buffer_prototype() + ) + + assert result1 is not None + assert result2 is not None + assert result3 is not None + + # Flush to make shards clean + await store.flush() + + # After flush, dirty count should be 0 + assert store._shard_data_cache.dirty_cache_size == 0 + + # But cache should still contain some shards + assert store._shard_data_cache.cache_size > 0 + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_cache_eviction_during_read(create_ipfs): + """Test cache eviction behavior during read-heavy workloads.""" + rpc_base_url, gateway_base_url = create_ipfs + + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Create and populate a store + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(20, 20, 20), + chunk_shape=(2, 2, 2), + chunks_per_shard=10, + max_cache_memory_bytes=1500, # Small cache to force eviction + ) + + # Write data to multiple shards + test_data = np.random.randn(2, 2, 2).astype(np.float32) + buffer = zarr.core.buffer.default_buffer_prototype().buffer.from_bytes( + test_data.tobytes() + ) + + for i in range(0, 10, 2): # Write to shards 0, 1, 2, 3, 4 + await store.set(f"c/{i}/0/0", buffer) + + # Flush to make all shards clean + root_cid = await store.flush() + + # Now open in read-only mode to test pure read behavior + readonly_store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=True, + root_cid=root_cid, + max_cache_memory_bytes=1000, # Even smaller cache + ) + + # Read from many different locations to fill and overflow cache + read_results = [] + for i in range(0, 10, 2): + result = await readonly_store.get( + f"c/{i}/0/0", zarr.core.buffer.default_buffer_prototype() + ) + read_results.append(result) + + # All reads should succeed even with cache eviction + assert all(result is not None for result in read_results) + + # Cache should have evicted some shards due to memory limit + assert ( + readonly_store._shard_data_cache.cache_size <= 5 + ) # Not all shards should be cached + + # No shards should be dirty in read-only mode + assert readonly_store._shard_data_cache.dirty_cache_size == 0 diff --git a/tests/test_sharded_zarr_store_coverage.py b/tests/test_sharded_zarr_store_coverage.py new file mode 100644 index 0000000..72463c7 --- /dev/null +++ b/tests/test_sharded_zarr_store_coverage.py @@ -0,0 +1,896 @@ +import dag_cbor +import pytest +import zarr.abc.store +import zarr.core.buffer + +from py_hamt import KuboCAS +from py_hamt.sharded_zarr_store import ShardedZarrStore + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_init_exceptions(create_ipfs: tuple[str, str]): + """ + Tests various initialization exceptions in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test RuntimeError when base shape information is not set + # with pytest.raises(RuntimeError, match="Base shape information is not set."): + # store = ShardedZarrStore(kubo_cas, False, None) + # store._update_geometry() + + # Test ValueError for non-positive chunk_shape dimensions + with pytest.raises( + ValueError, match="All chunk_shape dimensions must be positive." + ): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(10, 0), + chunks_per_shard=10, + ) + + # Test ValueError for non-negative array_shape dimensions + with pytest.raises( + ValueError, match="All array_shape dimensions must be non-negative." + ): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, -10), + chunk_shape=(10, 10), + chunks_per_shard=10, + ) + + # Test ValueError when array_shape is not provided for a new store + with pytest.raises( + ValueError, + match="array_shape and chunk_shape must be provided for a new store.", + ): + await ShardedZarrStore.open( + cas=kubo_cas, read_only=False, chunk_shape=(10, 10) + ) + + # Test ValueError for non-positive chunks_per_shard + with pytest.raises( + ValueError, match="chunks_per_shard must be a positive integer." + ): + await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10, 10), + chunk_shape=(10, 10), + chunks_per_shard=0, + ) + + # Test ValueError when root_cid is not provided for a read-only store + with pytest.raises( + ValueError, match="root_cid must be provided for a read-only store." + ): + await ShardedZarrStore.open(cas=kubo_cas, read_only=True) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_load_root_exceptions(create_ipfs: tuple[str, str]): + """ + Tests exceptions raised during the loading of the root object. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + # Test RuntimeError when _root_cid is not set + # with pytest.raises(RuntimeError, match="Cannot load root without a root_cid."): + # store = ShardedZarrStore(kubo_cas, True, None) + # await store._load_root_from_cid() + + # Test ValueError for an incompatible manifest version + invalid_manifest_root = { + "manifest_version": "invalid_version", + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [], + }, + } + invalid_manifest_cid = await kubo_cas.save( + dag_cbor.encode(invalid_manifest_root), codec="dag-cbor" + ) + with pytest.raises(ValueError, match="Incompatible manifest version"): + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=invalid_manifest_cid + ) + + # Test ValueError for an inconsistent number of shards + inconsistent_shards_root = { + "manifest_version": "sharded_zarr_v1", + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [ + None, + None, + None, + ], # Should be 2 shards, but array shape dictates 2 total chunks + }, + } + inconsistent_shards_cid = await kubo_cas.save( + dag_cbor.encode(inconsistent_shards_root), codec="dag-cbor" + ) + with pytest.raises(ValueError, match="Inconsistent number of shards"): + await ShardedZarrStore.open( + cas=kubo_cas, read_only=True, root_cid=inconsistent_shards_cid + ) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_shard_handling_exceptions( + create_ipfs: tuple[str, str], caplog +): + """ + Tests exceptions and logging during shard handling. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=1, + ) + + # Test TypeError when a shard does not decode to a list + invalid_shard_cid = await kubo_cas.save( + dag_cbor.encode({"not": "a list"}), "dag-cbor" + ) + store._root_obj["chunks"]["shard_cids"][0] = invalid_shard_cid + with pytest.raises(TypeError, match="Shard 0 did not decode to a list."): + await store._load_or_initialize_shard_cache(0) + + # bad __eq__ method + assert store != {"not a ShardedZarrStore": "test"} + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_get_set_exceptions(create_ipfs: tuple[str, str]): + """ + Tests exceptions raised during get and set operations. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=1, + ) + proto = zarr.core.buffer.default_buffer_prototype() + + # Test RuntimeError when root object is not loaded in get + # store_no_root = ShardedZarrStore(kubo_cas, True, "some_cid") + # with pytest.raises( + # RuntimeError, match="Load the root object first before accessing data." + # ): + # await store_no_root.get("key", proto) + + # Set some bytes to /c/0 to ensure it exists + await store.set( + "/c/0", + proto.buffer.from_bytes(b'{"shape": [10], "dtype": "float32"}'), + ) + + # Test ValueError for invalid byte range in get + with pytest.raises( + ValueError, + match="Byte range start .* cannot be greater than end .*", + ): + await store.get( + "/c/0", + proto, + byte_range=zarr.abc.store.RangeByteRequest(start=10, end=5), + ) + + # Test NotImplementedError for set_partial_values + with pytest.raises(NotImplementedError): + await store.set_partial_values([]) + + # Test ValueError when shape is not found in metadata during set + with pytest.raises(ValueError, match="Shape not found in metadata."): + await store.set( + "test/zarr.json", proto.buffer.from_bytes(b'{"not": "a shape"}') + ) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_other_exceptions(create_ipfs: tuple[str, str]): + """ + Tests other miscellaneous exceptions in the ShardedZarrStore. + """ + rpc_base_url, gateway_base_url = create_ipfs + async with KuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + store = await ShardedZarrStore.open( + cas=kubo_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=1, + ) + + # Test RuntimeError for uninitialized store in flush + # store_no_root = ShardedZarrStore(kubo_cas, False, None) + # with pytest.raises(RuntimeError, match="Store not initialized for writing."): + # await store_no_root.flush() + + # Test ValueError when resizing a store with a different number of dimensions + with pytest.raises( + ValueError, + match="New shape must have the same number of dimensions as the old shape.", + ): + await store.resize_store(new_shape=(10, 10)) + + # Test KeyError when resizing a variable that doesn't exist + with pytest.raises( + KeyError, + match="Cannot find metadata for key 'nonexistent/zarr.json' to resize.", + ): + await store.resize_variable("nonexistent", new_shape=(20,)) + + # Test RuntimeError when listing a store with no root object + # with pytest.raises(RuntimeError, match="Root object not loaded."): + # async for _ in store_no_root.list(): + # pass + + # # Test RuntimeError when listing directories of a store with no root object + # with pytest.raises(RuntimeError, match="Root object not loaded."): + # async for _ in store_no_root.list_dir(""): + # pass + + # with pytest.raises(ValueError, match="Linear chunk index cannot be negative."): + # await store_no_root._get_shard_info(-1) + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_empty_shard(): + """Test line 40: empty shard handling in _get_shard_size""" + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + cache = MemoryBoundedLRUCache(max_memory_bytes=1000) + empty_shard = [] + + # Test that empty shard is handled correctly (line 40) + size = cache._get_shard_size(empty_shard) + assert size > 0 # sys.getsizeof should return some size even for empty list + + await cache.put(0, empty_shard) + retrieved = await cache.get(0) + assert retrieved == empty_shard + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_update_existing(): + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + cache = MemoryBoundedLRUCache(max_memory_bytes=10000) + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + + # First put + shard1 = [test_cid] * 2 + await cache.put(0, shard1) + + shard2 = [test_cid] * 3 + await cache.put(0, shard2, is_dirty=True) + + retrieved = await cache.get(0) + assert retrieved == shard2 + assert cache.dirty_cache_size == 1 + + +@pytest.mark.asyncio +async def test_memory_bounded_lru_cache_eviction_break(): + """Test line 96: eviction break when no clean shards available""" + from multiformats import CID + + from py_hamt.sharded_zarr_store import MemoryBoundedLRUCache + + cache = MemoryBoundedLRUCache(max_memory_bytes=500) # Small cache + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + + # Add several dirty shards to fill cache + large_shard = [test_cid] * 10 + for i in range(3): + await cache.put(i, large_shard, is_dirty=True) + + # Try to add another large shard - should trigger line 96 break + huge_shard = [test_cid] * 20 + await cache.put(3, huge_shard) + + # All dirty shards should still be in cache (not evicted) + for i in range(3): + assert await cache.get(i) is not None + assert cache.dirty_cache_size == 3 + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_duplicate_root_loading(): + """Test line 277: duplicate root object loading""" + import dag_cbor + + from py_hamt.sharded_zarr_store import ShardedZarrStore + + # Create mock CAS that returns malformed data to trigger line 277 + class MockCAS: + def __init__(self): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + # Return valid DAG-CBOR for a sharded zarr root + root_obj = { + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [None, None], + }, + } + return dag_cbor.encode(root_obj) + + # Create store with mock CAS + mock_cas = MockCAS() + store = ShardedZarrStore(mock_cas, True, "test_cid") + + # This should trigger line 277 where root_obj gets set twice + await store._load_root_from_cid() + + assert store._root_obj is not None + assert store._array_shape == (10,) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_invalid_root_object_structure(): + import dag_cbor + + from py_hamt.sharded_zarr_store import ShardedZarrStore + + class MockCAS: + def __init__(self, root_obj): + self.root_obj = root_obj + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + return dag_cbor.encode(self.root_obj) + + # Test line 274: root object is not a dict + mock_cas_not_dict = MockCAS("not a dictionary") + store = ShardedZarrStore(mock_cas_not_dict, True, "test_cid") + with pytest.raises(ValueError, match="Root object is not a valid dictionary"): + await store._load_root_from_cid() + + # Test line 274: root object missing 'chunks' key + mock_cas_no_chunks = MockCAS({ + "metadata": {}, + "manifest_version": "sharded_zarr_v1", + }) + store = ShardedZarrStore(mock_cas_no_chunks, True, "test_cid") + with pytest.raises(ValueError, match="Root object is not a valid dictionary"): + await store._load_root_from_cid() + + mock_cas_invalid_shard_cids = MockCAS({ + "manifest_version": "sharded_zarr_v1", + "metadata": {}, + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": "not a list", # Should be a list + }, + }) + store = ShardedZarrStore(mock_cas_invalid_shard_cids, True, "test_cid") + with pytest.raises(ValueError, match="shard_cids is not a list"): + await store._load_root_from_cid() + + # Test line 280: generic exception handling (invalid DAG-CBOR) + class MockCASInvalidDagCbor: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + return b"invalid dag-cbor data" # This will cause dag_cbor.decode to fail + + mock_cas_invalid_cbor = MockCASInvalidDagCbor() + store = ShardedZarrStore(mock_cas_invalid_cbor, True, "test_cid") + with pytest.raises(ValueError, match="Failed to decode root object"): + await store._load_root_from_cid() + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_invalid_manifest_version(): + import dag_cbor + + from py_hamt.sharded_zarr_store import ShardedZarrStore + + class MockCAS: + def __init__(self, manifest_version): + self.manifest_version = manifest_version + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + root_obj = { + "manifest_version": self.manifest_version, + "metadata": {}, + "chunks": { + "array_shape": [10], + "chunk_shape": [5], + "sharding_config": {"chunks_per_shard": 1}, + "shard_cids": [None, None], + }, + } + return dag_cbor.encode(root_obj) + + mock_cas = MockCAS("wrong_version") + store = ShardedZarrStore(mock_cas, True, "test_cid") + + with pytest.raises(ValueError, match="Incompatible manifest version"): + await store._load_root_from_cid() + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_shard_fetch_retry(): + from py_hamt.sharded_zarr_store import ShardedZarrStore + + class MockCAS: + def __init__(self, fail_count=2): + self.fail_count = fail_count + self.attempts = 0 + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + self.attempts += 1 + if self.attempts <= self.fail_count: + raise ConnectionError("Mock connection error") + # Success on final attempt + import dag_cbor + + return dag_cbor.encode([None] * 4) + + async def save(self, data, codec=None): + from multiformats import CID + + return CID.decode( + "bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm" + ) + + mock_cas = MockCAS(fail_count=2) # Fail twice, succeed on 3rd attempt + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + + # Set up a shard CID to fetch + from multiformats import CID + + shard_cid = CID.decode( + "bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm" + ) + store._root_obj["chunks"]["shard_cids"][0] = shard_cid + + shard_data = await store._load_or_initialize_shard_cache(0) + assert shard_data is not None + assert len(shard_data) == 4 + + # Test case where all retries fail + mock_cas_fail = MockCAS(fail_count=5) # Fail more than max_retries + store_fail = await ShardedZarrStore.open( + cas=mock_cas_fail, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + store_fail._root_obj["chunks"]["shard_cids"][0] = shard_cid + + with pytest.raises(RuntimeError, match="Failed to fetch shard 0 after 3 attempts"): + await store_fail._load_or_initialize_shard_cache(0) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_with_read_only_clone_attribute(): + """Test line 490: with_read_only clone attribute assignment""" + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + # Create clone with different read_only status (should hit line 490) + clone = store.with_read_only(True) + + # Verify line 490: clone._root_obj = self._root_obj + assert clone._root_obj is store._root_obj + assert clone.read_only is True + assert store.read_only is False + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_get_method_line_565(): + """Test line 565: get method start (line 565 is the method definition)""" + import zarr.core.buffer + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid, offset=None, length=None, suffix=None): + return b"metadata_content" + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + # Add metadata to test the get method + from multiformats import CID + + metadata_cid = CID.decode( + "bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm" + ) + store._root_obj["metadata"]["test.json"] = metadata_cid + + # Test get method (line 565 is the method signature) + proto = zarr.core.buffer.default_buffer_prototype() + result = await store.get("test.json", proto) + + assert result is not None + assert result.to_bytes() == b"metadata_content" + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_exists_exception_handling(): + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + # This should trigger the exception handling and return False + exists = await store.exists("invalid/chunk/key/format") + assert exists is False + + # Test exists with valid chunk key that's out of bounds (should also return False) + exists = await store.exists("c/100") # Out of bounds chunk + assert exists is False + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_cas_save_failure(): + """Test RuntimeError when cas.save fails in set method""" + import zarr.core.buffer + + from py_hamt import ShardedZarrStore + + class MockCASFailingSave: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def save(self, data, codec=None): + # Always fail to save + raise ConnectionError("Mock CAS save failure") + + mock_cas = MockCASFailingSave() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + proto = zarr.core.buffer.default_buffer_prototype() + test_data = proto.buffer.from_bytes(b"test_data") + + with pytest.raises(RuntimeError, match="Failed to save data for key test_key"): + await store.set("test_key", test_data) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_flush_dirty_shard_not_found(): + """Test RuntimeError when dirty shard not found in cache during flush""" + from unittest.mock import patch + + from multiformats import CID + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def save(self, data, codec=None): + return CID.decode( + "bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm" + ) + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=2, + ) + + # First put a shard in the cache and mark it as dirty + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + shard_data = [test_cid, None] + await store._shard_data_cache.put(0, shard_data, is_dirty=True) + + # Verify the shard is dirty + assert store._shard_data_cache.dirty_cache_size == 1 + assert 0 in store._shard_data_cache._dirty_shards + + # Mock the cache.get to return None for the dirty shard (simulating cache corruption) + original_get = store._shard_data_cache.get + + async def mock_get_returns_none(shard_idx): + if shard_idx == 0: # Return None for the dirty shard + return None + return await original_get(shard_idx) + + with patch.object( + store._shard_data_cache, "get", side_effect=mock_get_returns_none + ): + # This should hit line 529 (RuntimeError for dirty shard not found in cache) + with pytest.raises(RuntimeError, match="Dirty shard 0 not found in cache"): + await store.flush() + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_failed_to_load_or_initialize_shard(): + """Test RuntimeError when shard fails to load or initialize""" + from unittest.mock import patch + + from multiformats import CID + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + import dag_cbor + + return dag_cbor.encode([None] * 4) + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + + # Set up a shard CID to fetch + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + store._root_obj["chunks"]["shard_cids"][0] = test_cid + + # Mock the cache to always return None, even after put operations + async def mock_get_always_none(shard_idx): + return None # Always return None to simulate cache failure + + async def mock_put_does_nothing(shard_idx, shard_data, is_dirty=False): + pass # Do nothing, so cache remains empty + + with patch.object(store._shard_data_cache, "get", side_effect=mock_get_always_none): + with patch.object( + store._shard_data_cache, "put", side_effect=mock_put_does_nothing + ): + with pytest.raises( + RuntimeError, match="Failed to load or initialize shard 0" + ): + await store._load_or_initialize_shard_cache(0) + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_timeout_cleanup_logic(): + """Test timeout cleanup logic in _load_or_initialize_shard_cache""" + import asyncio + from unittest.mock import patch + + from multiformats import CID + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + # Never completes to simulate timeout + await asyncio.sleep(100) + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + + # Set up a shard CID + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + store._root_obj["chunks"]["shard_cids"][0] = test_cid + + # Manually create a pending load event to simulate the scenario + pending_event = asyncio.Event() + store._pending_shard_loads[0] = pending_event + + # Verify the pending load is set up + assert 0 in store._pending_shard_loads + assert not store._pending_shard_loads[0].is_set() + + # Mock wait_for to properly await the coroutine but still raise TimeoutError + async def mock_wait_for(coro, timeout=None): + # Properly cancel the coroutine to avoid the warning + if hasattr(coro, "close"): + coro.close() + raise asyncio.TimeoutError() + + with patch("asyncio.wait_for", side_effect=mock_wait_for): + with pytest.raises(RuntimeError, match="Timeout waiting for shard 0 to load"): + await store._load_or_initialize_shard_cache(0) + + # The event should be set and removed from pending loads + assert 0 not in store._pending_shard_loads # Should be cleaned up + + +@pytest.mark.asyncio +async def test_sharded_zarr_store_pending_load_cache_miss(): + """Test RuntimeError when pending load completes but shard not found in cache""" + import asyncio + from unittest.mock import patch + + from multiformats import CID + + from py_hamt import ShardedZarrStore + + class MockCAS: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def load(self, cid): + import dag_cbor + + return dag_cbor.encode([None] * 4) + + mock_cas = MockCAS() + store = await ShardedZarrStore.open( + cas=mock_cas, + read_only=False, + array_shape=(10,), + chunk_shape=(5,), + chunks_per_shard=4, + ) + + # Set up a shard CID + test_cid = CID.decode("bafyreihyrpefhacm6kkp4ql6j6udakdit7g3dmkzfriqfykhjw6cad7lrm") + store._root_obj["chunks"]["shard_cids"][0] = test_cid + + # Create a pending load event and manually add it + pending_event = asyncio.Event() + store._pending_shard_loads[0] = pending_event + + # Set up mocks: wait_for succeeds (doesn't timeout) but cache.get returns None + async def mock_wait_for(coro, timeout=None): + # Properly handle the coroutine to avoid warnings + if hasattr(coro, "close"): + coro.close() + # Simulate successful wait - the pending event gets set + pending_event.set() + return True # Successful wait + + async def mock_cache_get(shard_idx): + # Always return None to simulate cache miss after pending load + return None + + # Test the scenario where pending load "completes" but shard not in cache (line 428-430) + with patch("asyncio.wait_for", side_effect=mock_wait_for): + with patch.object(store._shard_data_cache, "get", side_effect=mock_cache_get): + with pytest.raises( + RuntimeError, + match="Shard 0 not found in cache after pending load completed", + ): + await store._load_or_initialize_shard_cache(0) diff --git a/tests/test_zarr_ipfs.py b/tests/test_zarr_ipfs.py index 95f29b6..a7769b0 100644 --- a/tests/test_zarr_ipfs.py +++ b/tests/test_zarr_ipfs.py @@ -139,11 +139,6 @@ async def test_write_read( with pytest.raises(NotImplementedError): await zhs.set_partial_values([]) - with pytest.raises(NotImplementedError): - await zhs.get_partial_values( - zarr.core.buffer.default_buffer_prototype(), [] - ) - previous_zarr_json: zarr.core.buffer.Buffer | None = await zhs.get( "zarr.json", zarr.core.buffer.default_buffer_prototype() ) diff --git a/tests/test_zarr_ipfs_encrypted.py b/tests/test_zarr_ipfs_encrypted.py index 8d65d1c..93aa74b 100644 --- a/tests/test_zarr_ipfs_encrypted.py +++ b/tests/test_zarr_ipfs_encrypted.py @@ -4,9 +4,8 @@ import pandas as pd import pytest import xarray as xr -import zarr -import zarr.core.buffer from Crypto.Random import get_random_bytes +from dag_cbor.ipld import IPLDKind from py_hamt import HAMT, KuboCAS, SimpleEncryptedZarrHAMTStore from py_hamt.zarr_hamt_store import ZarrHAMTStore @@ -91,9 +90,7 @@ async def test_encrypted_write_read( correct_key = get_random_bytes(32) wrong_key = get_random_bytes(32) header = b"test-encryption-header" - - root_cid = None - + root_cid: IPLDKind = None # --- Write Phase --- async with KuboCAS( rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url @@ -109,7 +106,7 @@ async def test_encrypted_write_read( assert ezhs_write != hamt_write assert ezhs_write.supports_writes - test_ds.to_zarr(store=ezhs_write, mode="w", zarr_format=3) # Use mode='w' + test_ds.to_zarr(store=ezhs_write, mode="w", zarr_format=3) # type: ignore await hamt_write.make_read_only() root_cid = hamt_write.root_node_id @@ -195,13 +192,8 @@ async def test_encrypted_write_read( with pytest.raises(NotImplementedError): await ezhs_read_ok.set_partial_values([]) - with pytest.raises(NotImplementedError): - await ezhs_read_ok.get_partial_values( - zarr.core.buffer.default_buffer_prototype(), [] - ) - with pytest.raises(Exception): - await ezhs_read_ok.set("new_key", np.array([b"a"], dtype=np.bytes_)) + await ezhs_read_ok.set("new_key", np.array([b"a"], dtype=np.bytes_)) # type: ignore with pytest.raises(Exception): await ezhs_read_ok.delete("zarr.json") diff --git a/tests/test_zarr_ipfs_partial.py b/tests/test_zarr_ipfs_partial.py new file mode 100644 index 0000000..84de087 --- /dev/null +++ b/tests/test_zarr_ipfs_partial.py @@ -0,0 +1,461 @@ +import time + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +import zarr +import zarr.core.buffer + +# Make sure to import the ByteRequest types +from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest + +from py_hamt import HAMT, InMemoryCAS, KuboCAS +from py_hamt.zarr_hamt_store import ZarrHAMTStore + + +@pytest.fixture(scope="module") +def random_zarr_dataset(): + """Creates a random xarray Dataset. + + Returns: + tuple: (dataset_path, expected_data) + - dataset_path: Path to the zarr store + - expected_data: The original xarray Dataset for comparison + """ + # Coordinates of the random data + times = pd.date_range("2024-01-01", periods=100) + lats = np.linspace(-90, 90, 18) + lons = np.linspace(-180, 180, 36) + + # Create random variables with different shapes + temp = np.random.randn(len(times), len(lats), len(lons)) + precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons))) + + # Create the dataset + ds = xr.Dataset( + { + "temp": ( + ["time", "lat", "lon"], + temp, + {"units": "celsius", "long_name": "Surface Temperature"}, + ), + "precip": ( + ["time", "lat", "lon"], + precip, + {"units": "mm/day", "long_name": "Daily Precipitation"}, + ), + }, + coords={ + "time": times, + "lat": ("lat", lats, {"units": "degrees_north"}), + "lon": ("lon", lons, {"units": "degrees_east"}), + }, + attrs={"description": "Test dataset with random weather data"}, + ) + + yield ds + + +# This test also collects miscellaneous statistics about performance, run with pytest -s to see these statistics being printed out +@pytest.mark.asyncio(loop_scope="session") # ← match the loop of the fixture +async def test_write_read( + create_ipfs, + random_zarr_dataset: xr.Dataset, +): # noqa for fixture which is imported above but then "redefined" + rpc_base_url, gateway_base_url = create_ipfs + test_ds = random_zarr_dataset + print("=== Writing this xarray Dataset to a Zarr v3 on IPFS ===") + print(test_ds) + + async with KuboCAS( # <-- own and auto-close session + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as kubo_cas: + hamt = await HAMT.build(cas=kubo_cas, values_are_bytes=True) + zhs = ZarrHAMTStore(hamt) + assert zhs.supports_writes + start = time.perf_counter() + # Do an initial write along with an append which is a common xarray/zarr operation + # Ensure chunks are not too small for partial value tests + test_ds.to_zarr(store=zhs, chunk_store={"time": 50, "lat": 18, "lon": 36}) # type: ignore + test_ds.to_zarr(store=zhs, mode="a", append_dim="time", zarr_format=3) # type: ignore + end = time.perf_counter() + elapsed = end - start + print("=== Write Stats") + print(f"Total time in seconds: {elapsed:.2f}") + print("=== Root CID") + await hamt.make_read_only() + cid = hamt.root_node_id + + print(f"=== Verifying Gateway Suffix Support (CID: {cid}) ===") + # Get the gateway URL without the /ipfs/ part + + print("=== Reading data back in and checking if identical") + hamt_read = await HAMT.build( # Renamed to avoid confusion + cas=kubo_cas, root_node_id=cid, values_are_bytes=True, read_only=True + ) + start = time.perf_counter() + zhs_read = ZarrHAMTStore(hamt_read, read_only=True) # Use the read-only hamt + ipfs_ds = xr.open_zarr(store=zhs_read) + print(ipfs_ds) + + # Check both halves, since each are an identical copy + ds1 = ipfs_ds.isel(time=slice(0, len(ipfs_ds.time) // 2)) + ds2 = ipfs_ds.isel(time=slice(len(ipfs_ds.time) // 2, len(ipfs_ds.time))) + xr.testing.assert_identical(ds1, ds2) + xr.testing.assert_identical(test_ds, ds1) + xr.testing.assert_identical(test_ds, ds2) + + end = time.perf_counter() + elapsed = end - start + print("=== Read Stats") + print(f"Total time in seconds: {elapsed:.2f}") + + # --- Start: New Partial Values Tests --- + + print("=== Testing get_partial_values ===") + proto = zarr.core.buffer.default_buffer_prototype() + + # Find a chunk key to test with (e.g., the first chunk of 'temp') + chunk_key = None + async for k in zhs_read.list(): + if k.startswith("temp/") and k != "temp/.zarray": + chunk_key = k + break + + assert chunk_key is not None, "Could not find a chunk key to test." + print(f"Testing with chunk key: {chunk_key}") + + # Get the full chunk data for comparison + full_chunk_buffer = await zhs_read.get(chunk_key, proto) + assert full_chunk_buffer is not None + full_chunk_data = full_chunk_buffer.to_bytes() + chunk_len = len(full_chunk_data) + print(f"Full chunk size: {chunk_len} bytes") + + # Ensure the chunk is large enough for meaningful tests + assert chunk_len > 100, "Chunk size too small for partial value tests" + + # Define some byte requests + range_req = RangeByteRequest(start=10, end=50) # Request 40 bytes + offset_req = OffsetByteRequest(offset=chunk_len - 30) # Request last 30 bytes + suffix_req = SuffixByteRequest(suffix=20) # Request last 20 bytes + + key_ranges_to_test = [ + (chunk_key, range_req), + (chunk_key, offset_req), + (chunk_key, suffix_req), + (chunk_key, None), # Full read + ] + + # Call get_partial_values + results = await zhs_read.get_partial_values(proto, key_ranges_to_test) # type: ignore + + # Assertions + assert len(results) == 4 + assert results[0] is not None + assert results[1] is not None + assert results[2] is not None + assert results[3] is not None + + # Check RangeByteRequest result + expected_range = full_chunk_data[10:50] + assert results[0].to_bytes() == expected_range, "RangeByteRequest failed" + print(f"RangeByteRequest: OK (Got {len(results[0].to_bytes())} bytes)") + + # Check OffsetByteRequest result + expected_offset = full_chunk_data[chunk_len - 30 :] + assert results[1].to_bytes() == expected_offset, "OffsetByteRequest failed" + print(f"OffsetByteRequest: OK (Got {len(results[1].to_bytes())} bytes)") + + # Check SuffixByteRequest result + expected_suffix = full_chunk_data[-20:] + # Broken until Kubo 0.36.0 + assert results[2].to_bytes() == expected_suffix, "SuffixByteRequest failed" + print(f"SuffixByteRequest: OK (Got {len(results[2].to_bytes())} bytes)") + + # Check full read result + assert results[3].to_bytes() == full_chunk_data, ( + "Full read via get_partial_values failed" + ) + print(f"Full Read: OK (Got {len(results[3].to_bytes())} bytes)") + + # --- End: New Partial Values Tests --- + + # Tests for code coverage's sake + assert await zhs_read.exists("zarr.json") + # __eq__ + assert zhs_read == zhs_read + assert zhs_read != hamt_read + assert not zhs_read.supports_writes + assert not zhs_read.supports_partial_writes + assert not ( + zhs_read.supports_deletes + ) # Should be true in read-only mode for HAMT? Usually False + + hamt_keys = set() + async for k in zhs_read.hamt.keys(): + hamt_keys.add(k) + + zhs_keys: set[str] = set() + async for k in zhs_read.list(): + zhs_keys.add(k) + assert hamt_keys == zhs_keys + + zhs_keys_prefix: set[str] = set() + async for k in zhs_read.list_prefix(""): + zhs_keys_prefix.add(k) + assert hamt_keys == zhs_keys_prefix + + with pytest.raises(NotImplementedError): + await zhs_read.set_partial_values([]) + + +@pytest.mark.asyncio +async def test_zarr_hamt_store_byte_request_errors(): + """Tests error handling for unsupported or invalid ByteRequest types.""" + cas = InMemoryCAS() + hamt = await HAMT.build(cas=cas, values_are_bytes=True) + zhs = ZarrHAMTStore(hamt) + proto = zarr.core.buffer.default_buffer_prototype() + await zhs.set("some_key", proto.buffer.from_bytes(b"0123456789")) + + # Test for ValueError with an invalid range (end < start) + invalid_range_req = RangeByteRequest(start=10, end=5) + with pytest.raises(ValueError, match="End must be >= start for RangeByteRequest"): + await zhs.get("some_key", proto, byte_range=invalid_range_req) + + # Test for TypeError with a custom, unsupported request type + class DummyUnsupportedRequest: + pass + + unsupported_req = DummyUnsupportedRequest() + with pytest.raises(TypeError, match="Unsupported ByteRequest type"): + await zhs.get("some_key", proto, byte_range=unsupported_req) + + +@pytest.mark.asyncio +async def test_ipfs_gateway_compression_behavior(create_ipfs): + """ + Test to verify whether IPFS gateways decompress data before applying + byte range requests, which would negate compression benefits for partial reads. + + This test creates highly compressible data, stores it via IPFS, and then + compares the bytes returned by partial vs full reads to determine if + the gateway is operating on compressed or decompressed data. + """ + rpc_base_url, gateway_base_url = create_ipfs + + print("\n=== IPFS Gateway Compression Behavior Test ===") + + # Create highly compressible test data + print("Creating highly compressible test data...") + data = np.zeros((50, 50, 50), dtype=np.float32) # 500KB of zeros + # Add small amount of variation + data[0:5, 0:5, 0:5] = np.random.randn(5, 5, 5) + + ds = xr.Dataset({"compressible_var": (["x", "y", "z"], data)}) + + print(f"Original data shape: {data.shape}") + print(f"Original data size: {data.nbytes:,} bytes") + + # Custom CAS to track actual network transfers + class NetworkTrackingKuboCAS(KuboCAS): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.load_sizes = {} + self.save_sizes = {} + + async def save(self, data, codec=None): + cid = await super().save(data, codec) + self.save_sizes[str(cid)] = len(data) + print(f"Saved to IPFS: {str(cid)[:12]}... ({len(data):,} bytes)") + return cid + + async def load(self, cid, offset=None, length=None, suffix=None): + result = await super().load(cid, offset, length, suffix) + + range_desc = "full" + if offset is not None or length is not None or suffix is not None: + range_desc = f"offset={offset}, length={length}, suffix={suffix}" + + key = f"{str(cid)[:12]}... ({range_desc})" + self.load_sizes[key] = len(result) + print(f"Loaded from IPFS: {key} -> {len(result):,} bytes") + return result + + async with NetworkTrackingKuboCAS( + rpc_base_url=rpc_base_url, gateway_base_url=gateway_base_url + ) as tracking_cas: + # Write dataset with compression + print("\n=== Writing to IPFS with Blosc compression ===") + hamt = await HAMT.build(cas=tracking_cas, values_are_bytes=True) + store = ZarrHAMTStore(hamt) + + # Use smaller chunks to ensure meaningful compression + ds.chunk({"x": 25, "y": 25, "z": 25}).to_zarr( + store=store, mode="w", zarr_format=3 + ) + + await hamt.make_read_only() + root_cid = hamt.root_node_id + print(f"Root CID: {root_cid}") + + # Read back and test compression behavior + print("\n=== Testing Compression vs Partial Reads ===") + hamt_read = await HAMT.build( + cas=tracking_cas, + root_node_id=root_cid, + values_are_bytes=True, + read_only=True, + ) + store_read = ZarrHAMTStore(hamt_read, read_only=True) + + # Find the largest data chunk (likely the actual array data) + chunk_key = None + chunk_size = 0 + async for key in store_read.list(): + if ( + "compressible_var" in key + and not key.endswith(".zarray") + and not key.endswith("zarr.json") + ): + # Get size to find the largest chunk + proto = zarr.core.buffer.default_buffer_prototype() + chunk_buffer = await store_read.get(key, proto) + if chunk_buffer and len(chunk_buffer.to_bytes()) > chunk_size: + chunk_key = key + chunk_size = len(chunk_buffer.to_bytes()) + + assert chunk_key is not None, "No data chunk found" + print(f"Testing with largest chunk: {chunk_key}") + + # Get full chunk for baseline + proto = zarr.core.buffer.default_buffer_prototype() + full_chunk = await store_read.get(chunk_key, proto) + full_compressed_size = len(full_chunk.to_bytes()) + print(f"Full chunk compressed size: {full_compressed_size:,} bytes") + + # Calculate expected uncompressed size + # 25x25x25 float32 = 62,500 bytes uncompressed + expected_uncompressed_size = 25 * 25 * 25 * 4 + compression_ratio = expected_uncompressed_size / full_compressed_size + print(f"Estimated compression ratio: {compression_ratio:.1f}:1") + + # Test different partial read sizes + test_ranges = [ + (0, full_compressed_size // 4, "25% of compressed"), + (0, full_compressed_size // 2, "50% of compressed"), + (full_compressed_size // 4, full_compressed_size // 2, "25%-50% range"), + ] + + print("\n=== Partial Read Analysis ===") + print("If gateway operates on compressed data:") + print(" - Partial reads should return exactly the requested byte ranges") + print(" - Network transfer should be proportional to compressed size") + print("If gateway decompresses before range requests:") + print(" - Partial reads may return more data than expected") + print(" - Network transfer loses compression benefits") + print() + + compression_preserved = True + + for start, end, description in test_ranges: + length = end - start + byte_req = RangeByteRequest(start=start, end=end) + + # Clear the load tracking for this specific test + original_load_count = len(tracking_cas.load_sizes) + + partial_chunk = await store_read.get(chunk_key, proto, byte_range=byte_req) + partial_size = len(partial_chunk.to_bytes()) + + # Find the new load entry + new_loads = list(tracking_cas.load_sizes.items())[original_load_count:] + network_bytes = new_loads[0][1] if new_loads else partial_size + + expected_percentage = length / full_compressed_size + actual_percentage = partial_size / full_compressed_size + network_efficiency = network_bytes / full_compressed_size + + print(f"Range request: {description}") + print( + f" Requested: {length:,} bytes ({expected_percentage:.1%} of compressed)" + ) + print( + f" Received: {partial_size:,} bytes ({actual_percentage:.1%} of compressed)" + ) + print( + f" Network transfer: {network_bytes:,} bytes ({network_efficiency:.1%} of compressed)" + ) + + # Key test: if we get significantly more data than requested, + # the gateway is likely decompressing before applying ranges + if partial_size > length * 1.1: # 10% tolerance for overhead + compression_preserved = False + print( + f" ⚠️ Received {partial_size / length:.1f}x more data than requested!" + ) + print(" ⚠️ Gateway appears to decompress before applying ranges") + else: + print(" ✅ Range applied efficiently to compressed data") + + # Verify the partial data makes sense + full_data = full_chunk.to_bytes() + expected_partial = full_data[start:end] + assert partial_chunk.to_bytes() == expected_partial, ( + "Partial data doesn't match expected range" + ) + print(" ✅ Partial data content verified") + print() + + # Summary analysis + print("=== Final Analysis ===") + if compression_preserved: + print("✅ IPFS gateway preserves compression benefits for partial reads") + print(" - Byte ranges are applied to compressed data") + print(" - Network transfers are efficient") + else: + print("⚠️ IPFS gateway appears to decompress before applying ranges") + print(" - Partial reads may not provide expected bandwidth savings") + print(" - Consider alternative storage strategies (sharding, etc.)") + + print("\nCompression statistics:") + print(f" - Uncompressed chunk size: {expected_uncompressed_size:,} bytes") + print(f" - Compressed chunk size: {full_compressed_size:,} bytes") + print(f" - Compression ratio: {compression_ratio:.1f}:1") + print(f" - Compression preserved in partial reads: {compression_preserved}") + + +@pytest.mark.asyncio +async def test_in_memory_cas_partial_reads(): + """ + Tests the partial read logic (offset, length, suffix) in the InMemoryCAS. + """ + cas = InMemoryCAS() + test_data = b"0123456789abcdefghijklmnopqrstuvwxyz" # 36 bytes long + data_id = await cas.save(test_data, "raw") + + # Test case 1: offset and length + result = await cas.load(data_id, offset=10, length=5) + assert result == b"abcde" + + # Test case 2: offset only + result = await cas.load(data_id, offset=30) + assert result == b"uvwxyz" + + # Test case 3: suffix only + result = await cas.load(data_id, suffix=6) + assert result == b"uvwxyz" + + # Test case 4: Full read (for completeness) + result = await cas.load(data_id) + assert result == test_data + + # Test case 5: Key not found (covers `try...except KeyError`) + with pytest.raises(KeyError, match="Object not found in in-memory store"): + await cas.load(b"\x00\x01\x02\x03\x04") # Some random, non-existent key + + # Test case 6: Invalid key type (covers `isinstance` check) + with pytest.raises(TypeError, match="InMemoryCAS only supports byte‐hash keys"): + await cas.load(12345) # Pass an integer instead of bytes diff --git a/tests/testing_utils.py b/tests/testing_utils.py index e8e4422..229762c 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -26,17 +26,15 @@ def cid_strategy() -> SearchStrategy: def ipld_strategy() -> SearchStrategy: - return st.one_of( - [ - st.none(), - st.booleans(), - st.integers(min_value=-9223372036854775808, max_value=9223372036854775807), - st.floats(allow_infinity=False, allow_nan=False), - st.text(), - st.binary(), - cid_strategy(), - ] - ) + return st.one_of([ + st.none(), + st.booleans(), + st.integers(min_value=-9223372036854775808, max_value=9223372036854775807), + st.floats(allow_infinity=False, allow_nan=False), + st.text(), + st.binary(), + cid_strategy(), + ]) key_value_list = st.lists(