From b3a02e1b62a35b31bf7d1531a05678784654c521 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 14 Aug 2023 09:51:50 -0700 Subject: [PATCH] jnp.ufunc: add __hash__ method and jit methods by default This allows the JIT cache to work properly with ufunc methods, because bound methods are created with a new ID each time. --- jax/_src/numpy/ufunc_api.py | 41 +++++++++++++++++++++++++----- tests/lax_numpy_ufuncs_test.py | 46 +++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index ac6818226d66..698aa63cad64 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -38,13 +38,37 @@ class ufunc: This is a class for LAX-backed implementations of numpy ufuncs. """ def __init__(self, func, /, nin, nout, *, name=None, nargs=None, identity=None): - # TODO(jakevdp): validate the signature of func via eval_shape. + # We want ufunc instances to work properly when marked as static, + # and for this reason it's important that their properties not be + # mutated. We prevent this by storing them in a dunder attribute, + # and accessing them via read-only properties. self.__name__ = name or func.__name__ - self._call = vectorize(func) - self.nin = operator.index(nin) - self.nout = operator.index(nout) - self.nargs = nargs or self.nin - self.identity = identity + self.__static_props = { + 'func': func, + 'call': vectorize(func), + 'nin': operator.index(nin), + 'nout': operator.index(nout), + 'nargs': operator.index(nargs or nin), + 'identity': identity + } + + _func = property(lambda self: self.__static_props['func']) + _call = property(lambda self: self.__static_props['call']) + nin = property(lambda self: self.__static_props['nin']) + nout = property(lambda self: self.__static_props['nout']) + nargs = property(lambda self: self.__static_props['nargs']) + identity = property(lambda self: self.__static_props['identity']) + + def __hash__(self): + # Do not include _call, because it is computed from _func. + return hash((self._func, self.__name__, self.identity, + self.nin, self.nout, self.nargs)) + + def __eq__(self, other): + # Do not include _call, because it is computed from _func. + return isinstance(other, ufunc) and ( + (self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) == + (other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs)) def __repr__(self): return f"" @@ -57,6 +81,7 @@ def __call__(self, *args, out=None, where=None, **kwargs): return self._call(*args, **kwargs) @_wraps(np.ufunc.reduce, module="numpy.ufunc") + @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, initial=None, where=None): if self.nin != 2: raise ValueError("reduce only supported for binary ufuncs") @@ -120,6 +145,7 @@ def body_fun(i, val): return result @_wraps(np.ufunc.accumulate, module="numpy.ufunc") + @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def accumulate(self, a, axis=0, dtype=None, out=None): if self.nin != 2: raise ValueError("accumulate only supported for binary ufuncs") @@ -150,6 +176,7 @@ def scan_fun(carry, _): return _moveaxis(result, 0, axis) @_wraps(np.ufunc.accumulate, module="numpy.ufunc") + @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) def at(self, a, indices, b=None, /, *, inplace=True): if inplace: raise NotImplementedError(_AT_INPLACE_WARNING) @@ -184,6 +211,7 @@ def scan_fun(carry, x): return carry[1] @_wraps(np.ufunc.reduceat, module="numpy.ufunc") + @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def reduceat(self, a, indices, axis=0, dtype=None, out=None): if self.nin != 2: raise ValueError("reduceat only supported for binary ufuncs") @@ -220,6 +248,7 @@ def loop_body(i, out): return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) @_wraps(np.ufunc.outer, module="numpy.ufunc") + @partial(jax.jit, static_argnums=[0]) def outer(self, A, B, /, **kwargs): if self.nin != 2: raise ValueError("outer only supported for binary ufuncs") diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index ea263f63e52b..73e5cd010630 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -33,14 +33,26 @@ def scalar_add(x, y): return x + y +def scalar_div(x, y): + assert np.shape(x) == np.shape(y) == () + return x / y + + def scalar_mul(x, y): assert np.shape(x) == np.shape(y) == () return x * y +def scalar_sub(x, y): + assert np.shape(x) == np.shape(y) == () + return x - y + + SCALAR_FUNCS = [ {'func': scalar_add, 'nin': 2, 'nout': 1, 'identity': 0}, + {'func': scalar_div, 'nin': 2, 'nout': 1, 'identity': None}, {'func': scalar_mul, 'nin': 2, 'nout': 1, 'identity': 1}, + {'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None}, ] broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] @@ -54,6 +66,34 @@ def wrapped(*args, **kwargs): class LaxNumpyUfuncTests(jtu.JaxTestCase): + + @jtu.sample_product(SCALAR_FUNCS) + def test_ufunc_properties(self, func, nin, nout, identity): + jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) + self.assertEqual(jnp_fun.identity, identity) + self.assertEqual(jnp_fun.nin, nin) + self.assertEqual(jnp_fun.nout, nout) + self.assertEqual(jnp_fun.nargs, nin) + + @jtu.sample_product(SCALAR_FUNCS) + def test_ufunc_properties_readonly(self, func, nin, nout, identity): + jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) + for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']: + getattr(jnp_fun, attr) # no error on attribute access. + with self.assertRaises(AttributeError): + setattr(jnp_fun, attr, None) # error when trying to mutate. + + @jtu.sample_product(SCALAR_FUNCS) + def test_ufunc_hash(self, func, nin, nout, identity): + jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) + jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) + self.assertEqual(jnp_fun, jnp_fun_2) + self.assertEqual(hash(jnp_fun), hash(jnp_fun_2)) + + other_fun = jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0) + self.assertNotEqual(jnp_fun, other_fun) + # Note: don't test hash for non-equality because it may collide. + @jtu.sample_product( SCALAR_FUNCS, lhs_shape=broadcast_compatible_shapes, @@ -124,7 +164,7 @@ def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses? + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -146,7 +186,7 @@ def np_fun(x, idx, y): args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)] self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses? + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -167,7 +207,7 @@ def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')] self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses? + self._CompileAndCheck(jnp_fun, args_maker) if __name__ == "__main__":