Skip to content

Commit

Permalink
Fix rank promotion warning in DeviceArray.view()
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 8, 2021
1 parent 8e86952 commit 61be1e5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -5828,11 +5828,11 @@ def _view(arr, dtype=None, type=None):
dt_out = byte_dtypes[nbits_out]
arr_bytes = lax.bitcast_convert_type(arr, dt_in)
if nbits_in < nbits_out:
shifts = arange(0, nbits_out, nbits_in, dtype=dt_out)
arr_bytes = arr_bytes.reshape(arr.shape[:-1] + (-1, nbits_out // nbits_in)).astype(dt_out)
shifts = expand_dims(arange(0, nbits_out, nbits_in, dtype=dt_out), tuple(range(arr_bytes.ndim - 1)))
arr_bytes = (arr_bytes << shifts).sum(-1).astype(dt_out)
else:
shifts = arange(0, nbits_in, nbits_out, dtype=dt_in)
shifts = lax.expand_dims(arange(0, nbits_in, nbits_out, dtype=dt_in), tuple(range(arr_bytes.ndim)))
arr_bytes = ((arr_bytes[..., newaxis] >> shifts) & iinfo(dt_out).max).astype(dt_out)
arr_bytes = arr_bytes.reshape(arr_bytes.shape[:-2] + (-1,))
if dtype == bool_:
Expand Down
1 change: 1 addition & 0 deletions tests/lax_numpy_test.py
Expand Up @@ -3787,6 +3787,7 @@ def testNbytes(self, shape, dtype):
for shape in [(8,), (3, 8)] # last dim = 8 to ensure shape compatibility
for a_dtype in (default_dtypes + unsigned_dtypes + bool_dtypes)
for dtype in (default_dtypes + unsigned_dtypes + bool_dtypes)))
@jax.numpy_rank_promotion('raise')
def testView(self, shape, a_dtype, dtype):
if jtu.device_under_test() == 'tpu':
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
Expand Down

0 comments on commit 61be1e5

Please sign in to comment.