Skip to content

Commit

Permalink
jnp.power: fix overflow case for x1=0
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 9, 2021
1 parent 0b88b0e commit 0c86c1f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 6 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,14 @@ def power(x1, x2):

# TODO(phawkins): add integer pow support to XLA.
bits = 6 # Anything more would overflow for any x1 > 1
acc = ones(shape(x1), dtype=dtype)
zero = _constant_like(x2, 0)
one = _constant_like(x2, 1)
# Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
acc = where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
for _ in range(bits):
acc = where(lax.bitwise_and(x2, _constant_like(x2, 1)),
lax.mul(acc, x1), acc)
acc = where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
x1 = lax.mul(x1, x1)
x2 = lax.shift_right_logical(x2, _constant_like(x2, 1))
x2 = lax.shift_right_logical(x2, one)
return acc


Expand Down
10 changes: 10 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,16 @@ def testIntegerPower(self, ptype):
self.assertLen(eqns, 1)
self.assertEqual(eqns[0].primitive, lax.integer_pow_p)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_y={}".format(x, y), "x": x, "y": y}
for x in [-1, 0, 1]
for y in [0, 32, 64, 128]))
def testIntegerPowerOverflow(self, x, y):
# Regression test for https://github.com/google/jax/issues/5987
args_maker = lambda: [x, y]
self._CheckAgainstNumpy(np.power, jnp.power, args_maker)
self._CompileAndCheck(jnp.power, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
Expand Down

0 comments on commit 0c86c1f

Please sign in to comment.