# Gather and Scatter Operations
Those operations are used to read values on a tensor based on some index, the advantage of these operations are because they are fast.

#### References
* [Pytorch Basic Operations](https://jhui.github.io/2018/02/09/PyTorch-Basic-operations/)
* [Stack Overflow Gather](https://stackoverflow.com/questions/50999977/what-does-the-gather-function-do-in-pytorch-in-layman-terms)
* [Scatter](https://pytorch.org/docs/stable/tensors.html?highlight=scatter#torch.Tensor.scatter)
* [Gather](https://pytorch.org/docs/stable/torch.html#torch.gather)
* [What does scatter do in layman terms](https://discuss.pytorch.org/t/what-does-the-scatter-function-do-in-layman-terms/28037/4)
* [Using Scatter to Convert to One-Hot](https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size = 3
nb_classes = 5

#### Gatter 1D

In [2]:
a = torch.tensor([[1,2,3]])
b = torch.tensor([[1,0,2]])
r = torch.gather(input=a, dim=1, index=b)
print(r)

tensor([[2, 1, 3]])


#### Gather 2D

In [3]:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
print(r)

tensor([[1, 2],
        [3, 2]])


In [4]:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
print(r)

tensor([[1, 1],
        [4, 3]])


### Scatter
This operation is used to scatter values along a tensor following the same rules as gather, but now pushing values instead of selecting them.

In [5]:
x = torch.arange(3).view(1,3)+10
print('input:',x)
y = torch.zeros_like(x)
# index need to have same shape as x
b = torch.tensor([[0,2,1]])
print('index:',b)
# Scatter on y, the values on x along dimension 1
y.scatter(dim=1, index=b, src=x)

input: tensor([[10, 11, 12]])
index: tensor([[0, 2, 1]])


tensor([[10, 12, 11]])

#### Custom CrossEntropy

In [6]:
# Define custom cross_entropy
# x: shape [batch x C]
# y: shape [batch]
def my_cross_entropy(x, target):
    # Calculate the log-probability of x along first dimension    
    log_prob = -1.0 * F.log_softmax(x, 1)
    print('log_prob:')
    print(log_prob)
    # Unsqueze will make y shape become [batch x 1]
    # Gather elements from log_probability along second dimension (dim=1)
    # here log_prob.shape [ batch x C]
    # target.shape [ batch x 1], so the shapes wont match but pytorch will broadcast
    loss_no_gather = log_prob[range(log_prob.shape[0]), target]
    print('No gather loss(SLOW):')
    print(loss_no_gather.unsqueeze(1))
    loss = log_prob.gather(1, target.unsqueeze(1))  
    print('Gather loss')
    print(loss)
    loss = loss.mean()    
    return loss

# Reference CrossEntropy
criterion = nn.CrossEntropyLoss()

In [7]:
x = torch.randn(batch_size, nb_classes, requires_grad=True)
y = torch.randint(0, nb_classes, (batch_size,))

with torch.no_grad():
    loss_reference = criterion(x, y)
    loss = my_cross_entropy(x, y)

print('loss_reference:',loss_reference)
print('my_cross_entropy:',loss)

log_prob:
tensor([[3.6357, 1.6960, 1.2000, 2.6488, 0.8716],
        [3.4037, 3.1011, 1.7399, 0.3241, 3.7723],
        [2.2717, 0.9338, 1.9005, 1.6440, 1.8258]])
No gather loss(SLOW):
tensor([[1.2000],
        [0.3241],
        [1.8258]])
Gather loss
tensor([[1.2000],
        [0.3241],
        [1.8258]])
loss_reference: tensor(1.1166)
my_cross_entropy: tensor(1.1166)
