Skip to content

Commit

Permalink
jnp.ndarray.view: implement all dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 16, 2023
1 parent f323952 commit a9a264d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 43 deletions.
74 changes: 37 additions & 37 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4978,47 +4978,47 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:
This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`.
"""
lax_internal._check_user_dtype_supported(dtype, "view")
if type is not None:
raise NotImplementedError("`type` argument of array.view()")
if dtype is None:
return arr
arr_dtype = _dtype(arr)
if arr_dtype == dtype:
return arr
# bool is implemented as lax:PRED, which is not compatible with lax.bitcast_convert_type.
# We work around this by casting bool to uint8.
if arr_dtype == bool_:
arr = arr.astype(uint8)
nbits_in = 8 * arr_dtype.itemsize
nbits_out = 8 * np.dtype(dtype).itemsize
if nbits_in == nbits_out:
if dtype == bool_:
return lax.bitcast_convert_type(arr, uint8).astype(dtype)
return lax.bitcast_convert_type(arr, dtype)
if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0:
raise NotImplementedError("`type` argument of array.view() is not supported.")

_check_arraylike("view", arr)
arr = asarray(arr)

lax_internal._check_user_dtype_supported(dtype, "view")
dtype = dtypes.canonicalize_dtype(dtype)

if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0:
raise ValueError("When changing to a larger dtype, its size must be a divisor "
"of the total size in bytes of the last axis of the array.")
byte_dtypes: Dict[int, DType] = {8: np.dtype('uint8'), 16: np.dtype('uint16'),
32: np.dtype('uint32'), 64: np.dtype('uint64')}
if nbits_in not in byte_dtypes:
raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}")
if nbits_out not in byte_dtypes:
raise NotImplementedError(f"arr.view(dtype) for {dtype=}")
dt_in = byte_dtypes[nbits_in]
dt_out = byte_dtypes[nbits_out]
arr_bytes = lax.bitcast_convert_type(arr, dt_in)
if nbits_in < nbits_out:
arr_bytes = arr_bytes.reshape(arr.shape[:-1] + (-1, nbits_out // nbits_in)).astype(dt_out)
shifts = expand_dims(arange(0, nbits_out, nbits_in, dtype=dt_out), tuple(range(arr_bytes.ndim - 1)))
arr_bytes = (arr_bytes << shifts).sum(-1).astype(dt_out)

if arr.dtype == dtype:
return arr

if arr.dtype == bool:
return _view(arr.astype('uint8'), dtype)

if issubdtype(arr.dtype, complexfloating):
new_shape = (*arr.shape[:-1], arr.shape[-1] * 2)
new_dtype = finfo(arr.dtype).dtype
arr = zeros(new_shape, new_dtype).at[..., 0::2].set(arr.real).at[..., 1::2].set(arr.imag)
return _view(arr, dtype)

if dtype == bool:
return _view(arr, uint8).astype(bool)

if issubdtype(dtype, complexfloating):
out = _view(arr, finfo(dtype).dtype).astype(dtype)
return out[..., 0::2] + 1j * out[..., 1::2]

if arr.dtype.itemsize == dtype.itemsize:
return lax.bitcast_convert_type(arr, dtype)
elif arr.dtype.itemsize < dtype.itemsize:
factor = dtype.itemsize // arr.dtype.itemsize
arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor)
return lax.bitcast_convert_type(arr, dtype)
else:
shifts = lax.expand_dims(arange(0, nbits_in, nbits_out, dtype=dt_in), tuple(range(arr_bytes.ndim)))
arr_bytes = ((arr_bytes[..., newaxis] >> shifts) & iinfo(dt_out).max).astype(dt_out)
arr_bytes = arr_bytes.reshape(arr_bytes.shape[:-2] + (-1,))
if dtype == bool_:
return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype)
return lax.bitcast_convert_type(arr_bytes, dtype)
out = lax.bitcast_convert_type(arr, dtype)
return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1])

def _notimplemented_flat(self):
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
Expand Down
10 changes: 4 additions & 6 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3430,17 +3430,15 @@ def testItemsize(self, shape, dtype):
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
shape=[(8,), (3, 8)], # last dim = 8 to ensure shape compatibility
a_dtype=default_dtypes + unsigned_dtypes + bool_dtypes,
dtype=default_dtypes + unsigned_dtypes + bool_dtypes,
# Final dimension of shape must be a multiple of 16 to ensure compatibilty.
shape=[(0,), (32,), (2, 16)],
a_dtype=all_dtypes,
dtype=[*all_dtypes, *([None] if config.x64_enabled else [])],
)
def testView(self, shape, a_dtype, dtype):
if jtu.device_under_test() == 'tpu':
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
if not config.x64_enabled:
if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_fullrange(self.rng())
args_maker = lambda: [rng(shape, a_dtype)]
np_op = lambda x: np.asarray(x).view(dtype)
Expand Down

0 comments on commit a9a264d

Please sign in to comment.