Skip to content

Commit

Permalink
Implement np.{empty,empty_like,ptp,isreal,iscomplex,sinc,vander,posit…
Browse files Browse the repository at this point in the history
…ive}.

Issue google#70
  • Loading branch information
hawkinsp committed Feb 5, 2019
1 parent 0243816 commit 4f75386
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
8 changes: 8 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ jax.numpy package
dsplit
einsum
equal
empty
empty_like
exp
exp2
expand_dims
Expand All @@ -95,11 +97,13 @@ jax.numpy package
imag
inner
isclose
iscomplex
isfinite
isinf
isnan
isneginf
isposinf
isreal
issubdtype
issubsctype
kaiser
Expand Down Expand Up @@ -144,8 +148,10 @@ jax.numpy package
pad
polyval
power
positive
prod
product
ptp
rad2deg
radians
ravel
Expand All @@ -160,6 +166,7 @@ jax.numpy package
row_stack
sign
sin
sinc
sinh
sometrue
sort
Expand All @@ -183,6 +190,7 @@ jax.numpy package
tril
triu
true_divide
vander
var
vdot
vsplit
Expand Down
53 changes: 52 additions & 1 deletion jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def _one_to_one_binop(numpy_fn, lax_fn, promote_like=False):
fabs = _one_to_one_unop(onp.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(onp.bitwise_not, lax.bitwise_not)
negative = _one_to_one_unop(onp.negative, lax.neg)
positive = _one_to_one_unop(onp.positive, lambda x: x)
sign = _one_to_one_unop(onp.sign, lax.sign)

floor = _one_to_one_unop(onp.floor, lax.floor, True)
Expand Down Expand Up @@ -429,6 +430,7 @@ def sqrt(x):
x, = _promote_to_result_dtype(onp.sqrt, x)
return power(x, _constant_like(x, 0.5))


@_wraps(onp.square)
def square(x):
x, = _promote_to_result_dtype(onp.square, x)
Expand Down Expand Up @@ -463,6 +465,14 @@ def reciprocal(x):
return lax.div(lax._const(x, 1), x)


@_wraps(onp.sinc)
def sinc(x):
x, = _promote_to_result_dtype(onp.sinc, x)
pi_x = lax.mul(lax._const(x, pi), x)
return where(lax.eq(x, lax._const(x, 0)),
lax._const(x, 1), lax.div(lax.sin(pi_x), pi_x))


@_wraps(onp.transpose)
def transpose(x, axis=None):
axis = onp.arange(ndim(x))[::-1] if axis is None else axis
Expand Down Expand Up @@ -525,6 +535,16 @@ def real(x):
return lax.real(x) if iscomplexobj(x) else x


@_wraps(onp.iscomplex)
def iscomplex(x):
i = imag(x)
return lax.ne(i, _const(i, 0))

@_wraps(onp.isreal)
def isreal(x):
i = imag(x)
return lax.eq(i, _const(i, 0))

@_wraps(onp.angle)
def angle(x):
if iscomplexobj(x):
Expand Down Expand Up @@ -879,6 +899,15 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims))


@_wraps(onp.ptp)
def ptp(a, axis=None, out=None, keepdims=False):
if out is not None:
raise ValueError("ptp does not support the `out` argument.")
x = amax(a, axis=axis, keepdims=keepdims)
y = amin(a, axis=axis, keepdims=keepdims)
return lax.sub(x, y)


@_wraps(onp.allclose)
def allclose(a, b, rtol=1e-05, atol=1e-08):
return all(isclose(a, b, rtol, atol))
Expand Down Expand Up @@ -1095,13 +1124,17 @@ def zeros(shape, dtype=onp.dtype("float64")):
shape = (shape,) if onp.isscalar(shape) else shape
return lax.full(shape, 0, dtype)


@_wraps(onp.ones)
def ones(shape, dtype=onp.dtype("float64")):
shape = (shape,) if onp.isscalar(shape) else shape
return lax.full(shape, 1, dtype)


# We can't create uninitialized arrays in XLA; use zeros for empty.
empty_like = zeros_like
empty = zeros


@_wraps(onp.eye)
def eye(N, M=None, k=None, dtype=onp.dtype("float64")):
M = N if M is None else M
Expand Down Expand Up @@ -1597,6 +1630,24 @@ def kron(a, b):
return lax.reshape(a_broadcast * b_broadcast, out_shape)


@_wraps(onp.vander)
def vander(x, N=None, increasing=False):
x = asarray(x)
dtype = _dtype(x)
if ndim(x) != 1:
raise ValueError("x must be a one-dimensional array")
x_shape = shape(x)
N = N or x_shape[0]
if N < 0:
raise ValueError("N must be nonnegative")

iota = lax.iota(dtype, N)
if not increasing:
iota = lax.sub(lax._const(iota, N - 1), iota)

return power(x[..., None], iota)


### Misc


Expand Down
25 changes: 25 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None):
op_record("greater_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
op_record("isfinite", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
op_record("less", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
op_record("less_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
Expand All @@ -90,11 +92,13 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None):
op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]),
op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("power", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default(), []),
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("sinc", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("tan", 1, number_dtypes, all_shapes, jtu.rand_uniform(-1.5, 1.5),
["rev"]),
Expand Down Expand Up @@ -161,6 +165,7 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None):
op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero(), []),
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default(), []),
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default(), []),
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default(), []),
]

JAX_ARGMINMAX_RECORDS = [
Expand Down Expand Up @@ -1144,5 +1149,25 @@ def args_maker():
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}_increasing={}".format(
jtu.format_shape_dtype_string([shape], dtype),
n, increasing),
"dtype": dtype, "shape": shape, "n": n, "increasing": increasing,
"rng": jtu.rand_default()}
for dtype in inexact_dtypes
for shape in [0, 5]
for n in [2, 4]
for increasing in [False, True]))
def testVander(self, shape, dtype, n, increasing, rng):
onp_fun = lambda arg: onp.vander(arg, N=n, increasing=increasing)
lnp_fun = lambda arg: lnp.vander(arg, N=n, increasing=increasing)
args_maker = lambda: [rng([shape], dtype)]
# np.vander seems to return float64 for all floating types. We could obey
# those semantics, but they seem like a bug.
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False)


if __name__ == "__main__":
absltest.main()

0 comments on commit 4f75386

Please sign in to comment.