diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a83306d57170..f2cffd8f67d8 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4977,47 +4977,60 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array: This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`. """ - lax_internal._check_user_dtype_supported(dtype, "view") if type is not None: - raise NotImplementedError("`type` argument of array.view()") - if dtype is None: - return arr - arr_dtype = _dtype(arr) - if arr_dtype == dtype: - return arr - # bool is implemented as lax:PRED, which is not compatible with lax.bitcast_convert_type. - # We work around this by casting bool to uint8. - if arr_dtype == bool_: - arr = arr.astype(uint8) - nbits_in = 8 * arr_dtype.itemsize - nbits_out = 8 * np.dtype(dtype).itemsize - if nbits_in == nbits_out: - if dtype == bool_: - return lax.bitcast_convert_type(arr, uint8).astype(dtype) - return lax.bitcast_convert_type(arr, dtype) - if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0: + raise NotImplementedError("`type` argument of array.view() is not supported.") + + _check_arraylike("view", arr) + arr = asarray(arr) + + lax_internal._check_user_dtype_supported(dtype, "view") + dtype = dtypes.canonicalize_dtype(dtype) + + if arr.ndim == 0: + if arr.dtype.itemsize != dtype.itemsize: + raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.") + return _view(lax.expand_dims(arr, (0,)), dtype).squeeze() + + if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0: raise ValueError("When changing to a larger dtype, its size must be a divisor " "of the total size in bytes of the last axis of the array.") - byte_dtypes: Dict[int, DType] = {8: np.dtype('uint8'), 16: np.dtype('uint16'), - 32: np.dtype('uint32'), 64: np.dtype('uint64')} - if nbits_in not in byte_dtypes: - raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}") - if nbits_out not in byte_dtypes: - raise NotImplementedError(f"arr.view(dtype) for {dtype=}") - dt_in = byte_dtypes[nbits_in] - dt_out = byte_dtypes[nbits_out] - arr_bytes = lax.bitcast_convert_type(arr, dt_in) - if nbits_in < nbits_out: - arr_bytes = arr_bytes.reshape(arr.shape[:-1] + (-1, nbits_out // nbits_in)).astype(dt_out) - shifts = expand_dims(arange(0, nbits_out, nbits_in, dtype=dt_out), tuple(range(arr_bytes.ndim - 1))) - arr_bytes = (arr_bytes << shifts).sum(-1).astype(dt_out) - else: - shifts = lax.expand_dims(arange(0, nbits_in, nbits_out, dtype=dt_in), tuple(range(arr_bytes.ndim))) - arr_bytes = ((arr_bytes[..., newaxis] >> shifts) & iinfo(dt_out).max).astype(dt_out) - arr_bytes = arr_bytes.reshape(arr_bytes.shape[:-2] + (-1,)) - if dtype == bool_: - return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype) - return lax.bitcast_convert_type(arr_bytes, dtype) + + if arr.dtype == dtype: + return arr + + # lax.bitcast_convert_type does not support bool or complex; in these cases we + # cast to a compatible type and recursively call _view for simplicity. + if arr.dtype == bool: + return _view(arr.astype('uint8'), dtype) + + if issubdtype(arr.dtype, complexfloating): + new_shape = (*arr.shape[:-1], arr.shape[-1] * 2) + new_dtype = finfo(arr.dtype).dtype + arr = (zeros(new_shape, new_dtype) + .at[..., 0::2].set(arr.real) + .at[..., 1::2].set(arr.imag)) + return _view(arr, dtype) + + if dtype == bool: + return _view(arr, uint8).astype(bool) + + if issubdtype(dtype, complexfloating): + out = _view(arr, finfo(dtype).dtype).astype(dtype) + return out[..., 0::2] + 1j * out[..., 1::2] + + # lax.bitcast_convert_type adds or subtracts dimensions depending on the + # relative bitwidths of the dtypes; we account for that with reshapes. + if arr.dtype.itemsize < dtype.itemsize: + factor = dtype.itemsize // arr.dtype.itemsize + arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor) + return lax.bitcast_convert_type(arr, dtype) + + if arr.dtype.itemsize > dtype.itemsize: + out = lax.bitcast_convert_type(arr, dtype) + return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) + + return lax.bitcast_convert_type(arr, dtype) + def _notimplemented_flat(self): raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: " diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 849375c42997..7e69469c8157 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3430,17 +3430,15 @@ def testItemsize(self, shape, dtype): self._CompileAndCheck(jnp_op, args_maker) @jtu.sample_product( - shape=[(8,), (3, 8)], # last dim = 8 to ensure shape compatibility - a_dtype=default_dtypes + unsigned_dtypes + bool_dtypes, - dtype=default_dtypes + unsigned_dtypes + bool_dtypes, + # Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs. + shape=[(0,), (32,), (2, 16)], + a_dtype=all_dtypes, + dtype=(*all_dtypes, None) if config.x64_enabled else all_dtypes, ) def testView(self, shape, a_dtype, dtype): if jtu.device_under_test() == 'tpu': if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") - if not config.x64_enabled: - if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8: - self.skipTest("x64 types are disabled by jax_enable_x64") rng = jtu.rand_fullrange(self.rng()) args_maker = lambda: [rng(shape, a_dtype)] np_op = lambda x: np.asarray(x).view(dtype) @@ -3450,6 +3448,25 @@ def testView(self, shape, a_dtype, dtype): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product([ + {'a_dtype': a_dtype, 'dtype': dtype} + for a_dtype in all_dtypes + for dtype in all_dtypes + if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize + ]) + def testViewScalar(self, a_dtype, dtype): + if jtu.device_under_test() == 'tpu': + if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: + self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") + rng = jtu.rand_fullrange(self.rng()) + args_maker = lambda: [jnp.array(rng((), a_dtype))] + np_op = lambda x: np.asarray(x).view(dtype) + jnp_op = lambda x: jnp.asarray(x).view(dtype) + # Above may produce signaling nans; ignore warnings from invalid values. + with np.errstate(invalid='ignore'): + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + def testPathologicalFloats(self): args_maker = lambda: [np.array([ 0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf