Skip to content

Commit

Permalink
Merge pull request #7936 from jakevdp:jnp-insert
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 397153039
  • Loading branch information
jax authors committed Sep 16, 2021
2 parents 0851e05 + b895f53 commit 75f941b
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 0 deletions.
51 changes: 51 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -4415,6 +4415,57 @@ def delete(arr, obj, axis=None):
raise ValueError(f"np.delete(arr, obj): got obj.dtype={obj.dtype}; must be integer or bool.")
return arr[tuple(slice(None) for i in range(axis)) + (mask,)]

@_wraps(np.insert)
def insert(arr, obj, values, axis=None):
_check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
arr = asarray(arr)
values = asarray(values)

if axis is None:
arr = ravel(arr)
axis = 0
axis = core.concrete_or_error(None, axis, "axis argument of jnp.insert()")
axis = _canonicalize_axis(axis, arr.ndim)
if isinstance(obj, slice):
indices = arange(*obj.indices(arr.shape[axis]))
else:
indices = asarray(obj)

if indices.ndim > 1:
raise ValueError("jnp.insert(): obj must be a slice, a one-dimensional "
f"array, or a scalar; got {obj}")
if not np.issubdtype(indices.dtype, np.integer):
if indices.size == 0 and not isinstance(obj, ndarray):
indices = indices.astype(int)
else:
# Note: np.insert allows boolean inputs but the behavior is deprecated.
raise ValueError("jnp.insert(): index array must be "
f"integer typed; got {obj}")
values = array(values, ndmin=arr.ndim, dtype=arr.dtype, copy=False)

if indices.size == 1:
index = ravel(indices)[0]
if indices.ndim == 0:
values = moveaxis(values, 0, axis)
indices = full(values.shape[axis], index)
n_input = arr.shape[axis]
n_insert = 0 if len(indices) == 0 else _max(values.shape[axis], len(indices))
out_shape = list(arr.shape)
out_shape[axis] += n_insert
out = zeros_like(arr, shape=tuple(out_shape))

indices = where(indices < 0, indices + n_input, indices)
indices = clip(indices, 0, n_input)

values_ind = indices.at[argsort(indices)].add(arange(n_insert))
arr_mask = ones(n_input + n_insert, dtype=bool).at[values_ind].set(False)
arr_ind = where(arr_mask, size=n_input)[0]

out = out.at[(slice(None),) * axis + (values_ind,)].set(values)
out = out.at[(slice(None),) * axis + (arr_ind,)].set(arr)

return out


@_wraps(np.apply_along_axis)
def apply_along_axis(func1d, axis: int, arr, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Expand Up @@ -185,6 +185,7 @@
in1d as in1d,
inf as inf,
inner as inner,
insert as insert,
int16 as int16,
int32 as int32,
int64 as int64,
Expand Down
54 changes: 54 additions & 0 deletions tests/lax_numpy_test.py
Expand Up @@ -2152,6 +2152,60 @@ def testDeleteMaskArray(self, shape, dtype, axis):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"dtype": dtype, "shape": shape, "axis": axis}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testInsertInteger(self, shape, dtype, axis):
x = jnp.empty(shape)
max_ind = x.size if axis is None else x.shape[axis]
rng = jtu.rand_default(self.rng())
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)]
np_fun = lambda *args: np.insert(*args, axis=axis)
jnp_fun = lambda *args: jnp.insert(*args, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"dtype": dtype, "shape": shape, "axis": axis}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testInsertSlice(self, shape, dtype, axis):
x = jnp.empty(shape)
max_ind = x.size if axis is None else x.shape[axis]
rng = jtu.rand_default(self.rng())
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item())
args_maker = lambda: [rng(shape, dtype), rng((), dtype)]
np_fun = lambda x, val: np.insert(x, slc, val, axis=axis)
jnp_fun = lambda x, val: jnp.insert(x, slc, val, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.parameters([
[[[1, 1], [2, 2], [3, 3]], 1, 5, None],
[[[1, 1], [2, 2], [3, 3]], 1, 5, 1],
[[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1],
[[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1],
[[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None],
[[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None],
[[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None],
[[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1]
])
def testInsertExamples(self, arr, index, values, axis):
# Test examples from the np.insert docstring
args_maker = lambda: (
np.asarray(arr), index if isinstance(index, slice) else np.array(index),
np.asarray(values), axis)
self._CheckAgainstNumpy(np.insert, jnp.insert, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_out_dims={}".format(
jtu.format_shape_dtype_string(shape, dtype),
Expand Down

0 comments on commit 75f941b

Please sign in to comment.