Skip to content

Commit

Permalink
Implement complex logaddexp
Browse files Browse the repository at this point in the history
  • Loading branch information
wdphy16 committed Jul 3, 2021
1 parent b937066 commit 7b10965
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 18 deletions.
56 changes: 38 additions & 18 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,45 +656,65 @@ def power(x1, x2):
@custom_jvp
@_wraps(np.logaddexp)
def logaddexp(x1, x2):
x1, x2 = _promote_shapes("logaddexp", *_promote_dtypes_inexact(x1, x2))
x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
amax = lax.max(x1, x2)
delta = lax.sub(x1, x2)
return lax.select(isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.log1p(lax.exp(-lax.abs(delta)))))
if issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.log1p(lax.exp(delta)))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))

def _wrap_between(x, _a):
"""Wraps `x` between `[-a, a]`."""
a = _constant_like(x, _a)
two_a = _constant_like(x, 2 * _a)
zero = _constant_like(x, 0)
rem = lax.rem(lax.add(x, a), two_a)
rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
return lax.sub(rem, a)

@logaddexp.defjvp
def _logaddexp_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = broadcast_arrays(x1, x2, t1, t2)
x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
primal_out = logaddexp(x1, x2)
tangent_out = (t1 * exp(_replace_inf(x1) - _replace_inf(primal_out)) +
t2 * exp(_replace_inf(x2) - _replace_inf(primal_out)))
tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out

def _replace_inf(x):
return lax.select(isposinf(x), zeros_like(x), x)
return lax.select(isposinf(real(x)), zeros_like(x), x)


@custom_jvp
@_wraps(np.logaddexp2)
def logaddexp2(x1, x2):
x1, x2 = _promote_shapes("logaddexp2", *_promote_dtypes_inexact(x1, x2))
x1, x2 = _promote_args_inexact("logaddexp2", x1, x2)
amax = lax.max(x1, x2)
delta = lax.sub(x1, x2)
return lax.select(isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.div(lax.log1p(exp2(-lax.abs(delta))),
_constant_like(x1, np.log(2)))))
if issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
_constant_like(x1, np.log(2)))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2))))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))

@logaddexp2.defjvp
def _logaddexp2_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = broadcast_arrays(x1, x2, t1, t2)
x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
primal_out = logaddexp2(x1, x2)
tangent_out = (t1 * 2 ** (_replace_inf(x1) - _replace_inf(primal_out)) +
t2 * 2 ** (_replace_inf(x2) - _replace_inf(primal_out)))
tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out


Expand Down
69 changes: 69 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5379,6 +5379,44 @@ def testIssue2347(self):
np_object_list = np.array(object_list)
self.assertRaises(TypeError, jnp.array, np_object_list)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
"shapes": shapes, "dtypes": dtypes}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 2))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes))))
def testLogaddexpComplex(self, shapes, dtypes):
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
def np_op(x1, x2):
return np.log(np.exp(x1) + np.exp(x2))

rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes))
tol = {np.complex64: 1e-5, np.complex128: 5e-15}
self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol)
self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
"shapes": shapes, "dtypes": dtypes}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 2))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes))))
def testLogaddexp2Complex(self, shapes, dtypes):
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
def np_op(x1, x2):
return np.log2(np.exp2(x1) + np.exp2(x2))

rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes))
tol = {np.complex64: 1e-5, np.complex128: 5e-15}
self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol)
self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol)

# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.

Expand Down Expand Up @@ -5488,6 +5526,37 @@ def f(x):

check_grads(f, (1.,), order=1)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, itertools.repeat(dtype)),
"shapes": shapes, "dtype": dtype}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(nonempty_shapes, 2))
for dtype in (np.complex128, )))
def testGradLogaddexpComplex(self, shapes, dtype):
rng = jtu.rand_default(self.rng())
args = tuple(rng(shape, dtype) for shape in shapes)
if jtu.device_under_test() == "tpu":
tol = 5e-2
else:
tol = 3e-2
check_grads(jnp.logaddexp, args, 1, ["fwd", "rev"], tol, tol)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, itertools.repeat(dtype)),
"shapes": shapes, "dtype": dtype}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(nonempty_shapes, 2))
for dtype in (np.complex128, )))
def testGradLogaddexp2Complex(self, shapes, dtype):
rng = jtu.rand_default(self.rng())
args = tuple(rng(shape, dtype) for shape in shapes)
if jtu.device_under_test() == "tpu":
tol = 5e-2
else:
tol = 3e-2
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)

class NumpySignaturesTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 7b10965

Please sign in to comment.