-
Notifications
You must be signed in to change notification settings - Fork 96
DeepHash fails on JAX arrays in tripartite make with equinox Modules #1427
Description
Summary
deepdiff.DeepHash (used by DataJoint's tripartite make for referential integrity checks) fails when the fetched data contains JAX arrays, as is common when using equinox Modules as data containers.
Minimal Reproduction
import jax.numpy as jnp
import deepdiff
# JAX 0-d scalar
deepdiff.DeepHash(jnp.float32(1.0))
# TypeError: iteration over a 0-d array
# JAX 1-d array
deepdiff.DeepHash(jnp.ones(3))
# TypeError: iteration over a 0-d array
# equinox Module with JAX fields (common pattern for datasets)
import equinox as eqx
class MyDataset(eqx.Module):
data: jnp.ndarray
scalar: jnp.ndarray
def __init__(self):
self.data = jnp.ones((3, 4))
self.scalar = jnp.float32(1.0)
deepdiff.DeepHash(MyDataset())
# TypeError: iteration over a 0-d arrayRoot Cause
DeepHash._hash checks isinstance(obj, Iterable) (line ~79 of deephash.py). JAX arrays implement __iter__, so they match. But JAX 0-d arrays (scalars) raise TypeError: iteration over a 0-d array when iterated — unlike numpy 0-d arrays which deepdiff handles via the numbers type check (numpy scalars are registered as numbers).
This affects the tripartite make pattern because _populate1 calls DeepHash(fetched_data) on whatever make_fetch returns. If the fetched data includes JAX arrays (e.g., a dataset loaded from the database), the integrity check fails before computation even starts.
Workaround
We monkey-patch DeepHash._hash to convert JAX arrays to numpy before hashing:
import numpy as np
import deepdiff.deephash as _dh
from deepdiff.helper import get_id
from jaxlib._jax import ArrayImpl
_orig_hash = _dh.DeepHash._hash
def _patched_hash(self, obj, parent, parents_ids=frozenset()):
if isinstance(obj, ArrayImpl):
jax_key = get_id(obj)
obj = np.asarray(obj)
if obj.ndim == 0:
obj = obj.item()
result = _orig_hash(self, obj, parent, parents_ids)
self.hashes[jax_key] = result
return result
return _orig_hash(self, obj, parent, parents_ids)
_dh.DeepHash._hash = _patched_hashSuggested Fix
DeepHash could handle array-like objects more robustly by:
- Checking for numpy/JAX array types before the
Iterablecheck - Converting array-likes to numpy via
np.asarray()(which both numpy and JAX support) - Or registering JAX scalar types alongside numpy scalar types in the
numberstuple
Versions
- deepdiff: 9.0.0
- jax: 0.9.2
- datajoint: 0.14.9
- equinox: 0.12.5