Skip to content

Commit

Permalink
PRNGKeyArray: fix dynamic slice index dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 10, 2023
1 parent 70f0cc4 commit 6ada878
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
6 changes: 4 additions & 2 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@ def slice_mlir(ctx, aval_out, x, start_indices, limit_indices, strides) -> ir.Va

@staticmethod
def dynamic_slice_mlir(ctx, aval_out, x, start_indices) -> ir.Value:
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
index_avals = ctx.avals_in[1:]
dtype = dtypes.canonicalize_dtype(index_avals[0].dtype if index_avals else 'int64')
key_shape = aval_out.dtype.impl.key_shape
trailing_zeros = [mlir.ir_constant(np.array(0, dtype))] * len(key_shape)
start_indices = (*start_indices, *trailing_zeros)
Expand All @@ -492,7 +493,8 @@ def dynamic_slice_mlir(ctx, aval_out, x, start_indices) -> ir.Value:

@staticmethod
def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices) -> ir.Value:
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
index_avals = ctx.avals_in[2:]
dtype = dtypes.canonicalize_dtype(index_avals[0].dtype if index_avals else 'int64')
key_shape = aval_out.dtype.impl.key_shape
zeros = [mlir.ir_constant(np.array(0, dtype=dtype))] * len(key_shape)
start_indices = (*start_indices, *zeros)
Expand Down
11 changes: 10 additions & 1 deletion tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,10 +1801,19 @@ def test_slice(self):

def test_dynamic_slice(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x, i: lax.dynamic_slice_in_dim(x, i, 2))(ks, 1)
index = np.int16(1) # non-default int type to catch type errors.
ys = jax.jit(partial(lax.dynamic_slice_in_dim, slice_size=2))(ks, index)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (2, 4))

def test_dynamic_update_slice(self):
ks = self.make_keys(3, 4)
k = self.make_keys(1, 4)
index = np.int16(1) # non-default int type to catch type errors.
ys = jax.jit(partial(lax.dynamic_update_slice_in_dim, axis=0))(ks, k, index)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (3, 4))

def test_transpose(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x.T)(ks)
Expand Down

0 comments on commit 6ada878

Please sign in to comment.