Skip to content

Commit

Permalink
lax.abs: better error for unsigned inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 5, 2023
1 parent 633f68a commit 60029e7
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/lax/lax.py
Expand Up @@ -1710,6 +1710,7 @@ def _nary_lower_hlo(op: Callable, ctx,
_complex_elem_types = {np.float32, np.float64}
_int = {np.integer}
_bool = {np.bool_}
_signedint = {np.signedinteger}

_num = _int | _float | _complex
_any = _int | _float | _complex | _bool
Expand Down Expand Up @@ -1944,7 +1945,7 @@ def _conj_transpose_rule(t, x, *, input_dtype):
ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
ad.primitive_transposes[conj_p] = _conj_transpose_rule

abs_p = unop(_complex_basetype, _num, 'abs')
abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs')
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.AbsOp))

def _abs_jvp_rule(g, ans, x):
Expand Down

0 comments on commit 60029e7

Please sign in to comment.