Skip to content

Commit

Permalink
jnp.ufunc: add __hash__ method and jit methods by default
Browse files Browse the repository at this point in the history
This allows the JIT cache to work properly with ufunc methods, because bound
methods are created with a new ID each time.
  • Loading branch information
jakevdp committed Aug 14, 2023
1 parent 619377e commit b3a02e1
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 9 deletions.
41 changes: 35 additions & 6 deletions jax/_src/numpy/ufunc_api.py
Expand Up @@ -38,13 +38,37 @@ class ufunc:
This is a class for LAX-backed implementations of numpy ufuncs.
"""
def __init__(self, func, /, nin, nout, *, name=None, nargs=None, identity=None):
# TODO(jakevdp): validate the signature of func via eval_shape.
# We want ufunc instances to work properly when marked as static,
# and for this reason it's important that their properties not be
# mutated. We prevent this by storing them in a dunder attribute,
# and accessing them via read-only properties.
self.__name__ = name or func.__name__
self._call = vectorize(func)
self.nin = operator.index(nin)
self.nout = operator.index(nout)
self.nargs = nargs or self.nin
self.identity = identity
self.__static_props = {
'func': func,
'call': vectorize(func),
'nin': operator.index(nin),
'nout': operator.index(nout),
'nargs': operator.index(nargs or nin),
'identity': identity
}

_func = property(lambda self: self.__static_props['func'])
_call = property(lambda self: self.__static_props['call'])
nin = property(lambda self: self.__static_props['nin'])
nout = property(lambda self: self.__static_props['nout'])
nargs = property(lambda self: self.__static_props['nargs'])
identity = property(lambda self: self.__static_props['identity'])

def __hash__(self):
# Do not include _call, because it is computed from _func.
return hash((self._func, self.__name__, self.identity,
self.nin, self.nout, self.nargs))

def __eq__(self, other):
# Do not include _call, because it is computed from _func.
return isinstance(other, ufunc) and (
(self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) ==
(other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs))

def __repr__(self):
return f"<jnp.ufunc '{self.__name__}'>"
Expand All @@ -57,6 +81,7 @@ def __call__(self, *args, out=None, where=None, **kwargs):
return self._call(*args, **kwargs)

@_wraps(np.ufunc.reduce, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, initial=None, where=None):
if self.nin != 2:
raise ValueError("reduce only supported for binary ufuncs")
Expand Down Expand Up @@ -120,6 +145,7 @@ def body_fun(i, val):
return result

@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def accumulate(self, a, axis=0, dtype=None, out=None):
if self.nin != 2:
raise ValueError("accumulate only supported for binary ufuncs")
Expand Down Expand Up @@ -150,6 +176,7 @@ def scan_fun(carry, _):
return _moveaxis(result, 0, axis)

@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
def at(self, a, indices, b=None, /, *, inplace=True):
if inplace:
raise NotImplementedError(_AT_INPLACE_WARNING)
Expand Down Expand Up @@ -184,6 +211,7 @@ def scan_fun(carry, x):
return carry[1]

@_wraps(np.ufunc.reduceat, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def reduceat(self, a, indices, axis=0, dtype=None, out=None):
if self.nin != 2:
raise ValueError("reduceat only supported for binary ufuncs")
Expand Down Expand Up @@ -220,6 +248,7 @@ def loop_body(i, out):
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)

@_wraps(np.ufunc.outer, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0])
def outer(self, A, B, /, **kwargs):
if self.nin != 2:
raise ValueError("outer only supported for binary ufuncs")
Expand Down
46 changes: 43 additions & 3 deletions tests/lax_numpy_ufuncs_test.py
Expand Up @@ -33,14 +33,26 @@ def scalar_add(x, y):
return x + y


def scalar_div(x, y):
assert np.shape(x) == np.shape(y) == ()
return x / y


def scalar_mul(x, y):
assert np.shape(x) == np.shape(y) == ()
return x * y


def scalar_sub(x, y):
assert np.shape(x) == np.shape(y) == ()
return x - y


SCALAR_FUNCS = [
{'func': scalar_add, 'nin': 2, 'nout': 1, 'identity': 0},
{'func': scalar_div, 'nin': 2, 'nout': 1, 'identity': None},
{'func': scalar_mul, 'nin': 2, 'nout': 1, 'identity': 1},
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
]

broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
Expand All @@ -54,6 +66,34 @@ def wrapped(*args, **kwargs):


class LaxNumpyUfuncTests(jtu.JaxTestCase):

@jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_properties(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun.identity, identity)
self.assertEqual(jnp_fun.nin, nin)
self.assertEqual(jnp_fun.nout, nout)
self.assertEqual(jnp_fun.nargs, nin)

@jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_properties_readonly(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']:
getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate.

@jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_hash(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun, jnp_fun_2)
self.assertEqual(hash(jnp_fun), hash(jnp_fun_2))

other_fun = jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0)
self.assertNotEqual(jnp_fun, other_fun)
# Note: don't test hash for non-equality because it may collide.

@jtu.sample_product(
SCALAR_FUNCS,
lhs_shape=broadcast_compatible_shapes,
Expand Down Expand Up @@ -124,7 +164,7 @@ def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype):
args_maker = lambda: [rng(shape, dtype)]

self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses?
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
SCALAR_FUNCS,
Expand All @@ -146,7 +186,7 @@ def np_fun(x, idx, y):
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)]

self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses?
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
SCALAR_FUNCS,
Expand All @@ -167,7 +207,7 @@ def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]

self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses?
self._CompileAndCheck(jnp_fun, args_maker)


if __name__ == "__main__":
Expand Down

0 comments on commit b3a02e1

Please sign in to comment.