Skip to content

Commit

Permalink
Allow multiple indexers when doing discharge or swap in pallas
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629847808
  • Loading branch information
jax authors authored and selamw1 committed May 2, 2024
1 parent 4b85aa3 commit dbbeb0c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 33 deletions.
72 changes: 39 additions & 33 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,22 +242,23 @@ def _prepend_scatter(x, indexer, val, *, add=False):

def _get_discharge(x, idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
if len(indexers) > 1:
raise NotImplementedError("Only single indexer is supported.")
indexer = indexers[0]
if _is_trivial_indexer(indexer):
return x
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
y = lax_slicing.dynamic_slice(x, starts, sizes)
return lax.squeeze(y, squeeze_dims)
indexer = _convert_to_array_indexer(indexer)
if indexer is None:
return x
return x[None][(np.array(0, 'int32'), *indexer)]
result = x
for indexer in indexers:
if _is_trivial_indexer(indexer):
continue
if indexer is None:
continue
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
y = lax_slicing.dynamic_slice(result, starts, sizes)
result = lax.squeeze(y, squeeze_dims)
else:
indexer = _convert_to_array_indexer(indexer)
result = result[None][(np.array(0, "int32"), *indexer)]
return result

def _indexer(idx, indexed_dims):
idx_ = iter(idx)
Expand All @@ -276,23 +277,28 @@ def _swap_discharge_rule(

def _swap_discharge(x, val, idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
if len(indexers) > 1:
raise NotImplementedError("Only single indexer is supported.")
indexer = indexers[0]
if _is_trivial_indexer(indexer):
return x, val
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
x_old = lax_slicing.dynamic_slice(x, starts, sizes)
val = lax.expand_dims(val, squeeze_dims)
y = lax_slicing.dynamic_update_slice(x, val, starts)
return lax.squeeze(x_old, squeeze_dims), y
indexer = _convert_to_array_indexer(indexer)
x_old = _prepend_gather(x, indexer)
return x_old, _prepend_scatter(x, indexer, val)

result = x
result_val = val
for indexer in indexers:
if _is_trivial_indexer(indexer):
continue
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
result_old = lax_slicing.dynamic_slice(result, starts, sizes)
result_val = lax.expand_dims(result_val, squeeze_dims)
y = lax_slicing.dynamic_update_slice(result, result_val, starts)
result = lax.squeeze(result_old, squeeze_dims)
result_val = y
else:
indexer = _convert_to_array_indexer(indexer)
result_old = _prepend_gather(result, indexer)
result_val = _prepend_scatter(result, indexer, result_val)
result = result_old
return result, result_val

@register_discharge_rule(addupdate_p)
def _addupdate_discharge_rule(
Expand Down
49 changes: 49 additions & 0 deletions tests/pallas/indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from jax._src import util
from jax._src.state import indexing
import numpy as np
import jax.numpy as jnp
from jax.experimental import pallas as pl

try:
import hypothesis as hp
Expand Down Expand Up @@ -197,5 +199,52 @@ def test_ndindexer(self, data):
indexer.get_indexer_shape())


def test_multi_indexing_interpreter_only(self):
# Interpreter only test! YMMV actually compiling this.
def permute(left, right, left_out_ref, right_out_ref):
left_out = jnp.zeros_like(left)
left_out = left_out.at[:, 0].set(left[:, 0])
left_out = left_out.at[:, 1].set(right[:, 0])
left_out = left_out.at[:, 2:].set(left[:, 1:-1])

right_out = jnp.zeros_like(right)
right_out = right_out.at[:, :-1].set(right[:, 1:])
right_out = right_out.at[:, -1].set(left[:, -1])

left_out_ref[...] = left_out
right_out_ref[...] = right_out

def invoke_permutes(x_ref, y_ref, x_out_ref, y_out_ref):
shape = x_ref.shape
_, n = shape[-2], shape[-1]
x_ref = x_ref.at[: n // 2, : n // 2]
y_ref = y_ref.at[: n // 2, : n // 2]
x_out_ref = x_out_ref.at[: n // 2, : n // 2]
y_out_ref = y_out_ref.at[: n // 2, : n // 2]
permute(x_ref, y_ref, x_out_ref, y_out_ref)

n = 8
x = jnp.ones([n, n])
y = jnp.ones([n, n])
jitted_permute = jax.jit(invoke_permutes)
grid = (1,)
pl.pallas_call(
jitted_permute,
grid=grid,
out_shape=[
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape, y.dtype),
],
in_specs=[
pl.BlockSpec(lambda i: (0, 0), x.shape),
pl.BlockSpec(lambda i: (0, 0), y.shape),
],
out_specs=[
pl.BlockSpec(lambda i: (0, 0), x.shape),
pl.BlockSpec(lambda i: (0, 0), y.shape),
],
interpret=True,
)(x, y)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit dbbeb0c

Please sign in to comment.