Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions py_hamt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .hamt import HAMT, blake3_hashfn
from .store import Store, DictStore, IPFSStore
from .zarr_encryption_transformer import create_zarr_encryption_transformers

__all__ = [
"HAMT",
"blake3_hashfn",
"Store",
"DictStore",
"IPFSStore",
"create_zarr_encryption_transformers",
]
21 changes: 18 additions & 3 deletions py_hamt/hamt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import dag_cbor
from dag_cbor.ipld import IPLDKind
from multiformats import multihash

from .store import Store


Expand Down Expand Up @@ -204,6 +203,11 @@ class HAMT(MutableMapping):
You can modify the read status of a HAMT through the `make_read_only` or `enable_write` functions, so that the HAMT will block on making a change until all mutating operations are done.
"""

transformer_encode: None | Callable[[str, bytes], bytes]
"""This function is called to transform a value being inserted right before it gets sent to the Store."""
transformer_decode: None | Callable[[str, bytes], bytes]
"""This function is called to decode a value being retrieved after it is returned from the Store."""

cache: dict[StoreID, Node]
"""@private"""
max_cache_size_bytes: int
Expand Down Expand Up @@ -250,6 +254,8 @@ def __init__(
read_only: bool = False,
root_node_id: IPLDKind = None,
max_cache_size_bytes=10_000_000, # default to 10 megabytes
transformer_encode: None | Callable[[str, bytes], bytes] = None,
transformer_decode: None | Callable[[str, bytes], bytes] = None,
):
Comment on lines +257 to 259
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Update __deepcopy__ to include transformer functions.

The __deepcopy__ method (lines 281-295) needs to be updated to include the transformer functions when creating a copy of the HAMT instance.

Apply this diff to fix the issue:

 def __deepcopy__(self, memo) -> "HAMT":
     if not self.read_only:
         self.lock.acquire(blocking=True)

     copy_hamt = HAMT(
         store=self.store,
         hash_fn=self.hash_fn,
         read_only=self.read_only,
         root_node_id=self.root_node_id,
+        transformer_encode=self.transformer_encode,
+        transformer_decode=self.transformer_decode,
     )

     if not self.read_only:
         self.lock.release()

     return copy_hamt

Also applies to: 269-271

self.store = store
self.hash_fn = hash_fn
Expand All @@ -261,6 +267,9 @@ def __init__(
self.read_only = read_only
self.lock = Lock()

self.transformer_encode = transformer_encode
self.transformer_decode = transformer_decode

if root_node_id is None:
root_node = Node()
self.root_node_id = self.write_node(root_node)
Expand Down Expand Up @@ -340,7 +349,10 @@ def __setitem__(self, key_to_insert: str, val_to_insert: IPLDKind):
if not self.read_only:
self.lock.acquire(blocking=True)

val_ptr = self.store.save_raw(dag_cbor.encode(val_to_insert))
val = dag_cbor.encode(val_to_insert)
if self.transformer_encode is not None:
val = self.transformer_encode(key_to_insert, val)
val_ptr = self.store.save_raw(val)

node_stack: list[tuple[Link, Node]] = []
root_node = self.read_node(self.root_node_id)
Expand Down Expand Up @@ -536,7 +548,10 @@ def __getitem__(self, key: str) -> IPLDKind:
if not found_a_result:
raise KeyError

return dag_cbor.decode(self.store.load(result_ptr))
result_bytes = self.store.load(result_ptr)
if self.transformer_decode is not None:
result_bytes = self.transformer_decode(key, result_bytes)
return dag_cbor.decode(result_bytes)

def __len__(self) -> int:
key_count = 0
Expand Down
78 changes: 78 additions & 0 deletions py_hamt/zarr_encryption_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Callable
from pathlib import Path


import io

from Crypto.Cipher import ChaCha20_Poly1305
from Crypto.Random import get_random_bytes


# Metadata files to avoid encrypting (and decrypting)
encryption_exclude_files = [
# top level meta data
".zattrs",
".zgroup",
".zmetadata",
# found within variables, this includes .zattrs
".zarray",
# important for coordinate variables, so that we can read bounds
"0",
]

TransformerFN = Callable[[str, bytes], bytes]


def create_zarr_encryption_transformers(
encryption_key: bytes,
encrypted_vars: [str],
) -> tuple[TransformerFN, TransformerFN]:
"""An encryption filter for ZARR data.
This class is serialized and stored along with the Zarr it is used with, so instead
of storing the encryption key, we store the hash of the encryption key, so it can be
looked up in the key registry at run time as needed.
Parameters
----------
key_hash: str
The hex digest of an encryption key. A key may be generated using
:func:`generate_encryption_key`. The hex digest is obtained by registering the
key using :func:`register_encryption_key`.
"""
Comment on lines +30 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update docstring to match function parameters.

The docstring refers to a key_hash parameter, but the function accepts encryption_key. Update the docstring to reflect the actual parameters.

-    Parameters
-    ----------
-    key_hash: str
-        The hex digest of an encryption key. A key may be generated using
-        :func:`generate_encryption_key`. The hex digest is obtained by registering the
-        key using :func:`register_encryption_key`.
+    Parameters
+    ----------
+    encryption_key: bytes
+        The encryption key used for ChaCha20-Poly1305 cipher.
+    encrypted_vars: [str]
+        List of variable names to be encrypted.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""An encryption filter for ZARR data.
This class is serialized and stored along with the Zarr it is used with, so instead
of storing the encryption key, we store the hash of the encryption key, so it can be
looked up in the key registry at run time as needed.
Parameters
----------
key_hash: str
The hex digest of an encryption key. A key may be generated using
:func:`generate_encryption_key`. The hex digest is obtained by registering the
key using :func:`register_encryption_key`.
"""
"""An encryption filter for ZARR data.
This class is serialized and stored along with the Zarr it is used with, so instead
of storing the encryption key, we store the hash of the encryption key, so it can be
looked up in the key registry at run time as needed.
Parameters
----------
encryption_key: bytes
The encryption key used for ChaCha20-Poly1305 cipher.
encrypted_vars: [str]
List of variable names to be encrypted.
"""


# codec_id = "xchacha20poly1305"
header = b"dClimate-Zarr"

def _should_transform_key(key: str) -> bool:
if Path(key).name in encryption_exclude_files:
return False
return key.split("/")[0] in encrypted_vars

def encode(key: str, val: bytes) -> bytes:
if not _should_transform_key(key):
return val
raw = io.BytesIO()
raw.write(val)
nonce = get_random_bytes(24) # XChaCha uses 24 byte (192 bit) nonce
cipher = ChaCha20_Poly1305.new(key=encryption_key, nonce=nonce)
cipher.update(header)
ciphertext, tag = cipher.encrypt_and_digest(raw.getbuffer())

return nonce + tag + ciphertext

def decode(key: str, val: bytes):
if not _should_transform_key(key):
return val

nonce, tag, ciphertext = val[:24], val[24:40], val[40:]
cipher = ChaCha20_Poly1305.new(key=encryption_key, nonce=nonce)
cipher.update(header)
plaintext = cipher.decrypt_and_verify(ciphertext, tag)

Comment on lines +62 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for malformed input in decode function.

The decode function should handle potential errors when processing malformed input:

  1. Input shorter than expected (< 40 bytes)
  2. Invalid tag during verification
 def decode(key: str, val: bytes):
     if not _should_transform_key(key):
         return val
 
+    if len(val) < 40:
+        raise ValueError("Input too short: encrypted data must be at least 40 bytes")
+
     nonce, tag, ciphertext = val[:24], val[24:40], val[40:]
     cipher = ChaCha20_Poly1305.new(key=encryption_key, nonce=nonce)
     cipher.update(header)
-    plaintext = cipher.decrypt_and_verify(ciphertext, tag)
+    try:
+        plaintext = cipher.decrypt_and_verify(ciphertext, tag)
+    except ValueError as e:
+        raise ValueError(f"Decryption failed: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def decode(key: str, val: bytes):
if not _should_transform_key(key):
return val
nonce, tag, ciphertext = val[:24], val[24:40], val[40:]
cipher = ChaCha20_Poly1305.new(key=encryption_key, nonce=nonce)
cipher.update(header)
plaintext = cipher.decrypt_and_verify(ciphertext, tag)
def decode(key: str, val: bytes):
if not _should_transform_key(key):
return val
if len(val) < 40:
raise ValueError("Input too short: encrypted data must be at least 40 bytes")
nonce, tag, ciphertext = val[:24], val[24:40], val[40:]
cipher = ChaCha20_Poly1305.new(key=encryption_key, nonce=nonce)
cipher.update(header)
try:
plaintext = cipher.decrypt_and_verify(ciphertext, tag)
except ValueError as e:
raise ValueError(f"Decryption failed: {str(e)}")

# if out is not None:
# outbuf = io.BytesIO(plaintext)
# outbuf.readinto(out)
# return out

Comment on lines +71 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove commented out code.

The commented out code should be removed as it's not being used and could create confusion.

-        # if out is not None:
-        #     outbuf = io.BytesIO(plaintext)
-        #     outbuf.readinto(out)
-        #     return out
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# if out is not None:
# outbuf = io.BytesIO(plaintext)
# outbuf.readinto(out)
# return out

return plaintext

return encode, decode
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"msgspec>=0.18.6",
"multiformats[full]>=0.3.1.post4",
"requests>=2.32.3",
"pycryptodome",
]

[build-system]
Expand All @@ -26,6 +27,6 @@ dev = [
"snakeviz>=2.2.0",
"pandas>=2.2.3",
"numpy>=2.1.3",
"xarray>=2024.11.0",
"zarr>=2.18.3",
"xarray==2024.11.0",
"zarr==2.18.3",
]
173 changes: 173 additions & 0 deletions tests/test_zarr_encryption_ipfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import os
import shutil
import tempfile

from multiformats import CID
import numpy as np
import pandas as pd
import xarray as xr
import pytest
import time


from Crypto.Random import get_random_bytes

from py_hamt import HAMT, IPFSStore, create_zarr_encryption_transformers


@pytest.fixture
def random_zarr_dataset():
"""Creates a random xarray Dataset and saves it to a temporary zarr store.

Returns:
tuple: (dataset_path, expected_data)
- dataset_path: Path to the zarr store
- expected_data: The original xarray Dataset for comparison
"""
# Create temporary directory for zarr store
temp_dir = tempfile.mkdtemp()
zarr_path = os.path.join(temp_dir, "test.zarr")

# Coordinates of the random data
times = pd.date_range("2024-01-01", periods=100)
lats = np.linspace(-90, 90, 18)
lons = np.linspace(-180, 180, 36)

# Create random variables with different shapes
temp = np.random.randn(len(times), len(lats), len(lons))
precip = np.random.gamma(2, 0.5, size=(len(times), len(lats), len(lons)))

# Create the dataset
ds = xr.Dataset(
{
"temp": (
["time", "lat", "lon"],
temp,
{"units": "celsius", "long_name": "Surface Temperature"},
),
"precip": (
["time", "lat", "lon"],
precip,
{"units": "mm/day", "long_name": "Daily Precipitation"},
),
},
coords={
"time": times,
"lat": ("lat", lats, {"units": "degrees_north"}),
"lon": ("lon", lons, {"units": "degrees_east"}),
},
attrs={"description": "Test dataset with random weather data"},
)

ds.to_zarr(zarr_path, mode="w")

yield zarr_path, ds

# Cleanup
shutil.rmtree(temp_dir)


def test_upload_then_read(random_zarr_dataset: tuple[str, xr.Dataset]):
zarr_path, expected_ds = random_zarr_dataset
test_ds = xr.open_zarr(zarr_path)

# update precip and temp to have crypto: ["chacha"]
# TODO: THIS SHOULD BE DONE IN THE ZARRAY but it doesn't appear like xarray allows this
test_ds["precip"].attrs["crypto"] = ["xchacha20poly1305"]
test_ds["temp"].attrs["crypto"] = ["xchacha20poly1305"]
print("Writing this xarray Dataset to IPFS")
print(test_ds)

start_time = time.time()
encryption_key = get_random_bytes(32)

encode, decode = create_zarr_encryption_transformers(
encryption_key, ["temp", "precip"]
)
hamt1 = HAMT(
store=IPFSStore(pin_on_add=False),
transformer_encode=encode,
transformer_decode=decode,
)
test_ds.to_zarr(store=hamt1, mode="w")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Address “Unsupported type for store_like” pipeline error.
According to the pipeline failure logs, xarray raises a “TypeError: Unsupported type for store_like: 'HAMT'” on this line. You may need to provide a store implementation that xarray recognizes, or adapt the HAMT object to fulfill xarray’s requirements for a Zarr store-like.

🧰 Tools
🪛 GitHub Actions: Triggered on push from TheGreatAlgo to branch/tag crypto-mod

[error] 174-174: TypeError: Unsupported type for store_like: 'HAMT'


end_time = time.time()
total_time = end_time - start_time
print(f"Adding with encryption took {total_time:.2f} seconds")

start_time = time.time()
hamt2 = HAMT(
store=IPFSStore(pin_on_add=False),
)
test_ds.to_zarr(store=hamt2, mode="w")
end_time = time.time()
total_time = end_time - start_time
print(f"Adding without encryption took {total_time:.2f} seconds")

hamt1_root: CID = hamt1.root_node_id # type: ignore
hamt2_root: CID = hamt2.root_node_id # type: ignore
print(f"No pin on add root CID: {hamt1_root}")
print(f"Pin on add root CID: {hamt2_root}")

print("Reading in from IPFS")
hamt1_read = HAMT(
store=IPFSStore(),
root_node_id=hamt1_root,
read_only=True,
transformer_encode=encode,
transformer_decode=decode,
)
hamt2_read = HAMT(
store=IPFSStore(),
root_node_id=hamt2_root,
read_only=True,
)
start_time = time.time()
loaded_ds1 = xr.open_zarr(store=hamt1_read)
print(loaded_ds1)
loaded_ds2 = xr.open_zarr(store=hamt2_read)
end_time = time.time()
# Assert the values are the same
# Check if the values of 'temp' and 'precip' are equal in all datasets
assert np.array_equal(loaded_ds1["temp"].values, expected_ds["temp"].values), (
"Temp values in loaded_ds1 and expected_ds are not identical!"
)
assert np.array_equal(loaded_ds1["precip"].values, expected_ds["precip"].values), (
"Precip values in loaded_ds1 and expected_ds are not identical!"
)
assert np.array_equal(loaded_ds2["temp"].values, expected_ds["temp"].values), (
"Temp values in loaded_ds2 and expected_ds are not identical!"
)
assert np.array_equal(loaded_ds2["precip"].values, expected_ds["precip"].values), (
"Precip values in loaded_ds2 and expected_ds are not identical!"
)
# xr.testing.assert_identical(loaded_ds1, loaded_ds2)
# xr.testing.assert_identical(loaded_ds1, expected_ds)
total_time = (end_time - start_time) / 2
print(
f"Took {total_time:.2f} seconds on average to load between the two loaded datasets"
)

# Test with no encryption key

encrypted_no_transformer = HAMT(
store=IPFSStore(),
root_node_id=hamt1_root,
read_only=True,
)
loaded_failure = xr.open_zarr(store=encrypted_no_transformer)
# Accessing data should raise an exception since we don't have the encryption key or the transformer
with pytest.raises(Exception):
loaded_failure["temp"].values
Comment on lines +160 to +161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use a more specific exception type.
Using pytest.raises(Exception) is overly broad, and static analysis flags the attribute access as “useless.” A better practice is to specify the exact exception type expected when accessing encrypted data without the transformer.

Example fix:

- with pytest.raises(Exception):
-     loaded_failure["temp"].values
+ with pytest.raises(ValueError):
+     _ = loaded_failure["temp"].values
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with pytest.raises(Exception):
loaded_failure["temp"].values
with pytest.raises(ValueError):
_ = loaded_failure["temp"].values
🧰 Tools
🪛 Ruff (0.8.2)

153-153: pytest.raises(Exception) should be considered evil

(B017)


154-154: Found useless attribute access. Either assign it to a variable or remove it.

(B018)


assert "temp" in loaded_ds1
assert "precip" in loaded_ds1
assert loaded_ds1.temp.attrs["units"] == "celsius"

assert loaded_ds1.temp.shape == expected_ds.temp.shape

assert "temp" in loaded_ds2
assert "precip" in loaded_ds2
assert loaded_ds2.temp.attrs["units"] == "celsius"

assert loaded_ds2.temp.shape == expected_ds.temp.shape
Loading
Loading