Skip to content

Commit

Permalink
jax/pallas support ellipsis indexing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634922391
  • Loading branch information
Google-ML-Automation authored and jax authors committed May 17, 2024
1 parent 02c19e9 commit 641d5c8
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
16 changes: 12 additions & 4 deletions jax/_src/state/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import Any, Union
from typing import Any, Union, List

from jax._src import core
from jax._src import tree_util
Expand Down Expand Up @@ -190,9 +190,17 @@ def from_indices_shape(cls, indices, shape) -> NDIndexer:
if len(indices) == 1 and indices[0] is ...:
indices = (slice(None),) * len(shape)
if any(idx is ... for idx in indices):
# TODO(sharadmv,mattjj): support patterns that include ellipsis in them
# e.g. x[0, ..., 1].
raise NotImplementedError("Ellipsis in indexer not supported yet.")
new_indices : List[Any] = []
num_ellipsis = sum(1 for idx in indices if idx is ...)
if num_ellipsis > 1:
raise ValueError("Only one ellipsis is supported.")
for idx in indices:
if idx is ...:
expand = (slice(None),) * (len(shape) - len(indices) + 1)
new_indices.extend(expand)
else:
new_indices.append(idx)
indices = tuple(new_indices)
if len(indices) > len(shape):
raise ValueError("`indices` must not be longer than `shape`: "
f"{indices=}, {shape=}")
Expand Down
56 changes: 56 additions & 0 deletions tests/pallas/indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,5 +246,61 @@ def invoke_permutes(x_ref, y_ref, x_out_ref, y_out_ref):
interpret=True,
)(x, y)

def test_ellipsis_indexing_iterpret_only(self):
# Interpreter only test! YMMV actually compiling this.
def permute_columns_in_row_kernel(left, right, new_left, new_right):
shape = left.shape
k = shape[-1]
ndim = len(shape)
left_slices = [
left[..., :1],
right[..., :1],
left[..., 1:k-1]
]
right_slices = [
right[..., 1:k],
left[..., k-1:k]
]
new_left[...] = np.concatenate(left_slices, axis=ndim - 1)
new_right[...] = np.concatenate(right_slices, axis=ndim - 1)

left = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32)
right = jnp.array([[7, 8, 9], [10, 11, 12]], dtype=jnp.float32)

output_shape = left.shape

# hack to reuse the same fn for np cat
import jax.numpy as np # noqa: F811
left_out, right_out = pl.pallas_call(
permute_columns_in_row_kernel,
grid=(1,),
out_shape=[
jax.ShapeDtypeStruct(output_shape, jnp.float32),
jax.ShapeDtypeStruct(output_shape, jnp.float32)
],
in_specs=[
pl.BlockSpec(lambda i: (0, 0), left.shape),
pl.BlockSpec(lambda i: (0, 0), right.shape)
],
out_specs=[
pl.BlockSpec(lambda i: (0, 0), output_shape),
pl.BlockSpec(lambda i: (0, 0), output_shape)
],
interpret=True,
)(left, right)


import numpy as np # noqa: F811
left_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
right_np = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float32)
left_out_np = left_np.copy()
right_out_np = right_np.copy()


permute_columns_in_row_kernel(left_np, right_np, left_out_np, right_out_np)
np.testing.assert_array_equal(left_out_np, left_out)
np.testing.assert_array_equal(right_out_np, right_out)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 641d5c8

Please sign in to comment.