In [1]:
import torch
import torch.nn.functional as F

##  `scatter` and `gather`

[`scatter_()`](https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_):
Writes all values from the tensor `src` into `self` at the indices specified in the `index` tensor. 

For a 2-D tensor
```
self[index[i][j]][j] = src[i][j]  # if dim == 0
self[i][index[i][j]] = src[i][j]  # if dim == 1
```


In [2]:
src = torch.arange(1, 11).reshape((2, 5))
src

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])

In [3]:
index = torch.tensor([[0, 1, 2, 0]])

In [4]:
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)

tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

- `src[0][0]` (`index[0][0]` = 0) --> `self[index[0][0]][0]` == `self[0][0]` = 1
- `src[0][1]` (`index[0][1]` = 1) --> `self[index[0][1]][1]` == `self[1][1]` = 2
- `src[0][2]` (`index[0][2]` = 2) --> `self[index[0][2]][2]` == `self[2][2]` = 3
- `src[0][3]` (`index[0][3]` = 0) --> `self[index[0][3]][3]` == `self[0][3]` = 4

In [5]:
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)

tensor([[4, 2, 3, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]])

- `src[0][0]` (`index[0][0]` = 0) --> `self[0][index[0][0]]` == `self[0][0]` = 1
- `src[0][1]` (`index[0][1]` = 1) --> `self[0][index[0][1]]` == `self[0][1]` = 2
- `src[0][2]` (`index[0][2]` = 2) --> `self[0][index[0][2]]` == `self[0][2]` = 3
- `src[0][3]` (`index[0][3]` = 0) --> `self[0][index[0][3]]` == `self[0][0]` = 4 (**overlapped** with first row)

In [6]:
torch.full((2, 4), 2.).scatter_(
    1, torch.tensor([[2], [3]]),
    1.23, reduce='multiply')

tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])

- `src[0][0]` (`index[0][0]` = 2) --> `self[0][index[0][0]]` == `self[0][2]` * 1.23 = 2.46 
- `src[1][0]` (`index[1][0]` = 3) --> `self[1][index[1][0]]` == `self[1][3]` * 1.23 = 2.46


Similary for `reduce='add'`.

One Hot Encoding using `scatter_()`

In [7]:
x = torch.tensor([1, 2, 4, 2, 1, 0, 3])
num_classes = 5

reshaped_x = x.view(-1, 1)
one_hot = torch.full((x.shape[0], num_classes), 0).scatter_(
    1, # dimension
    reshaped_x, # index from reshaped x
    1, # value to fill for one-hot encoding
)

In [8]:
reshaped_x, one_hot

(tensor([[1],
         [2],
         [4],
         [2],
         [1],
         [0],
         [3]]),
 tensor([[0, 1, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 1, 0, 0],
         [0, 1, 0, 0, 0],
         [1, 0, 0, 0, 0],
         [0, 0, 0, 1, 0]]))

[`gather()`](https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather): Gathers values along an axis specified by `dim`.


For a 2-D tensor the output is specified by:

```
out[i][j] = input[index[i][j]][j]  # if dim == 0
out[i][j] = input[i][index[i][j]]  # if dim == 1
```

In [9]:
t = torch.tensor([[1, 2, 5], [3, 4, 7]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))

tensor([[1, 1],
        [4, 3]])

- `output[0][0]` (`index[0][0]` = 0) = `input[0][index[0][0]]` == `input[0][0]` = 1
- `output[0][1]` (`index[0][1]` = 0) = `input[0][index[0][1]]` == `input[0][0]` = 1
- `output[1][0]` (`index[1][0]` = 1) = `input[1][index[1][0]]` == `input[1][1]` = 4
- `output[1][1]` (`index[1][1]` = 0) = `input[1][index[1][1]]` == `input[1][0]` = 3

Calculate NLL (negative log likelihood) loss with `gather()` 

$$ H(p, q) = -\sum_{K}\log(p(x))q(x) $$

where $p(x)$ is prediction and $q(x)$ is ground truth.

In [10]:
pred = torch.tensor([
    [0.6, 0.5, 0.8, 0.3, 0.1],
    [0.1, 0.9, 0.4, 0.2, 0.2],
    [0.1, 0.3, 0.4, 0.2, 0.9],  
]) # raw predictions of each sample (across all 5 classes)
target_gt = torch.tensor([2, 1, 4]) # targets with class id

In [11]:
# log softmax probabilities
logprob = F.log_softmax(pred, dim=-1)
logprob

tensor([[-1.4982, -1.5982, -1.2982, -1.7982, -1.9982],
        [-1.9148, -1.1148, -1.6148, -1.8148, -1.8148],
        [-1.9318, -1.7318, -1.6318, -1.8318, -1.1318]])

In [12]:
- logprob.gather(
    dim=1, # only gather data at axis=1 (since axis=0 refers to different data sample in the batch)
    index=target_gt.unsqueeze(1), # gather value based on ground truth (only select index of true label) 
)

tensor([[1.2982],
        [1.1148],
        [1.1318]])