Skip to content

Commit

Permalink
Merge pull request #422 from hawkinsp/power
Browse files Browse the repository at this point in the history
Implement np.power for integers
  • Loading branch information
hawkinsp committed Feb 21, 2019
2 parents 4ba7d51 + 7b0bcbe commit 0fc5bbb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
25 changes: 24 additions & 1 deletion jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class ndarray(six.with_metaclass(_ArrayMeta, onp.ndarray)):
complexfloating = onp.complexfloating
floating = onp.floating
integer = onp.integer
signedinteger = onp.signedinteger
unsignedinteger = onp.unsignedinteger

iinfo = onp.iinfo
finfo = onp.finfo
Expand Down Expand Up @@ -284,7 +286,6 @@ def _one_to_one_binop(numpy_fn, lax_fn, promote_like=False):
multiply = _one_to_one_binop(onp.multiply, lax.mul)
not_equal = _one_to_one_binop(onp.not_equal, lax.ne)
subtract = _one_to_one_binop(onp.subtract, lax.sub)
power = _one_to_one_binop(onp.power, lax.pow, True)
arctan2 = _one_to_one_binop(onp.arctan2, lax.atan2, True)
minimum = _one_to_one_binop(onp.minimum, lax.min)
maximum = _one_to_one_binop(onp.maximum, lax.max)
Expand Down Expand Up @@ -388,6 +389,28 @@ def _float_divmod(x1, x2):
return lax.round(div), mod


@_wraps(onp.power)
def power(x1, x2):
x1 = asarray(x1)
x2 = asarray(x2)
x1, x2 = _promote_args_like(onp.power, x1, x2)
dtype = lax._dtype(x1)
if not issubdtype(dtype, integer):
return lax.pow(x1, x2)

# Integer power => use binary exponentiation.

# TODO(phawkins): add integer pow support to XLA.
bits = 6 # Anything more would overflow for any x1 > 1
acc = ones(shape(x1), dtype=dtype)
for _ in xrange(bits):
acc = where(lax.bitwise_and(x2, _constant_like(x2, 1)),
lax.mul(acc, x1), acc)
x1 = lax.mul(x1, x1)
x2 = lax.shift_right_logical(x2, _constant_like(x2, 1))
return acc


@_wraps(onp.logaddexp)
def logaddexp(x1, x2):
x1, x2 = _promote_shapes(*_promote_to_result_dtype(onp.logaddexp, x1, x2))
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None,
op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]),
op_record("power", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default(), []),
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
Expand Down

0 comments on commit 0fc5bbb

Please sign in to comment.