Skip to content

Commit

Permalink
Add a @jit decorator around jnp.linspace.
Browse files Browse the repository at this point in the history
Don't test integer dtype values. The exact rounding semantics may be quite sensitive to, e.g., jit compilation, and this is not something end users should be relying on. Simplify implementation to only use version that gets the endpoints correct.

Use the same approach NumPy does to ensure the endpoint is included when endpoint=True: explicitly set the endpoint.

Various minor cleanups.
  • Loading branch information
hawkinsp committed Sep 8, 2021
1 parent 9c782e2 commit 086cbdf
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 33 deletions.
40 changes: 25 additions & 15 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -3365,6 +3365,14 @@ def wrapper(*args, **kwargs):
@_wraps(np.linspace)
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis: int = 0):
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, int(num), endpoint, retstep, dtype,
operator.index(axis))

@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis: int = 0):
"""Implementation of linspace differentiable in start and stop args."""
lax._check_user_dtype_supported(dtype, "linspace")
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
Expand All @@ -3383,24 +3391,22 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
broadcast_stop = broadcast_to(stop, bounds_shape)
axis = len(bounds_shape) + axis + 1 if axis < 0 else axis
bounds_shape.insert(axis, 1)
iota_shape = [1,] * len(bounds_shape)
iota_shape[axis] = num
div = (num - 1) if endpoint else num
if num > 1:
delta = lax.convert_element_type(stop - start, computation_dtype) / div
if issubdtype(dtype, integer):
# This is similar to how numpy computes linspace, but it
# can fail to recover the endpoints in float32 arithmetic.
out = (reshape(broadcast_start, bounds_shape) +
reshape(lax.iota(dtype, num), iota_shape) *
reshape(delta, bounds_shape))
out = lax.floor(out)
else:
# This approach recovers the endpoints with float32 arithmetic,
# but can lead to rounding errors for integer outputs.
step = reshape(lax.iota(computation_dtype, num), iota_shape) / div
out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
reshape(broadcast_stop, bounds_shape) * step)
iota_shape = [1,] * len(bounds_shape)
iota_shape[axis] = div
# This approach recovers the endpoints with float32 arithmetic,
# but can lead to rounding errors for integer outputs.
real_dtype = finfo(computation_dtype).dtype
step = reshape(lax.iota(real_dtype, div), iota_shape) / div
out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
reshape(broadcast_stop, bounds_shape) * step)

if endpoint:
out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))],
_canonicalize_axis(axis, out.ndim))

elif num == 1:
delta = nan if endpoint else stop - start
out = reshape(broadcast_start, bounds_shape)
Expand All @@ -3409,6 +3415,10 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
empty_shape.insert(axis, 0)
delta = nan
out = reshape(array([], dtype=dtype), empty_shape)

if issubdtype(dtype, integer) and not issubdtype(out.dtype, integer):
out = lax.floor(out)

if retstep:
return lax.convert_element_type(out, dtype), delta
else:
Expand Down
28 changes: 10 additions & 18 deletions tests/lax_numpy_test.py
Expand Up @@ -992,8 +992,6 @@ def testArgWhere(self, shape, dtype):
message="Calling nonzero on 0d arrays.*")(np.argwhere)
jnp_fun = jnp.argwhere
args_maker = lambda: [rng(shape, dtype)]
if shape in (scalar_shapes + [()]) and numpy_version < (1, 18):
self.skipTest("np.argwhere() result for scalar input changed in numpy 1.18.")
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)

# JIT compilation requires specifying a size statically. Full test of this
Expand Down Expand Up @@ -3335,9 +3333,6 @@ def testExpandDimsStaticDim(self, arg_shape, dtype, dim):
jnp_fun = lambda x: jnp.expand_dims(x, dim)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CompileAndCheck(jnp_fun, args_maker)

if isinstance(dim, (tuple, list)) and numpy_version < (1, 18, 0):
raise SkipTest("support for multiple axes added in NumPy 1.18.0")
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down Expand Up @@ -4991,10 +4986,9 @@ def testIndex_exp(self):

@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
"_retstep={}_dtype={}").format(
start_shape, stop_shape, num, endpoint, retstep,
dtype.__name__ if dtype else "None"),
{"testcase_name": f"_start_shape={start_shape}_stop_shape={stop_shape}"
f"_num={num}_endpoint={endpoint}_retstep={retstep}"
f"_dtype={dtype.__name__ if dtype else 'None'}",
"start_shape": start_shape, "stop_shape": stop_shape,
"num": num, "endpoint": endpoint, "retstep": retstep,
"dtype": dtype}
Expand All @@ -5003,11 +4997,12 @@ def testIndex_exp(self):
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
for retstep in [True, False]
for dtype in number_dtypes + [None,]))
# floating-point compute between jitted platforms and non-jit + rounding
# cause unavoidable variation in integer truncation for some inputs, so
# we currently only test inexact 'dtype' arguments.
for dtype in inexact_dtypes + [None,]))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype):
if num == 1 and not endpoint and numpy_version < (1, 18):
raise SkipTest("Numpy < 1.18 has a linspace bug.")
rng = jtu.rand_default(self.rng())
# relax default tolerances slightly
tol = jtu.tolerance(dtype if dtype else np.float32) * 10
Expand Down Expand Up @@ -5037,15 +5032,12 @@ def np_op(start, stop):

self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=False, tol=tol)
# floating-point compute between jitted platforms and non-jit + rounding
# cause unavoidable variation in integer truncation for some inputs.
if dtype in (inexact_dtypes + [None,]):
self._CompileAndCheck(jnp_op, args_maker,
check_dtypes=False, atol=tol, rtol=tol)
self._CompileAndCheck(jnp_op, args_maker,
check_dtypes=False, atol=tol, rtol=tol)

@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(dtype), "dtype": dtype}
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
for dtype in number_dtypes))
def testLinspaceEndpoints(self, dtype):
"""Regression test for Issue #3014."""
Expand Down

0 comments on commit 086cbdf

Please sign in to comment.