Skip to content

Commit

Permalink
jnp.concatenate: add fast path based on lax.reshape for array inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 10, 2021
1 parent 42b540c commit 0470f4f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
16 changes: 16 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2826,8 +2826,24 @@ def tile(A, reps):
[k for pair in zip(reps, A_shape) for k in pair])
return reshape(result, tuple(np.multiply(A_shape, reps)))

def _concatenate_array(arr, axis: int):
# Fast path for concatenation when the input is an ndarray rather than a list.
arr = asarray(arr)
if arr.ndim == 0 or arr.shape[0] == 0:
raise ValueError("Need at least one array to concatenate.")
if axis is None:
return lax.reshape(arr, (arr.size,))
if arr.ndim == 1:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
axis = _canonicalize_axis(axis, arr.ndim - 1)
shape = arr.shape[1:axis + 1] + (arr.shape[0] * arr.shape[axis + 1],) + arr.shape[axis + 2:]
dimensions = [*range(1, axis + 1), 0, *range(axis + 1, arr.ndim)]
return lax.reshape(arr, shape, dimensions)

@_wraps(np.concatenate)
def concatenate(arrays, axis: int = 0):
if isinstance(arrays, ndarray):
return _concatenate_array(arrays, axis)
_check_arraylike("concatenate", *arrays)
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
Expand Down
15 changes: 15 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,21 @@ def args_maker():
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for shape in [(4, 1), (4, 3), (4, 5, 6)]
for dtype in all_dtypes
for axis in [None] + list(range(1 - len(shape), len(shape) - 1))))
def testConcatenateArray(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda x: np.concatenate(x, axis=axis)
jnp_fun = lambda x: jnp.concatenate(x, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def testConcatenateAxisNone(self):
# https://github.com/google/jax/issues/3419
a = jnp.array([[1, 2], [3, 4]])
Expand Down

0 comments on commit 0470f4f

Please sign in to comment.