Skip to content

Commit

Permalink
Improve implementation of cbrt() in JAX.
Browse files Browse the repository at this point in the history
Lower to XLA cbrt() operator in sufficiently new jaxlibs.
On TPU, use a Newton-Raphson step to improve the cube root.

Remove support for complex cbrt() in jax.numpy; the existing lowering was wrong and it is not entirely clear to me that we actually want to support complex `jnp.cbrt()`. NumPy itself does not support complex numbers in this case.

Add testing for `sqrt`/`rsqrt` for more types.

[XLA:Python] Add cbrt to XLA:Python bindings.

PiperOrigin-RevId: 386316949
  • Loading branch information
hawkinsp authored and jax authors committed Jul 22, 2021
1 parent 6335768 commit 278ff13
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Expand Up @@ -44,6 +44,7 @@ Operators
broadcast
broadcasted_iota
broadcast_in_dim
cbrt
ceil
clamp
collapse
Expand Down
32 changes: 32 additions & 0 deletions jax/_src/lax/lax.py
Expand Up @@ -311,6 +311,10 @@ def rsqrt(x: Array) -> Array:
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`."""
return rsqrt_p.bind(x)

def cbrt(x: Array) -> Array:
r"""Elementwise cube root: :math:`\cbrt{x}`."""
return cbrt_p.bind(x)

def bitwise_not(x: Array) -> Array:
r"""Elementwise NOT: :math:`\neg x`."""
return not_p.bind(x)
Expand Down Expand Up @@ -2615,6 +2619,34 @@ def _abs_jvp_rule(g, ans, x):
lambda g, ans, x:
mul(g, mul(_const(x, -0.5), div(ans, x))))

# TODO(phawkins): remove the fallback translation rule after the minimum jaxlib
# is 0.1.70 or newer.
if jax.lib._xla_extension_version >= 28:
_cbrt_translation_rule = None
else:
def _cbrt_translation_rule(c, x):
dtype = c.get_shape(x).numpy_dtype()
return xops.Mul(
xops.Sign(x),
xops.Pow(xops.Abs(x), xb.constant(c, np.array(1/3, dtype=dtype))))

cbrt_p = standard_unop(_float, 'cbrt',
translation_rule=_cbrt_translation_rule)
ad.defjvp2(cbrt_p,
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))

# TODO(b/194222106): remove the TPU-specific translation rule after XLA's cbrt
# is improved on TPU.
def _cbrt_tpu(y):
abs_y = abs(y)
z = pow(abs_y, _const(y, -1/3))
# Newton-Raphson step: https://csclub.uwaterloo.ca/~pbarfuss/qbrt.pdf
z1 = z + _const(y, 1/3) * (z - (z * z) * (z * (z * abs_y)))
return select(eq(y, _zeros(y)), y, z1 * (z1 * y))

xla.backend_specific_translations['tpu'][cbrt_p] = xla.lower_fun(
_cbrt_tpu, multiple_results=False)

pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')

def _pow_jvp_lhs(g, ans, x, y):
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/lax_reference.py
Expand Up @@ -54,6 +54,8 @@ def round(x):

sqrt = np.sqrt
rsqrt = lambda x: np.ones_like(x) / np.sqrt(x)
cbrt = np.cbrt

square = np.square
reciprocal = np.reciprocal
tan = np.tan
Expand Down
8 changes: 1 addition & 7 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -461,6 +461,7 @@ def fn(x1, x2):
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True)


add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
Expand Down Expand Up @@ -920,13 +921,6 @@ def fmod(x1, x2):
return lax.rem(*_promote_args("fmod", x1, x2))


@_wraps(np.cbrt)
def cbrt(x):
_check_arraylike("cbrt", x)
x, = _promote_dtypes_inexact(x)
return lax.sign(x) * power(lax.abs(x), _constant_like(x, 1. / 3.))


@_wraps(np.square)
def square(x):
_check_arraylike("square", x)
Expand Down
5 changes: 5 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1151,6 +1151,11 @@ def _atan2(y, x, **kwargs):
tf_impl[lax.sqrt_p] = tf.math.sqrt
tf_impl[lax.rsqrt_p] = tf.math.rsqrt

def _cbrt(x):
return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3)

tf_impl[lax.cbrt_p] = _cbrt

tf_impl[lax.lgamma_p] = tf.math.lgamma
tf_impl[lax.digamma_p] = tf.math.digamma
tf_impl[lax.igamma_p] = tf.math.igamma
Expand Down
5 changes: 3 additions & 2 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Expand Up @@ -125,8 +125,9 @@ def limitations_for_harness(
# We keep here the explicit set of groups for which we don't have limitations
harness_groups_no_limitations = {
"abs", "add", "add_any", "and", "atan2",
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "ceil", "clamp",
"concatenate", "cos", "cosh", "complex", "conj", "convert_element_type",
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "cbrt", "ceil",
"clamp", "concatenate", "cos", "cosh", "complex", "conj",
"convert_element_type",
"cummax", "cummin", "device_put", "dynamic_slice",
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt",
"imag",
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Expand Up @@ -442,6 +442,7 @@ def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype):
_make_unary_elementwise_harness(prim=lax.is_finite_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.lgamma_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.digamma_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.cbrt_p, dtype=dtype)

for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
_make_unary_elementwise_harness(prim=lax.neg_p, dtype=dtype)
Expand Down
2 changes: 2 additions & 0 deletions jax/lax/__init__.py
Expand Up @@ -70,6 +70,8 @@
broadcast_shapes,
broadcast_to_rank,
broadcasted_iota,
cbrt,
cbrt_p,
ceil,
ceil_p,
clamp,
Expand Down
10 changes: 10 additions & 0 deletions tests/lax_autodiff_test.py
Expand Up @@ -125,6 +125,16 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.pow, nargs=2, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_inexact_dtypes, tol={np.float32: 3e-1}),
grad_test_spec(lax.sqrt, nargs=1, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_float_dtypes),
grad_test_spec(lax.sqrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_complex_dtypes),
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_float_dtypes),
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_complex_dtypes),
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes, tol={np.float64: 3e-5}),

grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
Expand Down
20 changes: 14 additions & 6 deletions tests/lax_test.py
Expand Up @@ -109,8 +109,11 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None):
op_record("cos", 1, float_dtypes + complex_dtypes, jtu.rand_default),
op_record("atan2", 2, float_dtypes, jtu.rand_default),

op_record("sqrt", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
op_record("rsqrt", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
op_record("sqrt", 1, float_dtypes, jtu.rand_positive),
op_record("sqrt", 1, complex_dtypes, jtu.rand_default),
op_record("rsqrt", 1, float_dtypes, jtu.rand_positive),
op_record("rsqrt", 1, complex_dtypes, jtu.rand_default),
op_record("cbrt", 1, float_dtypes, jtu.rand_default),
op_record("square", 1, float_dtypes + complex_dtypes, jtu.rand_default),
op_record("reciprocal", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
op_record("tan", 1, float_dtypes + complex_dtypes, jtu.rand_default, {np.float32: 3e-5}),
Expand Down Expand Up @@ -2606,10 +2609,15 @@ def testArgMaxOfNanChoosesNaN(self):
self.assertEqual(lax.argmax(np.array([0., np.nan]), axis=0,
index_dtype=np.int32), 1)

unary_op_types = {}
for r in LAX_OPS:
if r.nargs == 1:
unary_op_types[r.op] = (unary_op_types.get(r.op, set()) |
set(np.dtype(t) for t in r.dtypes))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(rec.op),
"op_name": rec.op, "rec_dtypes": rec.dtypes}
for rec in LAX_OPS if rec.nargs == 1))
{"testcase_name": "_{}".format(op), "op_name": op, "rec_dtypes": dtypes}
for op, dtypes in unary_op_types.items()))
def testUnaryWeakTypes(self, op_name, rec_dtypes):
"""Test that all lax unary ops propagate weak_type information appropriately."""
# Find a valid dtype for the function.
Expand All @@ -2620,7 +2628,7 @@ def testUnaryWeakTypes(self, op_name, rec_dtypes):
lax_val = lax.full((), py_val, dtype)
break
else:
raise ValueError("no available dtypes")
raise ValueError(f"no available dtypes in {rec_dtypes}")

op = getattr(lax, op_name)
py_op = op(py_val)
Expand Down

0 comments on commit 278ff13

Please sign in to comment.