Skip to content

Commit

Permalink
[sparse] improve error messages for unimplemented primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 25, 2022
1 parent 563c426 commit c37c1e6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
55 changes: 49 additions & 6 deletions jax/experimental/sparse/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,52 @@

sparse_rules : Dict[core.Primitive, Callable] = {}

_zero_preserving_unary_primitives = [
lax.abs_p,
lax.asin_p,
lax.asinh_p,
lax.atan_p,
lax.atanh_p,
lax.bessel_i1e_p,
lax.expm1_p,
lax.log1p_p,
lax.neg_p,
lax.real_p,
lax.imag_p,
lax.sign_p,
lax.sin_p,
lax.sinh_p,
lax.sqrt_p,
lax.tan_p,
lax.tanh_p,
lax.convert_element_type_p
]

_densifying_primitives : List[core.Primitive] = [
lax.acos_p,
lax.acosh_p,
lax.bessel_i0e_p,
lax.cos_p,
lax.cosh_p,
lax.eq_p,
lax.exp_p,
lax.ge_p,
lax.gt_p,
lax.le_p,
lax.lt_p,
lax.log_p,
lax.ne_p,
lax.xor_p
]

def _raise_unimplemented_primitive(primitive):
if primitive in _densifying_primitives:
raise NotImplementedError(f"sparse rule for {primitive} is not implemented because it "
"would result in dense output. If this is your intent, use "
"sparse.todense() to convert your arguments to dense matrices.")
raise NotImplementedError(f"sparse rule for {primitive} is not implemented.")


Array = Any
ArrayOrSparse = Any

Expand Down Expand Up @@ -245,7 +291,7 @@ def process_primitive(self, primitive, tracers, params):
spvalues = [t._spvalue for t in tracers]
if any(spvalue.is_sparse() for spvalue in spvalues):
if primitive not in sparse_rules:
raise NotImplementedError(f"sparse rule for {primitive}")
_raise_unimplemented_primitive(primitive)
out_spvalues = sparse_rules[primitive](spenv, *(t._spvalue for t in tracers), **params)
else:
out_bufs = primitive.bind(*(spenv.data(spvalue) for spvalue in spvalues), **params)
Expand Down Expand Up @@ -337,7 +383,7 @@ def write(var: core.Var, a: SparsifyValue) -> None:

if any(val.is_sparse() for val in invals):
if prim not in sparse_rules:
raise NotImplementedError(f"sparse rule for {prim}")
_raise_unimplemented_primitive(prim)
out = sparse_rules[prim](spenv, *invals, **eqn.params)
else:
if prim is xla.xla_call_p:
Expand Down Expand Up @@ -425,10 +471,7 @@ def func(spenv, *spvalues, **kwargs):

# TODO(jakevdp): some of these will give incorrect results when there are duplicated indices.
# how should we handle this?
for _prim in [
lax.abs_p, lax.expm1_p, lax.log1p_p, lax.neg_p, lax.sign_p, lax.sin_p,
lax.sinh_p, lax.sqrt_p, lax.tan_p, lax.tanh_p, lax.convert_element_type_p
]:
for _prim in _zero_preserving_unary_primitives:
sparse_rules[_prim] = _zero_preserving_unary_op(_prim)

def _dot_general_sparse(spenv, *spvalues, dimension_numbers, precision, preferred_element_type):
Expand Down
12 changes: 12 additions & 0 deletions tests/sparsify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ class SparsifyTest(jtu.JaxTestCase):
def sparsify(cls, f):
return sparsify(f, use_tracer=False)

def testNotImplementedMessages(self):
x = BCOO.fromdense(jnp.arange(5.0))
# Test a densifying primitive
with self.assertRaisesRegex(NotImplementedError,
r"^sparse rule for cos is not implemented because it would result in dense output\."):
self.sparsify(lax.cos)(x)

# Test a generic not implemented primitive.
with self.assertRaisesRegex(NotImplementedError,
r"^sparse rule for complex is not implemented\.$"):
self.sparsify(lax.complex)(x, x)

def testTracerIsInstanceCheck(self):
@self.sparsify
def f(x):
Expand Down

0 comments on commit c37c1e6

Please sign in to comment.