Skip to content

Commit

Permalink
Roll-back #14526 because it breaks view() on scalar inputs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 510281592
  • Loading branch information
Jake VanderPlas authored and jax authors committed Feb 17, 2023
1 parent c467d84 commit e1333f3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 49 deletions.
83 changes: 38 additions & 45 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -4978,54 +4978,47 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:
This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`.
"""
if type is not None:
raise NotImplementedError("`type` argument of array.view() is not supported.")

_check_arraylike("view", arr)
arr = asarray(arr)

lax_internal._check_user_dtype_supported(dtype, "view")
dtype = dtypes.canonicalize_dtype(dtype)

if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0:
raise ValueError("When changing to a larger dtype, its size must be a divisor "
"of the total size in bytes of the last axis of the array.")

if arr.dtype == dtype:
if type is not None:
raise NotImplementedError("`type` argument of array.view()")
if dtype is None:
return arr

# lax.bitcast_convert_type does not support bool or complex; in these cases we
# cast to a compatible type and recursively call _view for simplicity.
if arr.dtype == bool:
return _view(arr.astype('uint8'), dtype)

if issubdtype(arr.dtype, complexfloating):
new_shape = (*arr.shape[:-1], arr.shape[-1] * 2)
new_dtype = finfo(arr.dtype).dtype
arr = (zeros(new_shape, new_dtype)
.at[..., 0::2].set(arr.real)
.at[..., 1::2].set(arr.imag))
return _view(arr, dtype)

if dtype == bool:
return _view(arr, uint8).astype(bool)

if issubdtype(dtype, complexfloating):
out = _view(arr, finfo(dtype).dtype).astype(dtype)
return out[..., 0::2] + 1j * out[..., 1::2]

# lax.bitcast_convert_type adds or subtracts dimensions depending on the
# relative bitwidths of the dtypes; we account for that with reshapes.
if arr.dtype.itemsize < dtype.itemsize:
factor = dtype.itemsize // arr.dtype.itemsize
arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor)
arr_dtype = _dtype(arr)
if arr_dtype == dtype:
return arr
# bool is implemented as lax:PRED, which is not compatible with lax.bitcast_convert_type.
# We work around this by casting bool to uint8.
if arr_dtype == bool_:
arr = arr.astype(uint8)
nbits_in = 8 * arr_dtype.itemsize
nbits_out = 8 * np.dtype(dtype).itemsize
if nbits_in == nbits_out:
if dtype == bool_:
return lax.bitcast_convert_type(arr, uint8).astype(dtype)
return lax.bitcast_convert_type(arr, dtype)

if arr.dtype.itemsize > dtype.itemsize:
out = lax.bitcast_convert_type(arr, dtype)
return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1])

return lax.bitcast_convert_type(arr, dtype)
if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0:
raise ValueError("When changing to a larger dtype, its size must be a divisor "
"of the total size in bytes of the last axis of the array.")
byte_dtypes: Dict[int, DType] = {8: np.dtype('uint8'), 16: np.dtype('uint16'),
32: np.dtype('uint32'), 64: np.dtype('uint64')}
if nbits_in not in byte_dtypes:
raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}")
if nbits_out not in byte_dtypes:
raise NotImplementedError(f"arr.view(dtype) for {dtype=}")
dt_in = byte_dtypes[nbits_in]
dt_out = byte_dtypes[nbits_out]
arr_bytes = lax.bitcast_convert_type(arr, dt_in)
if nbits_in < nbits_out:
arr_bytes = arr_bytes.reshape(arr.shape[:-1] + (-1, nbits_out // nbits_in)).astype(dt_out)
shifts = expand_dims(arange(0, nbits_out, nbits_in, dtype=dt_out), tuple(range(arr_bytes.ndim - 1)))
arr_bytes = (arr_bytes << shifts).sum(-1).astype(dt_out)
else:
shifts = lax.expand_dims(arange(0, nbits_in, nbits_out, dtype=dt_in), tuple(range(arr_bytes.ndim)))
arr_bytes = ((arr_bytes[..., newaxis] >> shifts) & iinfo(dt_out).max).astype(dt_out)
arr_bytes = arr_bytes.reshape(arr_bytes.shape[:-2] + (-1,))
if dtype == bool_:
return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype)
return lax.bitcast_convert_type(arr_bytes, dtype)

def _notimplemented_flat(self):
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
Expand Down
10 changes: 6 additions & 4 deletions tests/lax_numpy_test.py
Expand Up @@ -3430,15 +3430,17 @@ def testItemsize(self, shape, dtype):
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
# Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs.
shape=[(0,), (32,), (2, 16)],
a_dtype=all_dtypes,
dtype=(*all_dtypes, None) if config.x64_enabled else all_dtypes,
shape=[(8,), (3, 8)], # last dim = 8 to ensure shape compatibility
a_dtype=default_dtypes + unsigned_dtypes + bool_dtypes,
dtype=default_dtypes + unsigned_dtypes + bool_dtypes,
)
def testView(self, shape, a_dtype, dtype):
if jtu.device_under_test() == 'tpu':
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
if not config.x64_enabled:
if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_fullrange(self.rng())
args_maker = lambda: [rng(shape, a_dtype)]
np_op = lambda x: np.asarray(x).view(dtype)
Expand Down

0 comments on commit e1333f3

Please sign in to comment.