diff --git a/py_hamt/zarr_encryption_transformers.py b/py_hamt/zarr_encryption_transformers.py index 3014f88..7b5033f 100644 --- a/py_hamt/zarr_encryption_transformers.py +++ b/py_hamt/zarr_encryption_transformers.py @@ -1,6 +1,8 @@ -from typing import Callable -from pathlib import Path +import json +from typing import Callable, Literal +import dag_cbor +import xarray as xr from Crypto.Cipher import ChaCha20_Poly1305 from Crypto.Random import get_random_bytes @@ -11,27 +13,34 @@ def create_zarr_encryption_transformers( encryption_key: bytes, header: bytes, exclude_vars: list[str] = [], + detect_exclude: xr.Dataset + | Literal["auto-from-read"] + | Literal[False] = "auto-from-read", ) -> tuple[TransformerFN, TransformerFN]: """ Uses XChaCha20_Poly1305 from the pycryptodome library to perform encryption, while ignoring zarr metadata files. https://pycryptodome.readthedocs.io/en/latest/src/cipher/chacha20_poly1305.html - Note that the encryption key must always be 32 bytes long. A header is required by the underlying encryption algorithm. Every time a zarr chunk is encrypted, a random 24-byte nonce is generated. This is saved with the chunk for use when reading back. + Note that the encryption key must be exactly 32 bytes long. A header is required by the underlying encryption algorithm. Every time a zarr chunk is encrypted, a random 24-byte nonce is generated. This is saved with the chunk for use when reading back. - zarr.json metadata files in a zarr v3 are always ignored, to allow for calculating an encrypted zarr's structure without having the encryption key. + zarr.json metadata files in a zarr v3 are always ignored and passed through unencrypted. - With `exclude_vars` you may also set some variables to be unencrypted. This allows for partially encrypted zarrs which can be loaded into xarray but the values of encrypted variables cannot be accessed (errors will be thrown). You should generally include your coordinate variables along with your data variables in here. + With `exclude_vars` you may also set some variables to be unencrypted. This allows for partially encrypted zarrs. This should generally include your coordinate variables, along with any data variables you want to keep open. + + `detect_exclude` allows you to put in a xarray Dataset. This will be used to automatically add coordinate variables to the exclusion list. When you reading back a dataset and you do not know the unencrypted variables ahead of time, you can set this to the default "auto-from-read", which will attempt to use any metadata or any decryption errors to detect unencrypted variables. + + To do no automatic detection, set `detect_exclude` to False. # Example code ```python - from py_hamt import HAMT, IPFSStore, IPFSZarr3 + from py_hamt import HAMT, IPFSStore, IPFSZarr3, create_zarr_encryption_transformers ds = ... # example xarray Dataset with precip and temp data variables encryption_key = bytes(32) # change before using, only for demonstration purposes! header = "sample-header".encode() encrypt, decrypt = create_zarr_encryption_transformers( - encryption_key, header, exclude_vars=["temp"] + encryption_key, header, exclude_vars=["temp"], detect_exclude=ds ) hamt = HAMT( store=IPFSStore(), transformer_encode=encrypt, transformer_decode=decrypt @@ -40,7 +49,15 @@ def create_zarr_encryption_transformers( ds.to_zarr(store=ipfszarr3, mode="w") print("Attempting to read and print metadata of partially encrypted zarr") - enc_ds = xr.open_zarr(store=ipfszarr3, read_only=True) + wrong_key = bytes([0xAA]) * 32 + wrong_header = "".encode() + bad_encrypt, auto_detecting_decrypt = create_zarr_encryption_transformers( + wrong_key, + wrong_header, + ) + hamt = HAMT(store=IPFSStore(), transformer_encode=bad_encrypt, transformer_decode=auto_detecting_decrypt, root_node_id=ipfszarr3.hamt.root_node_id) + ipfszarr3 = IPFSZarr3(hamt, read_only=True) + enc_ds = xr.open_zarr(store=ipfszarr3) print(enc_ds) assert enc_ds.temp.sum() == ds.temp.sum() try: @@ -49,20 +66,27 @@ def create_zarr_encryption_transformers( print("Couldn't read encrypted variable") ``` """ - if len(encryption_key) != 32: raise ValueError("Encryption key is not 32 bytes") - def _should_transform(key: str) -> bool: - p = Path(key) + exclude_var_set = set(exclude_vars) - # Find the first directory name in the path since zarr v3 chunks are stored in a nested directory structure - # e.g. for Path("precip/c/0/0/1") it would return "precip" - if p.parts[0] in exclude_vars: - return False + if isinstance(detect_exclude, xr.Dataset): + ds = detect_exclude + for coord in list(ds.coords): + exclude_var_set.add(coord) # type: ignore - # Don't transform metadata files - if p.name == "zarr.json": + def _should_transform(key: str) -> bool: + # Find the first directory name in the path since zarr v3 chunks are stored in a nested directory structure + # e.g. for "precip/c/0/0/1" this would find "precip" + first_slash = key.find("/") + if first_slash != -1: + var_name = key[0:first_slash] + if var_name in exclude_var_set: + return False + + # Don't transform metadata files, even for encrypted variables + if key[-9:] == "zarr.json": return False return True @@ -78,14 +102,53 @@ def encrypt(key: str, val: bytes) -> bytes: # + concatenates two byte variables x,y so that it looks like xy return nonce + tag + ciphertext + seen_metadata: set[str] = set() + def decrypt(key: str, val: bytes) -> bytes: + # Look through files, this relies on the fact that xarray itself will attempt to read the root zarr.json and other metadata files first before any data will ever be accessed + # Important that this goes before _should_transform since that will return before we get a chance to look at metadata, and it needs information that we can glean here + if ( + detect_exclude == "auto-from-read" + and key[-9:] == "zarr.json" + and key not in seen_metadata + ): + seen_metadata.add(key) + + # Assume the zarr.json is unencrypted, which it should be if made with zarr encryption transformers + metadata = json.loads(dag_cbor.decode(val)) # type: ignore + + # If the global zarr.json, check if it has the list of coordinates in the consolidated metadata + if ( + "consolidated_metadata" in metadata + and metadata["consolidated_metadata"] is not None + ): + variables = metadata["consolidated_metadata"]["metadata"] + for var in variables: + for dimension in variables[var]["dimension_names"]: + exclude_var_set.add(dimension) + # Otherwise just scan a variable's individual metadata, but first make sure it's not the root zarr.json + elif "dimension_names" in metadata: + for dimension in metadata["dimension_names"]: + exclude_var_set.add(dimension) + if not _should_transform(key): return val - nonce, tag, ciphertext = val[:24], val[24:40], val[40:] - cipher = ChaCha20_Poly1305.new(key=encryption_key, nonce=nonce) - cipher.update(header) - plaintext = cipher.decrypt_and_verify(ciphertext, tag) - return plaintext + try: + nonce, tag, ciphertext = val[:24], val[24:40], val[40:] + cipher = ChaCha20_Poly1305.new(key=encryption_key, nonce=nonce) + cipher.update(header) + plaintext = cipher.decrypt_and_verify(ciphertext, tag) + return plaintext + except Exception as e: + # If if we are auto detecting coordinates, and there's an error with decrypting, then assume the issue is that this is a partially encrypted zarr, so we need to mark this variable as being one of the unencrypted ones and return like normal + if detect_exclude == "auto-from-read": + first_slash = key.find("/") + if first_slash != -1: + var_name = key[0:first_slash] + exclude_var_set.add(var_name) + return val + else: + raise e return (encrypt, decrypt) diff --git a/tests/test_zarr_ipfs.py b/tests/test_zarr_ipfs.py index 91d59ec..3344181 100644 --- a/tests/test_zarr_ipfs.py +++ b/tests/test_zarr_ipfs.py @@ -187,7 +187,8 @@ def test_encryption(random_zarr_dataset: tuple[str, xr.Dataset]): encrypt, decrypt = create_zarr_encryption_transformers( encryption_key, header="sample-header".encode(), - exclude_vars=["lat", "lon", "time", "temp"], + exclude_vars=["temp"], + detect_exclude=test_ds, ) hamt = HAMT( store=IPFSStore(), transformer_encode=encrypt, transformer_decode=decrypt @@ -195,15 +196,27 @@ def test_encryption(random_zarr_dataset: tuple[str, xr.Dataset]): ipfszarr3 = IPFSZarr3(hamt) test_ds.to_zarr(store=ipfszarr3) # type: ignore + ipfszarr3 = IPFSZarr3(ipfszarr3.hamt, read_only=True) ipfs_ds = xr.open_zarr(store=ipfszarr3) xr.testing.assert_identical(ipfs_ds, test_ds) # Now trying to load without a decryptor, xarray should be able to read the metadata and still perform operations on the unencrypted variable print("=== Attempting to read and print metadata of partially encrypted zarr") + bad_key = bytes([0xAA]) * 32 + bad_header = "".encode() + bad_encrypt, auto_detecting_decrypt = create_zarr_encryption_transformers( + bad_key, + bad_header, + ) ds = xr.open_zarr( store=IPFSZarr3( - HAMT(store=IPFSStore(), root_node_id=ipfszarr3.hamt.root_node_id), + HAMT( + store=IPFSStore(), + root_node_id=ipfszarr3.hamt.root_node_id, + transformer_encode=bad_encrypt, + transformer_decode=auto_detecting_decrypt, + ), read_only=True, ) ) @@ -213,6 +226,25 @@ def test_encryption(random_zarr_dataset: tuple[str, xr.Dataset]): with pytest.raises(Exception): ds.precip.sum() + # For code coverage's sake + # Don't auto detect, and thus encounter an error when trying to read back an unencrypted variable with the wrong encryption key and header + bad_encrypt, bad_decrypt = create_zarr_encryption_transformers( + bad_key, bad_header, detect_exclude=False + ) + with pytest.raises(Exception): + ds = xr.open_zarr( + store=IPFSZarr3( + HAMT( + store=IPFSStore(), + root_node_id=ipfszarr3.hamt.root_node_id, + transformer_encode=bad_encrypt, + transformer_decode=bad_decrypt, + ), + read_only=True, + ) + ) + assert ds.temp.sum() == test_ds.temp.sum() + # This test assumes the other zarr ipfs tests are working fine, so if other things are breaking check those first def test_authenticated_gateway(random_zarr_dataset: tuple[str, xr.Dataset]):