Skip to content

Commit

Permalink
jnp.angle: support deg keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 20, 2022
1 parent 6032528 commit eac5302
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
7 changes: 4 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,8 +1493,8 @@ def isreal(x):
return lax.eq(i, lax._const(i, 0))

@_wraps(np.angle)
@jit
def angle(z):
@partial(jit, static_argnames=['deg'])
def angle(z, deg=False):
re = real(z)
im = imag(z)
dtype = _dtype(re)
Expand All @@ -1503,7 +1503,8 @@ def angle(z):
dtype = dtypes.canonicalize_dtype(float_)
re = lax.convert_element_type(re, dtype)
im = lax.convert_element_type(im, dtype)
return lax.atan2(im, re)
result = lax.atan2(im, re)
return degrees(result) if deg else result


@_wraps(np.diff)
Expand Down
15 changes: 9 additions & 6 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ def _get_y_shapes(y_dtype, shape, rowvar):
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
"test_name", "check_dtypes", "tolerance", "inexact"])
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])

def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name=None, check_dtypes=True,
tolerance=None, inexact=False):
tolerance=None, inexact=False, kwargs=None):
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name, check_dtypes, tolerance, inexact)
test_name, check_dtypes, tolerance, inexact, kwargs)

JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("abs", 1, number_dtypes + unsigned_dtypes + bool_dtypes,
Expand Down Expand Up @@ -213,6 +213,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
# angle has inconsistent 32/64-bit return types across numpy versions.
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False, inexact=True),
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False, inexact=True, test_name="angle_deg", kwargs={'deg': True}),
op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
Expand Down Expand Up @@ -545,7 +547,7 @@ def testLoad(self, dtype):
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance,
"inexact": rec.inexact}
"inexact": rec.inexact, "kwargs": rec.kwargs or {}}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
Expand All @@ -555,7 +557,9 @@ def testLoad(self, dtype):
JAX_COMPOUND_OP_RECORDS)))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testOp(self, np_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes,
tolerance, inexact):
tolerance, inexact, kwargs):
np_op = partial(np_op, **kwargs)
jnp_op = partial(jnp_op, **kwargs)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(np_op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
Expand Down Expand Up @@ -5998,7 +6002,6 @@ def testWrappedSignaturesMatch(self):

# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
unsupported_params = {
'angle': ['deg'],
'argmax': ['keepdims'],
'argmin': ['keepdims'],
'asarray': ['like'],
Expand Down

0 comments on commit eac5302

Please sign in to comment.