Skip to content

Commit

Permalink
add efficient path for array input to jnp.stack, jnp.[hvd]stack, jnp.…
Browse files Browse the repository at this point in the history
…column_stack
  • Loading branch information
jakevdp committed Jun 11, 2021
1 parent 3550732 commit 17710c0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 40 deletions.
57 changes: 35 additions & 22 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2801,15 +2801,19 @@ def stack(arrays, axis: int =0, out=None):
raise ValueError("Need at least one array to stack.")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
_check_arraylike("stack", *arrays)
shape0 = shape(arrays[0])
axis = _canonicalize_axis(axis, len(shape0) + 1)
new_arrays = []
for a in arrays:
if shape(a) != shape0:
raise ValueError("All input arrays must have the same shape.")
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis)
if isinstance(arrays, ndarray):
axis = _canonicalize_axis(axis, arrays.ndim)
return concatenate(expand_dims(arrays, axis + 1), axis=axis)
else:
_check_arraylike("stack", *arrays)
shape0 = shape(arrays[0])
axis = _canonicalize_axis(axis, len(shape0) + 1)
new_arrays = []
for a in arrays:
if shape(a) != shape0:
raise ValueError("All input arrays must have the same shape.")
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis)

@_wraps(np.tile)
def tile(A, reps):
Expand Down Expand Up @@ -2868,32 +2872,41 @@ def concatenate(arrays, axis: int = 0):

@_wraps(np.vstack)
def vstack(tup):
return concatenate([atleast_2d(m) for m in tup], axis=0)
if isinstance(tup, ndarray):
arrs = jax.vmap(atleast_2d)(tup)
else:
arrs = [atleast_2d(m) for m in tup]
return concatenate(arrs, axis=0)
row_stack = vstack


@_wraps(np.hstack)
def hstack(tup):
arrs = [atleast_1d(m) for m in tup]
if arrs[0].ndim == 1:
return concatenate(arrs, 0)
return concatenate(arrs, 1)
if isinstance(tup, ndarray):
arrs = jax.vmap(atleast_1d)(tup)
arr0_ndim = arrs.ndim - 1
else:
arrs = [atleast_1d(m) for m in tup]
arr0_ndim = arrs[0].ndim
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1)


@_wraps(np.dstack)
def dstack(tup):
return concatenate([atleast_3d(m) for m in tup], axis=2)
if isinstance(tup, ndarray):
arrs = jax.vmap(atleast_3d)(tup)
else:
arrs = [atleast_3d(m) for m in tup]
return concatenate(arrs, axis=2)


@_wraps(np.column_stack)
def column_stack(tup):
arrays = []
for v in tup:
arr = asarray(v)
if arr.ndim < 2:
arr = atleast_2d(arr).T
arrays.append(arr)
return concatenate(arrays, 1)
if isinstance(tup, ndarray):
arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup
else:
arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)]
return concatenate(arrs, 1)


@_wraps(np.choose, skip_params=['out'])
Expand Down
48 changes: 30 additions & 18 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,29 +2762,33 @@ def testDigitize(self, xshape, binshape, right, reverse, dtype):
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
"shape": shape, "dtypes": dtypes}
{"testcase_name": "_{}_array={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input),
"shape": shape, "dtypes": dtypes, "array_input": array_input}
for dtypes in [
[np.float32],
[np.float32, np.float32],
[np.float32, np.int32, np.float32],
[np.float32, np.int64, np.float32],
[np.float32, np.int32, np.float64],
]
for shape in [(), (2,), (3, 4), (1, 5)]))
def testColumnStack(self, shape, dtypes):
for shape in [(), (2,), (3, 4), (1, 5)]
for array_input in [True, False]))
def testColumnStack(self, shape, dtypes, array_input):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(np.column_stack)
jnp_fun = jnp.column_stack
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis),
"shape": shape, "axis": axis, "dtypes": dtypes}
{"testcase_name": "_{}_axis={}_array={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input),
"shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input}
for dtypes in [
[np.float32],
[np.float32, np.float32],
Expand All @@ -2793,19 +2797,23 @@ def testColumnStack(self, shape, dtypes):
[np.float32, np.int32, np.float64],
]
for shape in [(), (2,), (3, 4), (1, 100)]
for axis in range(-len(shape), len(shape) + 1)))
def testStack(self, shape, axis, dtypes):
for axis in range(-len(shape), len(shape) + 1)
for array_input in [True, False]))
def testStack(self, shape, axis, dtypes, array_input):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(partial(np.stack, axis=axis))
jnp_fun = partial(jnp.stack, axis=axis)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_{}".format(
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
"shape": shape, "op": op, "dtypes": dtypes}
{"testcase_name": "_op={}_{}_array={}".format(
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input),
"shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input}
for op in ["hstack", "vstack", "dstack"]
for dtypes in [
[np.float32],
Expand All @@ -2814,10 +2822,14 @@ def testStack(self, shape, axis, dtypes):
[np.float32, np.int64, np.float32],
[np.float32, np.int32, np.float64],
]
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]))
def testHVDStack(self, shape, op, dtypes):
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]
for array_input in [True, False]))
def testHVDStack(self, shape, op, dtypes, array_input):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(getattr(np, op))
jnp_fun = getattr(jnp, op)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
Expand Down

0 comments on commit 17710c0

Please sign in to comment.