How does indexing work in numpy exactly?

In [2]:
import numpy as np

The rule seems to be:
1. If there are fewer indices than dimensions in the array, implicitly add full slices for the rest
2. Henceforth, assume the *number of indices* is the same as the number of dimensions in the array
3. Next, broadcast all the indices against each other to implicitly make them the same shape
4. Now, compute the output shape like this:
    - Loop over all the dimensions
    - If the dimension is sliced, add the size of the slice
    - If the dimension is not sliced, *and the indices have not already been added*, add the size of thie indices

Notice in particular: **slices are not like passing in `np.arange(axis_size)`**. The sliced dimensions don't get broadcast with the other dimensions, and show up independently. SLICES ARE DIFFERENT.

In [74]:
x = np.random.randn(10,11,12,13)
assert x[0,1,2,3].shape == ()
assert x[[0,1],[0,1],[0,1],[0,1]].shape == (2,)
assert x[[0,1],[0,1],:,[0,1]].shape == (2,12)
assert x[:,[0,1],[0,1],[0,1]].shape == (10,2)
assert x[:,[0,1],[0,1],[[0,1],[2,3],[4,5]]].shape == (10,3,2)

In [75]:
# 1d array 1d indexing
x = np.arange(0,10,.1)[:10]
idx = np.array([0,3,1,2])
print("x", "\n", f"{x}")
print("idx", "\n", f"{idx}")
print("x[idx]", "\n", f"{x[idx]}")

x 
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
idx 
 [0 3 1 2]
x[idx] 
 [0.  0.3 0.1 0.2]


In [76]:
# 1d array 2d indexing
x = np.arange(0,10,.1)[:10]
idx = np.array([[0,1],[2,3]])
print("x", "\n", f"{x}")
print("idx", "\n", f"{idx}")
print("x[idx]", "\n", f"{x[idx]}")

x 
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
idx 
 [[0 1]
 [2 3]]
x[idx] 
 [[0.  0.1]
 [0.2 0.3]]


In [77]:
# 2d array, single 1d index
x = np.reshape(np.arange(0,10,.1),[10,10])[:5,:5]
idx = np.array([0,3,2,1])
print("x", "\n", f"{x}")
print("idx", "\n", f"{idx}")
print("x[idx]", "\n", f"{x[idx]}")

x 
 [[0.  0.1 0.2 0.3 0.4]
 [1.  1.1 1.2 1.3 1.4]
 [2.  2.1 2.2 2.3 2.4]
 [3.  3.1 3.2 3.3 3.4]
 [4.  4.1 4.2 4.3 4.4]]
idx 
 [0 3 2 1]
x[idx] 
 [[0.  0.1 0.2 0.3 0.4]
 [3.  3.1 3.2 3.3 3.4]
 [2.  2.1 2.2 2.3 2.4]
 [1.  1.1 1.2 1.3 1.4]]


In [78]:
# 2d array, single 2d index
x = np.reshape(np.arange(0,10,.1),[10,10])[:4,:5]
idx = np.array([[0,1],[2,3]])
print("x", "\n", f"{x}")
print("idx", "\n", f"{idx}")
print("x[idx]", "\n", f"{x[idx]}")
z = x[idx]
assert z.shape == idx.shape + x.shape[1:] # shape of index then shape of slice

x 
 [[0.  0.1 0.2 0.3 0.4]
 [1.  1.1 1.2 1.3 1.4]
 [2.  2.1 2.2 2.3 2.4]
 [3.  3.1 3.2 3.3 3.4]]
idx 
 [[0 1]
 [2 3]]
x[idx] 
 [[[0.  0.1 0.2 0.3 0.4]
  [1.  1.1 1.2 1.3 1.4]]

 [[2.  2.1 2.2 2.3 2.4]
  [3.  3.1 3.2 3.3 3.4]]]


In [79]:
# 2d array, single 2d index in second component
x = np.reshape(np.arange(0,10,.1),[10,10])[:4,:5]
idx = np.array([[0,1],[2,3]])
print("x", "\n", f"{x}")
print("idx", "\n", f"{idx}")
print("x[:,idx]", "\n", f"{x[:,idx]}")
z = x[:,idx]
assert z.shape == x.shape[:1] + idx.shape # shape of slice then shape of index

x 
 [[0.  0.1 0.2 0.3 0.4]
 [1.  1.1 1.2 1.3 1.4]
 [2.  2.1 2.2 2.3 2.4]
 [3.  3.1 3.2 3.3 3.4]]
idx 
 [[0 1]
 [2 3]]
x[:,idx] 
 [[[0.  0.1]
  [0.2 0.3]]

 [[1.  1.1]
  [1.2 1.3]]

 [[2.  2.1]
  [2.2 2.3]]

 [[3.  3.1]
  [3.2 3.3]]]


In [86]:
x = np.random.randn(21,17,5,7,9)
idx0 = np.random.randint(0,21,[4,3])
idx2 = np.random.randint(0,5,[4,3])
idx4 = np.array(2)
y = x[idx0,:,idx2,:,idx4]
assert y.shape == idx0.shape + x.shape[1:2] + x.shape[3:4] # look ma no idx2

In [93]:
x = np.random.randn(21,17,5,7,9)
idx0 = np.random.randint(0,21,[4,3])
idx2 = np.random.randint(0,5,[3])
idx4 = np.array(2)
y = x[idx0,:,idx2,:,idx4]
assert y.shape == idx0.shape + x.shape[1:2] + x.shape[3:4]

In [97]:
(3,) + ()

(3,)