# Advanced Indexing

A brief tutorial on numpy's [advanced indexing](https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing) as it pertains to the `Indexing` dialect (see [here](https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707)).

Note this notebook uses starred-unpacking for index expressions which requires Python >= 3.11

In [1]:
import numpy as np
from numpy.random import randint

## Intro

Consider extracting a collection of slices of a 3-D shaped like ~ $2 \times 30 \times 4$:

In [2]:
input = randint(0, 10, (3, 30, 4))
a = input[:, 0, :]
b = input[:, 5, :]
c = input[:, 14, :]
d = input[:, 24, :]
# ...

This can also be done more easily using "advanced indexing":

In [3]:
abcd = input[:,(0, 5, 14, 24), :]
assert np.array_equal(abcd[:, 0, :], a)
assert np.array_equal(abcd[:, 1, :], b)
assert np.array_equal(abcd[:, 2, :], c)
assert np.array_equal(abcd[:, 3, :], d)
assert abcd.shape == (3, 4, 4)

You can also simultaneously reshape the result by provinding the indexes in some shape:

In [4]:
indices = np.array([[0, 5], [14, 24]])
abcd = input[:, indices, :]
assert np.array_equal(abcd[:, 0, 0, :], a)
assert np.array_equal(abcd[:, 0, 1, :], b)
assert np.array_equal(abcd[:, 1, 0, :], c)
assert np.array_equal(abcd[:, 1, 1, :], d)
assert abcd.shape == (3, 2, 2, 4)

Notice that the result is expanded to accomodate both the shapes of the `input` and the `indices`. By the way, another way to index into `abcd` is starred-tuple unpacking:

In [5]:
assert np.array_equal(abcd[:, *(0, 0), :], a)
assert np.array_equal(abcd[:, *(0, 1), :], b)
assert np.array_equal(abcd[:, *(1, 0), :], c)
assert np.array_equal(abcd[:, *(1, 1), :], d)

It's also important to understand what `abcd` is; while the original array has dimensions $2 \times 30 \times 4$ while `abcd` has distinct (different from `input`) interior dimensions. $2 \times 2$ seems somehow smaller/related to $30$ but it's not; the indexing dimensions determine output side _at those dimensions_:

In [6]:
indices_ = randint(0, 30, (100, 100))
abcd_ = input[:, indices_, :]
assert abcd_.shape == (3, 100, 100, 4)

What's happening here is not that you're creating a larger array; in numpy `abcd` is a ["view"](https://numpy.org/doc/stable/user/basics.copies.html) (something like an array of pointers to the original data) on the original `input` and no copies are (necessarily) performed.

## Gather

The previous example (indexing along one dimension) directly corresponds to the numpy operation [`take`](https://numpy.org/doc/stable/reference/generated/numpy.take.html):

In [7]:
assert np.array_equal(np.take(input, indices, axis=1), abcd)

and therefore directly corresponds to the operation [`tensor.gather`](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorgather-mlirtensorgatherop) as well (with a few small differences):

```mlir
%abcd = tensor.gather %input[%indices] gather_dims([1]) :
    (tensor<3x30x4xf32>, tensor<2x2x 1xindex>) -> tensor<2x2x3x1x4xf32>
```

Note the slight (matter of convention) differences:

1. The `%indices` tensor is "unsqueezed" in the trailing dimension (there's a sort-of redundant `x1` in the shape `2x2x1`)
2. The result tensor (`%abcd`) has axes permuted so that the shape of the `%indices` tensor is leading (the `2x2x` is "in front"). The effect of this choice of conventions is that (figuratively) `%input[:, %indices[0, 0], :] == %abcd[0, 0, ...]` i.e., the actual indices can be placed in leading position in both the indexing tensor and the result.
3. The result tensor preserves the rank of the `%input` tensor (`3x1x4` instead of `3x4`). This can actually be reconciled with numpy by using the "rank-reducing" form of `tensor.gather`.

`tensor.gather` is "more powerful" than `numpy.take`; a slightly more complicated example:

```mlir
%out = tensor.gather %input[%indices] gather_dims([0, 2]) :
  (tensor<20x3x40xf32>, tensor<5x6x 2xindex>) -> tensor<5x6x3xf32>
```

which corresponds to:

In [8]:
input = randint(0, 10, (20, 3, 40))
indices_0 = randint(0, 20, (5, 6))
indices_1 = randint(0, 40, (5, 6))
out = input[indices_0, :, indices_1]
assert out.shape == (5, 6, 3)

for index in np.ndindex(5, 6):
  coord_0, coord_1 = indices_0[index], indices_1[index]
  assert np.array_equal(out[index], input[coord_0, :, coord_1])

It's important to consider how advanced indexing works in order to draw the analogy: `input[indices[:, :, 0], :, indices[:, :, 1]]` means take the $5 \times 6$ collection of numbers `indices[:, :, 0]` (which will index the $0$th dimension of `input`), pair them with the corresponding numbers in `indices[:, :, 1]` (which will index the $2$nd dimension of `input`) and use those pairs to slice `input` (along the $1$st dimension). 

Note, that something like 

In [9]:
indices_0 = randint(0, 10, (5, 6))
indices_1 = randint(0, 10, (10, 3))

won't work because the shapes are different and cannot be reconciled (via broadcast):

In [10]:
indices_0 = randint(0, 10, (5, 6))
indices_1 = randint(0, 10, (10, 3))

try:
    out = input[indices_0, :, indices_1]
except IndexError as e:
    assert e.args[0].strip() == 'shape mismatch: indexing arrays could not be broadcast together with shapes (5,6) (10,3)'

## Scatter

[`tensor.scatter`](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorscatter-mlirtensorscatterop) corresponds to numpy's [`put`](https://numpy.org/doc/stable/reference/generated/numpy.put.html) operation and thus is symmetrically related to `tensor.gather`:

```mlir
%out = tensor.scatter %source into %dest[%indices] scatter_dims([1]) :
    tensor<5x6x4x4xf32> into tensor<4x100x4xf32>[tensor<5x6x 1xindex>] -> tensor<4x100x4xf32>
```

corresponds to

In [11]:
source = randint(0, 10, (5, 6, 4, 4))
dest = np.zeros((4, 100, 4))
indices = randint(0, 100, (5, 6))

# ensure that indices are unique to prevent WAW hazard
while len(np.unique(indices)) != 30:
  indices = randint(0, 100, (5, 6))

# shuffle axes so that source index dims "line" up with dest index dim
source_ = np.moveaxis(source, (0, 1), (1, 2))
assert source_.shape == (4, 5, 6, 4)

dest[:, indices, :] = source_

for index in np.ndindex(5, 6):
  coord = indices[*index]
  assert np.array_equal(dest[:, coord, :], source[*index])

**Stated in words**: scatter thirty slices (of size $4 \times 4$) from `source` into `dest` where the "location" `k` (i.e., `dest[:, k, :]`) of the slice `source[i, j, :, :]` is specified by `k = indices[i, j]`. Two things to note:

1. We need to ensure that the "location" `k` is unique so that inserting a slice doesn't overwrite another slice that had already been inserted. In general, non-unique `k` has "undefined behavior" semantics (i.e., all bets are off regarding the results).
2. Leading index dims (`5, 6, ...`) corresponds to `tensor.scatter` semantics but numpy expects index dims to line up with the dimension that is being indexed in `dest` (hence the `np.moveaxis`).

Scattering along multiple dimensions looks symmetrical with the `tensor.gather` case: 

```mlir
%out = scatter %source into %dest[%indices] scatter_dims([0, 2]) :
   tensor<5x6x4xf32> into tensor<1000x4x1000xf32>[tensor<5x6x 2xindex>] -> tensor<1000x4x1000xf32>
```

corresponds to

In [12]:
source = randint(0, 10, (5, 6, 4))
dest = np.zeros((1000, 4, 1000))
indices = randint(0, 1000, (5, 6, 2))

# ensure that indices are unique to prevent WAW hazard
while len(np.unique(indices)) != 60:
  indices = randint(0, 1000, (5, 6, 2))

# no shuffle needed because instead we're explicitly slicing the indices
dest[indices[:, :, 0], :, indices[:, :, 1]] = source

for index in np.ndindex(5, 6):
  coord = indices[*index]
  assert np.array_equal(dest[coord[0], :, coord[1]], source[index])
    
assert dest.shape == (1000, 4, 1000)

Note we don't need a `np.moveaxis` here because we're explicitly slicing `indices`; imagine writing thirty $4$ entry slices (i.e., the slices `source[i, j, :]`) to $5 \times 6$ pairs in `dest`, where `dest[k, :, m]` and `k = indices[i, j, 0]` and `m = indices[i, j, 1]`.

# Biblio

* [[RFC] Adding Gather, Scatter Ops](https://discourse.llvm.org/t/rfc-adding-gather-scatter-ops/64757)

* [[RFC] Structured Codegen Beyond Rectangular Arrays](https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707)

* [What does the gather function do in pytorch in layman terms?](https://stackoverflow.com/a/54706716)