# Shapes and Indexing with JAX

Here we'll look at how JAX manages shapes, broadcasting, and indexing.

## Shapes

Like NumPy, JAX uses `.shape` to represent the shape of a tensor and `.reshape` to change the shape. Because JAX requires immutability, there is no `.view()` equivalent to PyTorch or `resize()` like in NumPy. Calls to `.reshape` return a new array that may share (default) or create new data if needed.

Unlike NumPy or Torch, JAX does NOT provide a way to get at the underlying data and instead manages all data under the hood. This is to prevent unexpected outcomes and preserve immutability.

In [78]:
import jax
import jax.numpy as jnp

key = jax.random.key(42)

key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (100, 20), dtype=jnp.float32)

x2 = x.reshape(20, 5, 5, 4)

print("x: ", x.shape)
print("x2: ", x2.shape)
print("x is x2: ", x is x2)

try:
    print("x.data is x2.data: ", x.data is x2.data)
except AttributeError:
    print("Ha ha! No data for you!")

x:  (100, 20)
x2:  (20, 5, 5, 4)
x is x2:  False
Ha ha! No data for you!


## Broadcasting

Like NumPy, shape broadcasting is performed under the hood.

Note the dimension expansion favors adding new axes at the beginning.

In [79]:
key, subkey_x, subkey_y = jax.random.split(key, 3)
x = jax.random.normal(subkey_x, (5, 3), dtype=jnp.float32)
y = jax.random.normal(subkey_y, (3,), dtype=jnp.float32)

print(f"Add {x.shape} and {y.shape}:")
z = x + y
print(z.shape)


w = jax.random.normal(key, (3, 3), dtype=jnp.float32)

print(f"Add {w.shape} and {y.shape}:")
q = w + y
print(q.shape)

print(f"Umm ... what? Which axis is broadcasted?")
print(f"w: {w}")
print(f"y: {y}")
print(f"w + y: {q}")
print(f"explicitly broadcast y[None, :] : {w + y[None, :]}")
print("Ok, so it expands the first dimension of y to match w.")

Add (5, 3) and (3,):
(5, 3)
Add (3, 3) and (3,):
(3, 3)
Umm ... what? Which axis is broadcasted?
w: [[-0.7197971   1.5521808  -0.8557356 ]
 [ 0.270705    0.18473469  0.97469676]
 [-0.341975   -1.2624321  -0.22399448]]
y: [-1.8259704  -0.40702963  0.553828  ]
w + y: [[-2.5457675   1.1451511  -0.3019076 ]
 [-1.5552654  -0.22229494  1.5285248 ]
 [-2.1679454  -1.6694617   0.3298335 ]]
explicitly broadcast y[None, :] : [[-2.5457675   1.1451511  -0.3019076 ]
 [-1.5552654  -0.22229494  1.5285248 ]
 [-2.1679454  -1.6694617   0.3298335 ]]
Ok, so it expands the first dimension of y to match w.


Let's see what JAX does with axes having one element.

In [80]:
key, subkey_x, subkey_y = jax.random.split(key, 3)
x = jax.random.normal(subkey_x, (1, 3), dtype=jnp.float32)
y = jax.random.normal(subkey_y, (3, 1), dtype=jnp.float32)

print(f"x: {x}")
print(f"y: {y}")
print(f"x * y: {x * y}")


x: [[-0.39489815 -0.42596528 -1.0719717 ]]
y: [[0.28066227]
 [0.4480957 ]
 [1.9998703 ]]
x * y: [[-0.11083301 -0.11955238 -0.30086198]
 [-0.17695217 -0.19087322 -0.4803459 ]
 [-0.7897451  -0.8518753  -2.1438043 ]]


## Slice Indexing

Reads are all normal.

In [81]:
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (2, 3), dtype=jnp.float32)

print(f"x: {x}")
print(f"x[0, 0]: {x[0, 0]}")
print(f"x[:, 1]: {x[:, 1]}")
print(f"x[2, :]: {x[2, :]}")

x: [[ 0.6733885  -0.50893074 -0.28400496]
 [ 2.03002     0.61162853 -1.9237745 ]]
x[0, 0]: 0.6733884811401367
x[:, 1]: [-0.50893074  0.61162853]
x[2, :]: [ 2.03002     0.61162853 -1.9237745 ]


We can use either `None` or `jnp.newaxis` to add new dimensions.

In [82]:
print(f"x[None, :, :]: {x[None, :, :]}")
print(f"x[jnp.newaxis, :, :]: {x[jnp.newaxis, :, :]}")

x[None, :, :]: [[[ 0.6733885  -0.50893074 -0.28400496]
  [ 2.03002     0.61162853 -1.9237745 ]]]
x[jnp.newaxis, :, :]: [[[ 0.6733885  -0.50893074 -0.28400496]
  [ 2.03002     0.61162853 -1.9237745 ]]]


Can we do skipped steps and reversals? Yes we can.

In [83]:
print(f"Skipping every other element: x[0, ::2]: {x[0, ::2]}")
print(f"Reversing: x[::-1]: {x[::-1]}")
print(f"Reversing rows: x[::-1, :]: {x[::-1, :]}")
print(f"Reversing rows and skipping every other element: x[::-1, ::2]: {x[::-1, ::2]}")


Skipping every other element: x[0, ::2]: [ 0.6733885  -0.28400496]
Reversing: x[::-1]: [[ 2.03002     0.61162853 -1.9237745 ]
 [ 0.6733885  -0.50893074 -0.28400496]]
Reversing rows: x[::-1, :]: [[ 2.03002     0.61162853 -1.9237745 ]
 [ 0.6733885  -0.50893074 -0.28400496]]
Reversing rows and skipping every other element: x[::-1, ::2]: [[ 2.03002    -1.9237745 ]
 [ 0.6733885  -0.28400496]]


Assignment runs afoul of JAX's immutability requirements, though.

In [84]:
x[0, 0] = 100

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

Instead, we use `.at()`, in a pattern reminiscent of Theano long ago. This creates a object that wraps the original array; when `.set()` is called on this wrapper, it creates a new array.

In [None]:
print(x.at[0, 0].__class__)

y = x.at[0, 0].set(100)
print(y)

a = x.at[0,0]
z = a.set(20)
w = a.set(10)

print("a", a)
print("x", x)
print("z", z)
print("w", w)


<class 'jax._src.numpy.array_methods._IndexUpdateRef'>
[[100.          -0.50893074  -0.28400496]
 [  2.03002      0.61162853  -1.9237745 ]]
a _IndexUpdateRef(Array([[ 0.6733885 , -0.50893074, -0.28400496],
       [ 2.03002   ,  0.61162853, -1.9237745 ]], dtype=float32), (0, 0))
x [[ 0.6733885  -0.50893074 -0.28400496]
 [ 2.03002     0.61162853 -1.9237745 ]]
z [[20.         -0.50893074 -0.28400496]
 [ 2.03002     0.61162853 -1.9237745 ]]
w [[10.         -0.50893074 -0.28400496]
 [ 2.03002     0.61162853 -1.9237745 ]]


Beyond that, assignment works like you would expect.

In [None]:
key, subkey_x, subkey_y = jax.random.split(key, 3)
x = jax.random.normal(subkey_x, (2, 3, 4), dtype=jnp.float32)
y = jax.random.normal(subkey_y, (2, 4), dtype=jnp.float32)

z = x.at[:, 1, :].set(y)

print("x", x)
print("y", y)
print("z", z)

x [[[-0.45876116 -1.820128    0.09428611 -1.1634603 ]
  [ 0.5830652   0.8229201  -1.2494267   0.4772025 ]
  [ 0.38525388  0.6251398   0.9846895   0.21361268]]

 [[ 0.2637497  -0.25655866  0.5197545  -1.1509591 ]
  [ 0.2594074  -2.5684536  -2.7484417  -0.6710001 ]
  [-1.3932319   0.08468539 -0.15895508  0.5357991 ]]]
y [[ 1.5729724  -0.6441063  -1.0299827   0.6285919 ]
 [ 1.1504353   0.09341063 -0.08898792  1.6021241 ]]
z [[[-0.45876116 -1.820128    0.09428611 -1.1634603 ]
  [ 1.5729724  -0.6441063  -1.0299827   0.6285919 ]
  [ 0.38525388  0.6251398   0.9846895   0.21361268]]

 [[ 0.2637497  -0.25655866  0.5197545  -1.1509591 ]
  [ 1.1504353   0.09341063 -0.08898792  1.6021241 ]
  [-1.3932319   0.08468539 -0.15895508  0.5357991 ]]]


Note that JAX does weird things with out-of-bounds indexes, typically clipping the index to bounds that work. So you'll have to be careful when implementing things that use indices in JAX.

In [85]:
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (10, 25), dtype=jnp.float32)
print(f"x[1000, 1000] = {x[1000, 1000]}")
print(f"x[{x.shape[0] - 1}, {x.shape[1] - 1}] = {x[x.shape[0] - 1, x.shape[1] - 1]}")

x[1000, 1000] = -2.163405656814575
x[9, 24] = -2.163405656814575


# Gather / Scatter

A very common use case in Torch is the gather/scatter pattern where you obtain a set of indices from e.g. `argmax`, `sort`, or `topk` and you need to extract these indices from another array. JAX does not support this directly. Does have `jnp.take_along_axis` that is equivalent to `gather`, but it does not expose anything equivalent to `torch.scatter`, which is the adjoint operation and derivative. JAX does however support certain advanced indexed concepts, and this can be done with complicated indices.

In [86]:
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (2, 10, 300), dtype=jnp.float32)

indices = jnp.argmax(x, axis=-1)
print("indices shape", indices.shape)

highest_values = jnp.take_along_axis(x, indices[..., None], axis=-1).squeeze(-1)
print("highest_values shape", highest_values.shape)

# now suppose we want to scatter these highest values back into an array that is otherwise zero
# with numpy we could do this with:
y = jnp.zeros_like(x)
first_index = jnp.arange(y.shape[0])[:, None, None].repeat(y.shape[1], axis=1)
second_index = jnp.arange(y.shape[1])[None, :, None].repeat(y.shape[0], axis=0)
y = y.at[first_index, second_index, indices[...,None]].set(highest_values[..., None])
assert jnp.allclose(jnp.max(y, axis=-1), highest_values)




indices shape (2, 10)
highest_values shape (2, 10)


In [87]:
# we can even write a `put_along_axis` function that is like `torch.scatter`
def put_along_axis(arr, indices, values, axis):
    assert arr.ndim == indices.ndim == values.ndim
    assert indices.shape == values.shape
    assert axis is not None and -arr.ndim <= axis < arr.ndim

    if axis < 0:
        axis = arr.ndim + axis

    index_list = []
    for i in range(arr.ndim):
        if i == axis:
            index_list.append(indices)
        else:
            indices_i = jnp.arange(arr.shape[i])
            s = []
            for j in range(arr.ndim):
                if j == i:
                    s.append(slice(None))
                else:
                    s.append(None)
            
            indices_i = indices_i[tuple(s)]

            for j in range(arr.ndim):
                if j != i and j != axis:
                    indices_i = indices_i.repeat(arr.shape[j], axis=j)

            index_list.append(indices_i)

    return arr.at[tuple(index_list)].set(values)


z = jnp.zeros_like(x)
z = put_along_axis(z, indices[...,None], highest_values[...,None], axis=-1)
assert jnp.allclose(z, y)

print("put_along_axis works!")





put_along_axis works!


In [88]:
# ChatGPT recommended this alternative using `jnp.indices`:
def put_along_axis(arr, indices, values, axis):
    assert arr.ndim == indices.ndim == values.ndim
    assert indices.shape == values.shape
    assert axis is not None and -arr.ndim <= axis < arr.ndim

    # normalize axis
    axis = axis % arr.ndim

    # Generate base coordinate grids for all axes
    # Shape: (ndim, *arr.shape)
    grids = jnp.indices(arr.shape, sparse=False)

    # Replace the coordinate array at 'axis' with provided indices
    # but we need indices to have the same shape as grids, so expand dims
    index_grids = list(grids)
    index_grids[axis] = indices

    return arr.at[tuple(index_grids)].set(values)

z = jnp.zeros_like(x)
z = put_along_axis(z, indices[...,None], highest_values[...,None], axis=-1)
assert jnp.allclose(z, y)

print("ChatGPT's put_along_axis works!")

ChatGPT's put_along_axis works!


There are the usual indexing gotchas, which is why we need the complex indexing above:

In [89]:
# This, however, does not do the right thing.
z = jnp.zeros_like(x)
z = z.at[:, :, indices[...,None]].set(highest_values[...,None])

print("z: ", jnp.max(z, axis=-1))
print("y: ", jnp.max(y, axis=-1))
print("z idxs: ", jnp.argmax(z, axis=-1))
print("y idxs: ", jnp.argmax(y, axis=-1))

print("It wrote everything to the same index! ... not what we want.")


z:  [[4.259979 4.259979 4.259979 4.259979 4.259979 4.259979 4.259979 4.259979
  4.259979 4.259979]
 [4.259979 4.259979 4.259979 4.259979 4.259979 4.259979 4.259979 4.259979
  4.259979 4.259979]]
y:  [[3.1397703 2.2660837 2.938662  2.764445  2.3730185 2.7462356 2.8064163
  2.482857  2.963299  2.9200146]
 [3.800043  2.3390687 3.7038352 3.3755596 4.259979  2.5296812 3.1239321
  2.4584854 4.072891  2.772977 ]]
z idxs:  [[168 168 168 168 168 168 168 168 168 168]
 [168 168 168 168 168 168 168 168 168 168]]
y idxs:  [[ 34  62  73  74  75  77 193  24 242 198]
 [173 180 272   0 168  47  59 192 271 182]]
It wrote everything to the same index! ... not what we want.


However, the real indexing primitive for JAX is `jax.lax.gather` and `jax.lax.scatter`; apparently, this is used for the complex indexing under the hood. It looks like a multi-index version of torch gather/scatter.

Note that the API docs warn to prefer `jax.numpy` operations over `jax.lax` where possible; `lax` implements most of the standard operations (`concatenate`, `broadcast`, `atan`, `exp`, etc.)

Gather is no exception, and the API docs recommend strongly against using `gather` and `scatter` directly. However, they do provide some low-level arguments that could be useful, namely, `indices_are_sorted` and `promise_in_bounds`:

In [93]:
print("time with indices_are_sorted=True, mode='promise_in_bounds'")
%timeit x.at[first_index, second_index, indices[..., None]].get(indices_are_sorted=True, mode="promise_in_bounds")
print("time with indices_are_sorted=False")
%timeit x.at[first_index, second_index, indices[..., None]].get(indices_are_sorted=False)


time with indices_are_sorted=True, mode='promise_in_bounds'
321 μs ± 7.63 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
time with indices_are_sorted=False
329 μs ± 13 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


So it didn't help much here, but then again, we're running on CPU.

## Dynamic vs. Static Indexing

JAX JIT cannot handle dynamic indexing. So, for instance, if we want to run a nearest-neighbor search and take all results that are above a threshold, this function CANNOT be traced.

In [159]:
def nearest_neighbor_search(query, keys, threshold, max_to_take=None):
    assert query.ndim == keys.ndim
    assert query.shape[-1] == keys.shape[-1]

    distances = jnp.linalg.norm(query[..., None, :] - keys[..., None, :, :], axis=-1)

    # Sort results by distance
    indices = jnp.argsort(distances, axis=-1)
    sorted_distances = jnp.take_along_axis(distances, indices, axis=-1)

    threshold_mask = sorted_distances > threshold

    # Get the last index where the distance is less than the threshold
    valid_counts = jnp.argmax(threshold_mask, axis=-1)
    first_invalid_index = valid_counts.max()

    if max_to_take is not None:
        first_invalid_index = min(first_invalid_index, max_to_take)
        valid_counts = jnp.minimum(valid_counts, max_to_take)

    indices = indices[..., :first_invalid_index]
    print(indices.shape)

    # Use the first valid index to index into the keys
    key_indices = indices[..., None]
    sorted_keys = jnp.take_along_axis(keys[None, ...], key_indices, axis=-2)

    return sorted_keys, valid_counts, indices


In [160]:
import math

key, subkey_1, subkey_2 = jax.random.split(key, 3)

num_queries = 8
num_keys = 1024
key_dim = 32

expected_threshold = 2 * key_dim // math.log(num_keys, 2)

query = jax.random.normal(subkey_1, (num_queries, key_dim), dtype=jnp.float32)
keys = jax.random.normal(subkey_2, (num_keys, key_dim), dtype=jnp.float32)

key_result, valid_counts, indices = nearest_neighbor_search(query, keys, expected_threshold)

print("expected_threshold", expected_threshold)
print("key_result shape", key_result.shape)
print("valid_counts shape", valid_counts.shape)

print("valid_counts", valid_counts)
print("indices", indices[0])


(8, 66)
expected_threshold 6.0
key_result shape (8, 66, 32)
valid_counts shape (8,)
valid_counts [ 0 32 12 66 27  1 14 16]
indices [ 139   87   37  718   47   22  159 1006  304  233  768   27  417  782
  285  678  546  880  650  629  163   51  351  535  673  510  128  990
  100  463  874   71  616  509  452  354   24  522  214  456  153  536
  105  743  192  512   46  118  956  906  254  533  303  513  552 1007
  753  467 1009  820  349  806  242  710  383   93]


But now if we try to compile this with JIT, it will not work, because the size of our indexing array is not static. There is no way around this in optimized JAX.

In [161]:
jit_nearest_neighbor_search = jax.jit(nearest_neighbor_search)

key_result, valid_counts, indices = jit_nearest_neighbor_search(query, keys, expected_threshold)



IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, JitTracer<int32[]>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

# Conclusion

JAX is committed to immutability and traceability, and as a result of these design decisions, JAX indexing is not as comprehensive as Torch or NumPy. But most things are possible, we just loose the ability to automatically optimize when we use dynamic indices.