Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 116 additions & 2 deletions py_hamt/hamt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from sys import getsizeof
from threading import Lock
from typing import Callable
import re
import json

import dag_cbor
from dag_cbor.ipld import IPLDKind
from multiformats import multihash

from .store import Store


Expand Down Expand Up @@ -124,7 +125,18 @@ def blake3_hashfn(input_bytes: bytes) -> bytes:
return raw_bytes


class HAMT(MutableMapping):
class HAMT:
def __new__(cls, store, transformer=None, **kwargs):
"""
Instantiating HAMT directly returns an instance of HAMT or TransformedHamt,
depending on whether a transformer is provided.
"""
if transformer:
return TransformedHamt(store=store, transformer=transformer, **kwargs)
return HAMTOriginal(store=store, **kwargs)


class HAMTOriginal(MutableMapping):
Comment on lines +128 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than having to differentiate and add this layer of indirection, we can take in the transformer functions in the HAMT class initiation and then apply them right before values are either sent or retrieved form the store.

"""
This HAMT presents a key value interface, like a python dictionary. The only limits are that keys can only be strings, and values can only be types amenable with [IPLDKind](https://dag-cbor.readthedocs.io/en/stable/api/dag_cbor.ipld.html#dag_cbor.ipld.IPLDKind). IPLDKind is a fairly flexible data model, but do note that integers are must be within the bounds of a signed 64-bit integer.

Expand Down Expand Up @@ -268,6 +280,32 @@ def __init__(
self.root_node_id = root_node_id
# Make sure the cache has our root node
self.read_node(self.root_node_id)
self.check_crypto()

def check_crypto(self):
"""
Checks if any crypto exist in the HAMT's `.zarray` metadata and validates them against the transformer's codec_id.

Raises:
Exception: If a transformer is required but not present.
Exception: If a codec_id exists but no crypto are found.
Exception: If crypto exist but the codec_id does not match.
"""
has_codec_id = hasattr(self, "codec_id")
found_filter = False
for key in self:
if key.endswith("/.zarray"):
metadata_bytes = self[key]
metadata = json.loads(metadata_bytes.decode("utf-8"))
crypto = metadata.get("crypto", None)
if crypto:
found_filter = True
if not has_codec_id:
raise Exception("Codec ID not found in transformer")
if crypto[0]["id"] != self.codec_id:
raise Exception("Codec ID does not match transformer")
if not found_filter and has_codec_id:
raise Exception("Codec ID found in transformer but no crypto found in metadata")

# dunder for the python deepcopy module
def __deepcopy__(self, memo) -> "HAMT":
Expand Down Expand Up @@ -600,3 +638,79 @@ def ids(self):
links = top_node.get_links()
for link in links.values():
node_id_stack.append(link)


class TransformedHamt(HAMTOriginal):
"""
A wrapper around the HAMT class that applies a transformation function
when setting or getting items, except for specific Zarr metadata keys.
"""

ZARR_METADATA_KEYS = {".zarray", ".zgroup", ".zattrs", ".zmetadata"}
INDEX_KEY_PATTERN = re.compile(r".*/\d+$") # Matches keys like lat/0, lon/0, time/0

def __init__(self, *args, transformer=None, encrypt_vars=[], **kwargs):
self.transformer = transformer
self.codec_id = transformer.codec_id
self.encrypt_vars = encrypt_vars
super().__init__(*args, **kwargs)

Comment on lines +652 to +657
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix mutable default argument.

Using mutable default arguments can lead to unexpected behavior.

Apply this diff to fix the issue:

-    def __init__(self, *args, transformer=None, encrypt_vars=[], **kwargs):
+    def __init__(self, *args, transformer=None, encrypt_vars=None, **kwargs):
         self.transformer = transformer
         self.codec_id = transformer.codec_id
-        self.encrypt_vars = encrypt_vars
+        self.encrypt_vars = encrypt_vars if encrypt_vars is not None else []
         super().__init__(*args, **kwargs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(self, *args, transformer=None, encrypt_vars=[], **kwargs):
self.transformer = transformer
self.codec_id = transformer.codec_id
self.encrypt_vars = encrypt_vars
super().__init__(*args, **kwargs)
def __init__(self, *args, transformer=None, encrypt_vars=None, **kwargs):
self.transformer = transformer
self.codec_id = transformer.codec_id
self.encrypt_vars = encrypt_vars if encrypt_vars is not None else []
super().__init__(*args, **kwargs)
🧰 Tools
🪛 Ruff (0.8.2)

652-652: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

def _update_metadata_crypto(self, value):
"""Update the metadata to include the codec in the crypto list."""
try:
# Decode bytes to JSON dict
metadata = json.loads(value.decode('utf-8'))
except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
return value # Return original if not decodable

codec_filter = {'id': self.codec_id}
if "crypto" not in metadata.keys():
metadata['crypto'] = []
# Append codec filter if not already present
if codec_filter not in metadata['crypto']:
metadata['crypto'].append(codec_filter)
# Encode back to JSON bytes
return json.dumps(metadata).encode('utf-8')

def _is_data_var_metadata(self, key):
"""Check if the key corresponds to a data variable's metadata (e.g., "precip/.zarray")."""
parts = key.split('/')
return len(parts) == 2 and parts[1] in ".zarray" and parts[0] in self.encrypt_vars

def _should_transform_key(self, key):
"""Determines whether the key should be transformed."""
if any(meta_key in key for meta_key in self.ZARR_METADATA_KEYS):
return False # Don't transform metadata keys

if self.INDEX_KEY_PATTERN.fullmatch(key):
return False # Don't transform single index keys like lat/0 . The downside is that this can't be used on 1 dimensional data arrays.
# I think this alligns with xarray since is recognizes these as coordinates and not data variables
# If those were transformed it would limit xarrays ability to init any zarr as it fetches the bounds of the coordinate arrays which shouldn't be encrypted
parts = key.split('/')
# Get the zarray for parts[0] and see if crypto exists
metadata = self[parts[0] + "/.zarray"]
if not metadata:
return False
metadata = json.loads(metadata.decode('utf-8'))
crypto = metadata.get("crypto", [])
# If no crypto then don't transform data
if not crypto:
return False
return True # Transform everything else (e.g., chunked data like precip/1.0.0)

def __setitem__(self, key, value):
# If the item being set is a zarr data key like precip/.zarray then we want to include the codec_id in the crypto like crypto: [ {id: "codec_id"} ] if not set
# The problem, is that we don't know what is the data variable ahead of time, so we can't just check if the key is in the metadata keys
# My current fallback is requiring the user to indicate what data variables they are encrypting, and then we can check if the key is in that list
if self.transformer:
if self._should_transform_key(key):
value = self.transformer.encode(value)
elif self._is_data_var_metadata(key):
value = self._update_metadata_crypto(value)
super().__setitem__(key, value)

def __getitem__(self, key):
value = super().__getitem__(key)
if self.transformer and self._should_transform_key(key):
value = self.transformer.decode(value)
return value
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ dev = [
"xarray>=2024.11.0",
"zarr>=2.18.3",
]

[tool.uv]
dev-dependencies = [
"pycryptodome",
]
Comment on lines +33 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Move pycryptodome to runtime dependencies.

Since the encryption functionality is part of the core feature set and not just for development/testing, pycryptodome should be moved to the runtime dependencies section.

Apply this diff to fix the dependency:

-[tool.uv]
-dev-dependencies = [
-    "pycryptodome",
-]
+[project]
+dependencies = [
     "dag-cbor>=0.3.3",
     "msgspec>=0.18.6",
     "multiformats[full]>=0.3.1.post4",
     "requests>=2.32.3",
+    "pycryptodome>=3.20.0",
 ]

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +33 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coderabbit also pointed this out, but yes this should be in the project dependencies rather than dev dependencies.

190 changes: 190 additions & 0 deletions tests/test_zarr_encryption_ipfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import os
import shutil
import tempfile

from multiformats import CID
import numpy as np
import pandas as pd
import xarray as xr
import pytest
import time

import io

from Crypto.Cipher import ChaCha20_Poly1305
from Crypto.Random import get_random_bytes

from py_hamt import HAMT, IPFSStore


class EncryptionFilter:
"""An encryption filter for ZARR data.
This class is serialized and stored along with the Zarr it is used with, so instead
of storing the encryption key, we store the hash of the encryption key, so it can be
looked up in the key registry at run time as needed.
Parameters
----------
key_hash: str
The hex digest of an encryption key. A key may be generated using
:func:`generate_encryption_key`. The hex digest is obtained by registering the
key using :func:`register_encryption_key`.
"""

codec_id = "xchacha20poly1305"
header = b"dClimate-Zarr"

def __init__(self, encryption_key: str):
self.encryption_key = encryption_key

def encode(self, buf):
raw = io.BytesIO()
raw.write(buf)
nonce = get_random_bytes(24) # XChaCha uses 24 byte (192 bit) nonce
cipher = ChaCha20_Poly1305.new(key=self.encryption_key, nonce=nonce)
cipher.update(self.header)
ciphertext, tag = cipher.encrypt_and_digest(raw.getbuffer())

return nonce + tag + ciphertext

def decode(self, buf, out=None):
nonce, tag, ciphertext = buf[:24], buf[24:40], buf[40:]
cipher = ChaCha20_Poly1305.new(key=self.encryption_key, nonce=nonce)
cipher.update(self.header)
plaintext = cipher.decrypt_and_verify(ciphertext, tag)

if out is not None:
outbuf = io.BytesIO(plaintext)
outbuf.readinto(out)
return out

return plaintext
Comment on lines +20 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance security by adding key validation and error handling.

The EncryptionFilter class needs improvements:

  1. Add validation for the encryption key length (should be 32 bytes for ChaCha20-Poly1305).
  2. Add error handling for encryption/decryption failures.

Apply this diff to improve the implementation:

 def __init__(self, encryption_key: str):
+    if not isinstance(encryption_key, bytes) or len(encryption_key) != 32:
+        raise ValueError("encryption_key must be 32 bytes")
     self.encryption_key = encryption_key

 def encode(self, buf):
+    try:
         raw = io.BytesIO()
         raw.write(buf)
         nonce = get_random_bytes(24)  # XChaCha uses 24 byte (192 bit) nonce
         cipher = ChaCha20_Poly1305.new(key=self.encryption_key, nonce=nonce)
         cipher.update(self.header)
         ciphertext, tag = cipher.encrypt_and_digest(raw.getbuffer())
         return nonce + tag + ciphertext
+    except Exception as e:
+        raise ValueError(f"Encryption failed: {str(e)}")

 def decode(self, buf, out=None):
+    try:
         nonce, tag, ciphertext = buf[:24], buf[24:40], buf[40:]
         cipher = ChaCha20_Poly1305.new(key=self.encryption_key, nonce=nonce)
         cipher.update(self.header)
         plaintext = cipher.decrypt_and_verify(ciphertext, tag)
+    except ValueError:
+        raise ValueError("Decryption failed: Invalid key or corrupted data")
+    except Exception as e:
+        raise ValueError(f"Decryption failed: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class EncryptionFilter:
"""An encryption filter for ZARR data.
This class is serialized and stored along with the Zarr it is used with, so instead
of storing the encryption key, we store the hash of the encryption key, so it can be
looked up in the key registry at run time as needed.
Parameters
----------
key_hash: str
The hex digest of an encryption key. A key may be generated using
:func:`generate_encryption_key`. The hex digest is obtained by registering the
key using :func:`register_encryption_key`.
"""
codec_id = "xchacha20poly1305"
header = b"dClimate-Zarr"
def __init__(self, encryption_key: str):
self.encryption_key = encryption_key
def encode(self, buf):
raw = io.BytesIO()
raw.write(buf)
nonce = get_random_bytes(24) # XChaCha uses 24 byte (192 bit) nonce
cipher = ChaCha20_Poly1305.new(key=self.encryption_key, nonce=nonce)
cipher.update(self.header)
ciphertext, tag = cipher.encrypt_and_digest(raw.getbuffer())
return nonce + tag + ciphertext
def decode(self, buf, out=None):
nonce, tag, ciphertext = buf[:24], buf[24:40], buf[40:]
cipher = ChaCha20_Poly1305.new(key=self.encryption_key, nonce=nonce)
cipher.update(self.header)
plaintext = cipher.decrypt_and_verify(ciphertext, tag)
if out is not None:
outbuf = io.BytesIO(plaintext)
outbuf.readinto(out)
return out
return plaintext
class EncryptionFilter:
"""An encryption filter for ZARR data.
This class is serialized and stored along with the Zarr it is used with, so instead
of storing the encryption key, we store the hash of the encryption key, so it can be
looked up in the key registry at run time as needed.
Parameters
----------
key_hash: str
The hex digest of an encryption key. A key may be generated using
:func:`generate_encryption_key`. The hex digest is obtained by registering the
key using :func:`register_encryption_key`.
"""
codec_id = "xchacha20poly1305"
header = b"dClimate-Zarr"
def __init__(self, encryption_key: str):
if not isinstance(encryption_key, bytes) or len(encryption_key) != 32:
raise ValueError("encryption_key must be 32 bytes")
self.encryption_key = encryption_key
def encode(self, buf):
try:
raw = io.BytesIO()
raw.write(buf)
nonce = get_random_bytes(24) # XChaCha uses 24 byte (192 bit) nonce
cipher = ChaCha20_Poly1305.new(key=self.encryption_key, nonce=nonce)
cipher.update(self.header)
ciphertext, tag = cipher.encrypt_and_digest(raw.getbuffer())
return nonce + tag + ciphertext
except Exception as e:
raise ValueError(f"Encryption failed: {str(e)}")
def decode(self, buf, out=None):
try:
nonce, tag, ciphertext = buf[:24], buf[24:40], buf[40:]
cipher = ChaCha20_Poly1305.new(key=self.encryption_key, nonce=nonce)
cipher.update(self.header)
plaintext = cipher.decrypt_and_verify(ciphertext, tag)
except ValueError:
raise ValueError("Decryption failed: Invalid key or corrupted data")
except Exception as e:
raise ValueError(f"Decryption failed: {str(e)}")
if out is not None:
outbuf = io.BytesIO(plaintext)
outbuf.readinto(out)
return out
return plaintext

Comment on lines +20 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than use a class, we should use a factory function that returns closures. This allows the transformers input on the HAMT class to be more flexible and accept a wider array of generic functions, rather than limit to either just this class type, or force clients to inherit from this class type.



@pytest.fixture
def random_zarr_dataset():
"""Creates a random xarray Dataset and saves it to a temporary zarr store.
Returns:
tuple: (dataset_path, expected_data)
- dataset_path: Path to the zarr store
- expected_data: The original xarray Dataset for comparison
"""
# Create temporary directory for zarr store
temp_dir = tempfile.mkdtemp()
zarr_path = os.path.join(temp_dir, "test.zarr")

# Coordinates of the random data
times = pd.date_range("2024-01-01", periods=100)
lats = np.linspace(-90, 90, 18)
lons = np.linspace(-180, 180, 36)

# Create random variables with different shapes
temp = np.random.randn(len(times), len(lats), len(lons))
precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons)))

# Create the dataset
ds = xr.Dataset(
{
"temp": (
["time", "lat", "lon"],
temp,
{"units": "celsius", "long_name": "Surface Temperature"},
),
"precip": (
["time", "lat", "lon"],
precip,
{"units": "mm/day", "long_name": "Daily Precipitation"},
),
},
coords={
"time": times,
"lat": ("lat", lats, {"units": "degrees_north"}),
"lon": ("lon", lons, {"units": "degrees_east"}),
},
attrs={"description": "Test dataset with random weather data"},
)

ds.to_zarr(zarr_path, mode="w")

yield zarr_path, ds

# Cleanup
shutil.rmtree(temp_dir)


def test_upload_then_read(random_zarr_dataset: tuple[str, xr.Dataset]):
zarr_path, expected_ds = random_zarr_dataset
test_ds = xr.open_zarr(zarr_path)

print("Writing this xarray Dataset to IPFS")
print(test_ds)

start_time = time.time()
encryption_key = get_random_bytes(32)
hamt1 = HAMT(
store=IPFSStore(pin_on_add=False),
transformer=EncryptionFilter(encryption_key=encryption_key),
encrypt_vars=["temp", "precip"],
)
test_ds.to_zarr(store=hamt1, mode="w")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

⚠️ Potential issue

🧩 Analysis chain

Fix TypeError: Unsupported type for store_like.

The pipeline failure indicates that test_ds.to_zarr(store=hamt1, mode="w") is failing because TransformedHamt is not a supported store type.

Run the following script to check if TransformedHamt implements all required methods for a zarr store:


🏁 Script executed:

#!/bin/bash
# Description: Check if TransformedHamt implements all required methods for a zarr store.

# Test: Search for required zarr store methods in TransformedHamt.
ast-grep --pattern $'class TransformedHamt {
  $$$
  __getitem__($$$) {
    $$$
  }
  __setitem__($$$) {
    $$$
  }
  __delitem__($$$) {
    $$$
  }
  __iter__($$$) {
    $$$
  }
  __len__($$$) {
    $$$
  }
  $$$
}'

Length of output: 218


I'll generate a new shell script to locate the definition of TransformedHamt and inspect its implementation to verify if it includes the required methods for a Zarr store.

Please run this script and share the output so we can verify whether TransformedHamt implements the required methods for a Zarr store.


🏁 Script executed:

#!/bin/bash
# 1. Find the file containing the definition of TransformedHamt.
FILE=$(rg -l "class TransformedHamt" . | head -n 1)
if [ -z "$FILE" ]; then
  echo "No file found containing 'class TransformedHamt'"
  exit 1
fi
echo "TransformedHamt class found in file: $FILE"

# 2. Print context of the class definition.
echo "Extracting class definition (20 lines context):"
rg -n "class TransformedHamt" -A 20 "$FILE"

# 3. Check for the required methods in that file.
echo "Searching for required methods in $FILE:"
echo "---- __getitem__ ----"
rg -n "def __getitem__" "$FILE"
echo "---- __setitem__ ----"
rg -n "def __setitem__" "$FILE"
echo "---- __delitem__ ----"
rg -n "def __delitem__" "$FILE"
echo "---- __iter__ ----"
rg -n "def __iter__" "$FILE"
echo "---- __len__ ----"
rg -n "def __len__" "$FILE"

Length of output: 2274

TransformedHamt Implements the Required Methods but Lacks Proper Interface Inheritance

Our investigation confirms that the TransformedHamt class (in py_hamt/hamt.py) contains implementations for the standard mapping methods (__getitem__, __setitem__, __delitem__, __iter__, and __len__). However, despite this, the pipeline still fails with a TypeError ("Unsupported type for store_like") when passing an instance of this class to to_zarr.

• Though the methods exist, Zarr likely verifies the store by checking for a valid mutable mapping type—typically done via an inheritance from (or registration with) collections.abc.MutableMapping.

• It appears that TransformedHamt (or its base class HAMTOriginal) is not declared as a subclass of MutableMapping, so even with the required methods, it isn’t recognized as a valid Zarr store.

To resolve this issue, please review and modify the class inheritance hierarchy so that TransformedHamt (and/or its parent) explicitly extends collections.abc.MutableMapping. This should meet the interface expectations required by Zarr.

🧰 Tools
🪛 GitHub Actions: Triggered on push from TheGreatAlgo to branch/tag crypto

[error] 129-129: TypeError: Unsupported type for store_like: 'TransformedHamt'

end_time = time.time()
total_time = end_time - start_time
print(f"Adding with encryption took {total_time:.2f} seconds")

start_time = time.time()
hamt2 = HAMT(
store=IPFSStore(pin_on_add=False),
)
test_ds.to_zarr(store=hamt2, mode="w")
end_time = time.time()
total_time = end_time - start_time
print(f"Adding without encryption took {total_time:.2f} seconds")

hamt1_root: CID = hamt1.root_node_id # type: ignore
hamt2_root: CID = hamt2.root_node_id # type: ignore
print(f"No pin on add root CID: {hamt1_root}")
print(f"Pin on add root CID: {hamt2_root}")

print("Reading in from IPFS")
hamt1_read = HAMT(
store=IPFSStore(),
root_node_id=hamt1_root,
read_only=True,
transformer=EncryptionFilter(encryption_key=encryption_key),
)
hamt2_read = HAMT(
store=IPFSStore(),
root_node_id=hamt2_root,
read_only=True,
)
start_time = time.time()
loaded_ds1 = xr.open_zarr(store=hamt1_read)
print(loaded_ds1)
loaded_ds2 = xr.open_zarr(store=hamt2_read)
end_time = time.time()
xr.testing.assert_identical(loaded_ds1, loaded_ds2)
xr.testing.assert_identical(loaded_ds1, expected_ds)
total_time = (end_time - start_time) / 2
print(
f"Took {total_time:.2f} seconds on average to load between the two loaded datasets"
)

# Test with no encryption key
with pytest.raises(Exception, match="Codec ID not found in transformer"):
HAMT(
store=IPFSStore(),
root_node_id=hamt1_root,
read_only=True,
)

assert "temp" in loaded_ds1
assert "precip" in loaded_ds1
assert loaded_ds1.temp.attrs["units"] == "celsius"

assert loaded_ds1.temp.shape == expected_ds.temp.shape

assert "temp" in loaded_ds2
assert "precip" in loaded_ds2
assert loaded_ds2.temp.attrs["units"] == "celsius"

assert loaded_ds2.temp.shape == expected_ds.temp.shape
Loading
Loading