Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.ndarray.view: implement all dtypes #14526

Merged
merged 1 commit into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 45 additions & 38 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4978,47 +4978,54 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:

This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`.
"""
lax_internal._check_user_dtype_supported(dtype, "view")
if type is not None:
raise NotImplementedError("`type` argument of array.view()")
if dtype is None:
return arr
arr_dtype = _dtype(arr)
if arr_dtype == dtype:
return arr
# bool is implemented as lax:PRED, which is not compatible with lax.bitcast_convert_type.
# We work around this by casting bool to uint8.
if arr_dtype == bool_:
arr = arr.astype(uint8)
nbits_in = 8 * arr_dtype.itemsize
nbits_out = 8 * np.dtype(dtype).itemsize
if nbits_in == nbits_out:
if dtype == bool_:
return lax.bitcast_convert_type(arr, uint8).astype(dtype)
return lax.bitcast_convert_type(arr, dtype)
if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0:
raise NotImplementedError("`type` argument of array.view() is not supported.")

_check_arraylike("view", arr)
arr = asarray(arr)

lax_internal._check_user_dtype_supported(dtype, "view")
dtype = dtypes.canonicalize_dtype(dtype)

if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0:
raise ValueError("When changing to a larger dtype, its size must be a divisor "
"of the total size in bytes of the last axis of the array.")
byte_dtypes: Dict[int, DType] = {8: np.dtype('uint8'), 16: np.dtype('uint16'),
32: np.dtype('uint32'), 64: np.dtype('uint64')}
if nbits_in not in byte_dtypes:
raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}")
if nbits_out not in byte_dtypes:
raise NotImplementedError(f"arr.view(dtype) for {dtype=}")
dt_in = byte_dtypes[nbits_in]
dt_out = byte_dtypes[nbits_out]
arr_bytes = lax.bitcast_convert_type(arr, dt_in)
if nbits_in < nbits_out:
arr_bytes = arr_bytes.reshape(arr.shape[:-1] + (-1, nbits_out // nbits_in)).astype(dt_out)
shifts = expand_dims(arange(0, nbits_out, nbits_in, dtype=dt_out), tuple(range(arr_bytes.ndim - 1)))
arr_bytes = (arr_bytes << shifts).sum(-1).astype(dt_out)
else:
shifts = lax.expand_dims(arange(0, nbits_in, nbits_out, dtype=dt_in), tuple(range(arr_bytes.ndim)))
arr_bytes = ((arr_bytes[..., newaxis] >> shifts) & iinfo(dt_out).max).astype(dt_out)
arr_bytes = arr_bytes.reshape(arr_bytes.shape[:-2] + (-1,))
if dtype == bool_:
return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype)
return lax.bitcast_convert_type(arr_bytes, dtype)

if arr.dtype == dtype:
return arr

# lax.bitcast_convert_type does not support bool or complex; in these cases we
# cast to a compatible type and recursively call _view for simplicity.
if arr.dtype == bool:
return _view(arr.astype('uint8'), dtype)

if issubdtype(arr.dtype, complexfloating):
new_shape = (*arr.shape[:-1], arr.shape[-1] * 2)
new_dtype = finfo(arr.dtype).dtype
arr = (zeros(new_shape, new_dtype)
.at[..., 0::2].set(arr.real)
.at[..., 1::2].set(arr.imag))
return _view(arr, dtype)

if dtype == bool:
return _view(arr, uint8).astype(bool)

if issubdtype(dtype, complexfloating):
out = _view(arr, finfo(dtype).dtype).astype(dtype)
return out[..., 0::2] + 1j * out[..., 1::2]

# lax.bitcast_convert_type adds or subtracts dimensions depending on the
# relative bitwidths of the dtypes; we account for that with reshapes.
if arr.dtype.itemsize < dtype.itemsize:
factor = dtype.itemsize // arr.dtype.itemsize
arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor)
return lax.bitcast_convert_type(arr, dtype)

if arr.dtype.itemsize > dtype.itemsize:
out = lax.bitcast_convert_type(arr, dtype)
return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1])

return lax.bitcast_convert_type(arr, dtype)

def _notimplemented_flat(self):
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
Expand Down
10 changes: 4 additions & 6 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3430,17 +3430,15 @@ def testItemsize(self, shape, dtype):
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
shape=[(8,), (3, 8)], # last dim = 8 to ensure shape compatibility
a_dtype=default_dtypes + unsigned_dtypes + bool_dtypes,
dtype=default_dtypes + unsigned_dtypes + bool_dtypes,
# Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs.
shape=[(0,), (32,), (2, 16)],
a_dtype=all_dtypes,
dtype=(*all_dtypes, None) if config.x64_enabled else all_dtypes,
)
def testView(self, shape, a_dtype, dtype):
if jtu.device_under_test() == 'tpu':
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
if not config.x64_enabled:
if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_fullrange(self.rng())
args_maker = lambda: [rng(shape, a_dtype)]
np_op = lambda x: np.asarray(x).view(dtype)
Expand Down