Skip to content

DeepHash fails on JAX arrays in tripartite make with equinox Modules #1427

@peabody124

Description

@peabody124

Summary

deepdiff.DeepHash (used by DataJoint's tripartite make for referential integrity checks) fails when the fetched data contains JAX arrays, as is common when using equinox Modules as data containers.

Minimal Reproduction

import jax.numpy as jnp
import deepdiff

# JAX 0-d scalar
deepdiff.DeepHash(jnp.float32(1.0))
# TypeError: iteration over a 0-d array

# JAX 1-d array
deepdiff.DeepHash(jnp.ones(3))
# TypeError: iteration over a 0-d array

# equinox Module with JAX fields (common pattern for datasets)
import equinox as eqx

class MyDataset(eqx.Module):
    data: jnp.ndarray
    scalar: jnp.ndarray
    def __init__(self):
        self.data = jnp.ones((3, 4))
        self.scalar = jnp.float32(1.0)

deepdiff.DeepHash(MyDataset())
# TypeError: iteration over a 0-d array

Root Cause

DeepHash._hash checks isinstance(obj, Iterable) (line ~79 of deephash.py). JAX arrays implement __iter__, so they match. But JAX 0-d arrays (scalars) raise TypeError: iteration over a 0-d array when iterated — unlike numpy 0-d arrays which deepdiff handles via the numbers type check (numpy scalars are registered as numbers).

This affects the tripartite make pattern because _populate1 calls DeepHash(fetched_data) on whatever make_fetch returns. If the fetched data includes JAX arrays (e.g., a dataset loaded from the database), the integrity check fails before computation even starts.

Workaround

We monkey-patch DeepHash._hash to convert JAX arrays to numpy before hashing:

import numpy as np
import deepdiff.deephash as _dh
from deepdiff.helper import get_id
from jaxlib._jax import ArrayImpl

_orig_hash = _dh.DeepHash._hash

def _patched_hash(self, obj, parent, parents_ids=frozenset()):
    if isinstance(obj, ArrayImpl):
        jax_key = get_id(obj)
        obj = np.asarray(obj)
        if obj.ndim == 0:
            obj = obj.item()
        result = _orig_hash(self, obj, parent, parents_ids)
        self.hashes[jax_key] = result
        return result
    return _orig_hash(self, obj, parent, parents_ids)

_dh.DeepHash._hash = _patched_hash

Suggested Fix

DeepHash could handle array-like objects more robustly by:

  1. Checking for numpy/JAX array types before the Iterable check
  2. Converting array-likes to numpy via np.asarray() (which both numpy and JAX support)
  3. Or registering JAX scalar types alongside numpy scalar types in the numbers tuple

Versions

  • deepdiff: 9.0.0
  • jax: 0.9.2
  • datajoint: 0.14.9
  • equinox: 0.12.5

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementIndicates new improvements

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions