# How to use torch.gather (and how not to...)

This notebook provides a short tutorial on using `torch.gather` through a few examples, while showcasing some common pitfalls and differences from the `numpy` behaviour.

In [1]:
import time

import numpy as np
import torch

In [2]:
torch.manual_seed(100)
np.random.seed(100)

## Basic usage

Let's first look at a basic example: Starting from a 2D matrix `(m, n)`, construct a new 2D matrix `(m, l)`, gathering elements across rows.

We'll start from a 3x3 matrix containing numbers from 0 to 9.

In [3]:
# 2D matrix, gather across dim=1
x = torch.arange(9).reshape(3, 3)
print(x)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])


Let's construct a 3x4 matrix, gathering the following numbers from each row `[[2, 2, 0, 0],[3, 4, 5, 5], [6, 6, 7, 7]]`. For each row we therefore specify the corresponding indices:

In [4]:
# For each row, specify which index along the row we want to take
torch.gather(x, dim=1, index=torch.tensor([[2, 2, 0, 0], [0, 1, 2, 2], [0, 0, 1, 1]]))

tensor([[2, 2, 0, 0],
        [3, 4, 5, 5],
        [6, 6, 7, 7]])

What if we want to collect the same indices for each row, e.g. `[2, 2, 0, 0], [5, 5, 3, 3], [8, 8, 6, 6]`? The easiest thing to do is to again explicity specify indices for each row, same as before.

In [5]:
# Easiest thing to do -> explicty specify for each row again, same as before.
torch.gather(x, dim=1, index=torch.tensor([[2, 2, 0, 0], [2, 2, 0, 0], [2, 2, 0, 0]]))

tensor([[2, 2, 0, 0],
        [5, 5, 3, 3],
        [8, 8, 6, 6]])

What happens if we only specify the indices for one row? We might expect that the indices will be broadcast, so that the specified indices will be selected for each row. But that's not what happens!

In [6]:
torch.gather(x, dim=1, index=torch.tensor([[2, 2, 0, 0]]))

tensor([[2, 2, 0, 0]])

We have only gathered the elements for the first row! Similarly, if we specify indices for two rows:

In [7]:
# Similar if two rows are specified in the index
torch.gather(x, dim=1, index=torch.tensor([[2, 2, 0, 0], [2, 2, 0, 0]]))

tensor([[2, 2, 0, 0],
        [5, 5, 3, 3]])

Looking at the first case, the tensor is actually size `(1, 4)`, instead of the expected `(3, 4)`. This is because `torch.gather` does not broadcast `index` against `input`, it just expects that for each dimension `d != dim`, `index.size(d) <= input.size(d)`. In this case, `index.size(0) = 1`, so the output will gather only across the first row.

If we want to achieve the desired behaviour, we need to broadcast manually:

In [8]:
# Alternative, broadcast manually
torch.gather(x, dim=1, index=torch.tensor([[2, 2, 0, 0]]).expand(3, -1))

tensor([[2, 2, 0, 0],
        [5, 5, 3, 3],
        [8, 8, 6, 6]])

Why could this be annoying? `numpy` behaviour is different! `numpy` **does** broadcast `index` against `input`, so it will *only* work if, for each dimension `d != dim`, `index.size(d) = input.size(d)`, or `index.size(d) = 1` and is broadcast against `input`.

In [9]:
# numpy will broadcast indices against input
x = np.arange(9).reshape(3, 3)
np.take_along_axis(x, indices=np.array([[2, 2, 0, 0]]), axis=1)

array([[2, 2, 0, 0],
       [5, 5, 3, 3],
       [8, 8, 6, 6]])

In [10]:
# This breaks as shapes are not matching for the non-gather dimensions!
x = np.arange(9).reshape(3, 3)
try:
    np.take_along_axis(x, indices=np.array([[2, 2, 0, 0], [2, 2, 0, 0]]), axis=1)
except:
    print("This does not work!")

This does not work!


**Conclusion**: `index` needs to be **exactly** the same shape as `output`, each non-gather dimension needs to be smaller or equal than the corresponding `input` dimension, while the dimension across which we are gathering can have an arbitrary size.

## Broadcasting `gather`

We can adapt `torch.gather` to instead work similarly to `numpy.take_along_axis`, which is usually the more useful behaviour!

In [11]:
def gather(input: torch.Tensor, dim: int, index: torch.Tensor) -> torch.Tensor:
    """torch.gather with broadcasting"""
    dim = (dim < 0) * input.ndim + dim
    return torch.gather(
        input,
        dim,
        index.expand(*input.shape[:dim], index.shape[dim], *input.shape[dim + 1 :]),
    )

Let's check the behaviour on our previous example.

In [12]:
x = torch.arange(9).reshape(3, 3)
gather(x, dim=1, index=torch.tensor([[2, 2, 0, 0]]))

tensor([[2, 2, 0, 0],
        [5, 5, 3, 3],
        [8, 8, 6, 6]])

## *Batched* multi-dimensional gather

So far, the tutorial has looked at a basic 2D example. Let's now look at a more generic multi-dimensional gather operation and a common scenario where we would like to gather elements across a *batched* tensor.

Specifically, we would like to look at the following. Starting from an input tensor `x`, specified dimension `dim` and `idx` tensor of shape `idx.shape = x.shape[:dim] + (k,)` where each element is between `0` and `x.shape[dim] - 1`, we want to construct the output tensor `y` of shape `y.shape = idx.shape + x.shape[dim + 1:]`.

Let's take a look at an example. We have a 6D input tensor `x` of shape `(2, 3, 5, 7, 11, 13)`. We consider the first three dimensions `(2, 3, 5)` of the input as the *batch* dimensions, so that each batch element is a tensor of shape `(7, 11, 13)`. Here, `7` is the sequence dimension across we would like to gather "elements", where each "element" is a matrix of size `(11, 13)`.

For each batch, we want to take the *same* 4 "elements" of the 7, so that the final output shape should be `(2, 3, 5, 4, 11, 13)`.

We will take a look at a few different ways of achieving the same behaviour.

In [13]:
# Gather across dim=3, where (2, 3, 5) are batch dimensions,
# 7 is the sequence dimension with each element being (11, 13) matrix
x = torch.randn(2, 3, 5, 7, 11, 13)

In [14]:
# Batched gather across dim=3, choose 4 out of 7
# We want output shape (2, 3, 5, 4, 11, 13)
idx = torch.randint(7, (2, 3, 5, 4))

### `torch.gather`

First, let's achieve this using `torch.gather`. We will also measure the time to get an idea of how long each method takes.

As we've seen before, the `index` and `output` shapes of `torch.gather` are *exactly* the same. This means that we need to manually broadcast dimensions of our `idx` to get the desired output.

In [15]:
t0 = time.time()
# NOTE: idx.shape needs to be the same as y1.shape, so broadcast the last two dimensions
y1 = torch.gather(x, dim=3, index=idx[..., None, None].expand(*idx.shape, *x.shape[4:]))
print("Time taken: ", time.time() - t0)
print("Output shape: ", y1.shape)

Time taken:  0.00049591064453125
Output shape:  torch.Size([2, 3, 5, 4, 11, 13])


We need to be careful here! If we did not explicitly expand the `index`, the operation would work, but would not give us our desired output!

In [16]:
# Possible gotcha: forget to expand, or expand incorrectly (works as long as dim(idx) <= dim(x) for all non-gather dimensions)
y1_wrong = torch.gather(x, dim=3, index=idx[..., None, None]) # this works, but slices last two dimensions
print("Wrong output shape:", y1_wrong.shape)

Wrong output shape: torch.Size([2, 3, 5, 4, 1, 1])


In [17]:
# This method only gets the first row/column of each element in the sequence
torch.allclose(y1_wrong, y1[..., :1, :1])

True

Of course, we can instead use our own broadcasting version and check it gives the expected answer:

In [18]:
torch.allclose(gather(x, dim=3, index=idx[..., None, None]), y1)

True

### Advanced indexing

Instead of using `torch.gather`, we can use advanced indexing to achieve the same behaviour. For this, we need to construct `idx0`, `idx1`, `idx2`, and `idx3` for each of the first four dimensions, so that:

`y[i, j, k, l] = x[idx0[i, j, k, l], idx1[i, j, k, l], idx2[i, j, k, l], idx3[i, j, k, l]]`.

Each of the index tensors has the same shape as the output's first four dimensions, i.e. `y.shape[:4]`.

In this case, `idx3` will be the `idx` tensor we have constructed previously, while the other indices just need to indicate that we are selecting all of the dimensions. For this, we use broadcasting for the non-gather batch dimensions.

In [19]:
t0 = time.time()
y2 = x[
    torch.arange(2)[:, None, None, None],
    torch.arange(3)[None, :, None, None],
    torch.arange(5)[None, None, :, None],
    idx,
]
print("Time taken: ", time.time() - t0)

Time taken:  0.0014488697052001953


Note that this takes around an order of magnitude more than the method using `gather`! Let's check that we got the same result:

In [20]:
torch.allclose(y1, y2)

True

#### Flattening before indexing

Instead of constructing three indices for the batch dimensions (`idx0`, `idx1`, `idx2`), we can flatten these dimensions before indexing. This however does require a few annoying reshaping operations.

In [21]:
t0 = time.time()
y3 = x.flatten(end_dim=2)[
    torch.arange(2 * 3 * 5)[:, None], idx.flatten(end_dim=2)
].reshape(2, 3, 5, 4, 11, 13)
print("Time taken: ", time.time() - t0)

Time taken:  0.0012068748474121094


In [22]:
torch.allclose(y1, y3)

True

For a similar approach, we could instead flatten the dimension across which we are gathering together with the batched dimensions. In this case, we need to ensure that the index has the correct offset.

In [23]:
# Flatten gather dim=3 as well, calculate correct offsets
offset = torch.arange(2 * 3 * 5).reshape(2, 3, 5, 1) * 7

In [24]:
t0 = time.time()
y4 = x.flatten(end_dim=3)[(idx + offset).flatten()].reshape(2, 3, 5, 4, 11, 13)
print(time.time() - t0)

0.0008790493011474609


In [25]:
torch.allclose(y1, y4)

True

## Conclusions

When gathering elements from a multi-dimensional tensor, `torch.gather` is the fastest method. There are however a few important things to keep in mind:

* `torch.gather` *does not* broadcast across non-gather dimensions. This means that we need to ensure that the `idx` tensor has the **same** shape as the desired output.
* If the size of a non-gather dimension in `idx` is not the same as in the input tensor, `torch.gather` will still work as long as the size is smaller than the corresponding dimension in the input tensor. This is usually not what we want though! (as that dimension will be sliced)
* For most commmon use cases, we can implement our own modified version of `torch.gather` that performs broadcasting.

Note, this notebook only covered `torch.gather`, but all of the conclusions and broadcasting gotchas can be applied to `torch.scatter` as well.