Skip to content

Commit

Permalink
Merge pull request #11906 from alonfnt:dtype-arg
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 467970800
  • Loading branch information
jax authors committed Aug 16, 2022
2 parents 6ae46c3 + 99c5e91 commit 332d7d0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 27 deletions.
18 changes: 9 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,14 +1616,14 @@ def pad(array, pad_width, mode="constant", **kwargs):


@_wraps(np.stack, skip_params=['out'])
def stack(arrays, axis: int = 0, out=None):
def stack(arrays, axis: int = 0, out=None, dtype=None):
if not len(arrays):
raise ValueError("Need at least one array to stack.")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
if isinstance(arrays, (np.ndarray, ndarray)):
axis = _canonicalize_axis(axis, arrays.ndim)
return concatenate(expand_dims(arrays, axis + 1), axis=axis)
return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
else:
_stackable(*arrays) or _check_arraylike("stack", *arrays)
shape0 = shape(arrays[0])
Expand All @@ -1633,7 +1633,7 @@ def stack(arrays, axis: int = 0, out=None):
if shape(a) != shape0:
raise ValueError("All input arrays must have the same shape.")
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis)
return concatenate(new_arrays, axis=axis, dtype=dtype)

@_wraps(np.tile)
def tile(A, reps):
Expand Down Expand Up @@ -1696,33 +1696,33 @@ def concatenate(arrays, axis: int = 0, dtype=None):


@_wraps(np.vstack)
def vstack(tup):
def vstack(tup, dtype=None):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_2d)(tup)
else:
arrs = [atleast_2d(m) for m in tup]
return concatenate(arrs, axis=0)
return concatenate(arrs, axis=0, dtype=dtype)
row_stack = vstack


@_wraps(np.hstack)
def hstack(tup):
def hstack(tup, dtype=None):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_1d)(tup)
arr0_ndim = arrs.ndim - 1
else:
arrs = [atleast_1d(m) for m in tup]
arr0_ndim = arrs[0].ndim
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1)
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype)


@_wraps(np.dstack)
def dstack(tup):
def dstack(tup, dtype=None):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_3d)(tup)
else:
arrs = [atleast_3d(m) for m in tup]
return concatenate(arrs, axis=2)
return concatenate(arrs, axis=2, dtype=dtype)


@_wraps(np.column_stack)
Expand Down
51 changes: 33 additions & 18 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3311,9 +3311,11 @@ def testColumnStack(self, shape, dtypes, array_input):
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_axis={}_array={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input),
"shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input}
{"testcase_name": "{}_axis={}_array={}_out={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input,
np.dtype(out_dtype).name),
"shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input,
"out_dtype": out_dtype}
for dtypes in [
[np.float32],
[np.float32, np.float32],
Expand All @@ -3323,23 +3325,30 @@ def testColumnStack(self, shape, dtypes, array_input):
]
for shape in [(), (2,), (3, 4), (1, 100)]
for axis in range(-len(shape), len(shape) + 1)
for array_input in [True, False]))
def testStack(self, shape, axis, dtypes, array_input):
for array_input in [True, False]
for out_dtype in [np.float32, np.int32]))
def testStack(self, shape, axis, dtypes, array_input, out_dtype):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(partial(np.stack, axis=axis))
jnp_fun = partial(jnp.stack, axis=axis)

if numpy_version < (1, 24):
np_fun = _promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype))
else:
np_fun = _promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype))

jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_{}_array={}".format(
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input),
"shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input}
{"testcase_name": "_op={}_{}_array={}_out={}".format(
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input,
np.dtype(out_dtype).name),
"shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input, "out_dtype": out_dtype}
for op in ["hstack", "vstack", "dstack"]
for dtypes in [
[np.float32],
Expand All @@ -3349,15 +3358,21 @@ def testStack(self, shape, axis, dtypes, array_input):
[np.float32, np.int32, np.float64],
]
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]
for array_input in [True, False]))
def testHVDStack(self, shape, op, dtypes, array_input):
for array_input in [True, False]
for out_dtype in [np.float32, np.int32]))
def testHVDStack(self, shape, op, dtypes, array_input, out_dtype):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(getattr(np, op))
jnp_fun = getattr(jnp, op)

if numpy_version < (1, 24) or op == "dstack":
np_fun = _promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
else:
np_fun = partial(_promote_like_jnp(getattr(np, op)), dtype=out_dtype)

jnp_fun = partial(getattr(jnp, op), dtype=out_dtype)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
Expand Down Expand Up @@ -6388,7 +6403,7 @@ def testWrappedSignaturesMatch(self):
'einsum': ['kwargs'],
'einsum_path': ['einsum_call'],
'eye': ['order', 'like'],
'hstack': ['dtype', 'casting'],
'hstack': ['casting'],
'identity': ['like'],
'in1d': ['kind'],
'isin': ['kind'],
Expand All @@ -6400,11 +6415,11 @@ def testWrappedSignaturesMatch(self):
'histogramdd': ['normed'],
'ones': ['order', 'like'],
'ones_like': ['subok', 'order'],
'row_stack': ['dtype', 'casting'],
'stack': ['dtype', 'casting'],
'row_stack': ['casting'],
'stack': ['casting'],
'tri': ['like'],
'unique': ['equal_nan'],
'vstack': ['dtype', 'casting'],
'vstack': ['casting'],
'zeros_like': ['subok', 'order']
}

Expand Down

0 comments on commit 332d7d0

Please sign in to comment.