DeepHash fails on JAX arrays in tripartite make with equinox Modules #1428
Replies: 2 comments
-
|
Thanks for the detailed report. This is a deepdiff limitation — JAX arrays implement Immediate workaround (no DataJoint changes needed)Since you control what import numpy as np
import jax
def to_hashable(obj):
"""Convert JAX arrays to numpy for DeepHash compatibility."""
return jax.tree.map(
lambda x: np.asarray(x) if hasattr(x, '__jax_array__') else x, obj
)Then in your tripartite make, wrap the yield: def make(self, key):
# fetch
data = (SomeTable & key).fetch1()
model = reconstruct_equinox_model(data)
yield to_hashable(model) # DeepHash sees numpy arrays
# compute
result = train(model)
yield result
# insert
self.insert1({**key, "result": result})This converts JAX arrays (including 0-d scalars inside equinox Modules) to numpy before Longer termWe're considering adding a |
Beta Was this translation helpful? Give feedback.
-
|
Will need to think about that but is a bit tricky as the data fetching is rather involved. I guess we can look into migrating what is essentially preprocessing into the compute step though. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Summary
deepdiff.DeepHash(used by DataJoint's tripartitemakefor referential integrity checks) fails when the fetched data contains JAX arrays, as is common when using equinox Modules as data containers.Minimal Reproduction
Root Cause
DeepHash._hashchecksisinstance(obj, Iterable)(line ~79 ofdeephash.py). JAX arrays implement__iter__, so they match. But JAX 0-d arrays (scalars) raiseTypeError: iteration over a 0-d arraywhen iterated — unlike numpy 0-d arrays which deepdiff handles via thenumberstype check (numpy scalars are registered as numbers).This affects the tripartite
makepattern because_populate1callsDeepHash(fetched_data)on whatevermake_fetchreturns. If the fetched data includes JAX arrays (e.g., a dataset loaded from the database), the integrity check fails before computation even starts.Workaround
We monkey-patch
DeepHash._hashto convert JAX arrays to numpy before hashing:Suggested Fix
DeepHash could handle array-like objects more robustly by:
Iterablechecknp.asarray()(which both numpy and JAX support)numberstupleVersions
Beta Was this translation helpful? Give feedback.
All reactions