In [1]:
import torch
import torch.nn.functional as F
import numpy as np

##  `scatter` and `gather`

### [`scatter_()`](https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_)
> Writes all values from the tensor `src` into `self` at the indices specified in the `index` tensor. 

For a 2-D tensor
```
self[index[i][j]][j] = src[i][j]  # if dim == 0
self[i][index[i][j]] = src[i][j]  # if dim == 1
```


In [2]:
src = torch.arange(1, 11).reshape((2, 5))
src

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])

In [3]:
index = torch.tensor([[0, 1, 2, 0]])

In [4]:
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)

tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

- `src[0][0]` (`index[0][0]` = 0) --> `self[index[0][0]][0]` == `self[0][0]` = 1
- `src[0][1]` (`index[0][1]` = 1) --> `self[index[0][1]][1]` == `self[1][1]` = 2
- `src[0][2]` (`index[0][2]` = 2) --> `self[index[0][2]][2]` == `self[2][2]` = 3
- `src[0][3]` (`index[0][3]` = 0) --> `self[index[0][3]][3]` == `self[0][3]` = 4

In [5]:
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)

tensor([[4, 2, 3, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]])

- `src[0][0]` (`index[0][0]` = 0) --> `self[0][index[0][0]]` == `self[0][0]` = 1
- `src[0][1]` (`index[0][1]` = 1) --> `self[0][index[0][1]]` == `self[0][1]` = 2
- `src[0][2]` (`index[0][2]` = 2) --> `self[0][index[0][2]]` == `self[0][2]` = 3
- `src[0][3]` (`index[0][3]` = 0) --> `self[0][index[0][3]]` == `self[0][0]` = 4 (**overlapped** with first row)

In [6]:
torch.full((2, 4), 2.).scatter_(
    1, torch.tensor([[2], [3]]),
    1.23, reduce='multiply')

tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])

- `src[0][0]` (`index[0][0]` = 2) --> `self[0][index[0][0]]` == `self[0][2]` * 1.23 = 2.46 
- `src[1][0]` (`index[1][0]` = 3) --> `self[1][index[1][0]]` == `self[1][3]` * 1.23 = 2.46


Similary for `reduce='add'`.

One Hot Encoding using `scatter_()`

In [7]:
x = torch.tensor([1, 2, 4, 2, 1, 0, 3])
num_classes = 5

reshaped_x = x.view(-1, 1)
one_hot = torch.full((x.shape[0], num_classes), 0).scatter_(
    1, # dimension
    reshaped_x, # index from reshaped x
    1, # value to fill for one-hot encoding
)

In [8]:
reshaped_x, one_hot

(tensor([[1],
         [2],
         [4],
         [2],
         [1],
         [0],
         [3]]),
 tensor([[0, 1, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 1, 0, 0],
         [0, 1, 0, 0, 0],
         [1, 0, 0, 0, 0],
         [0, 0, 0, 1, 0]]))

### [`gather()`](https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather)
> Gathers values along an axis specified by `dim`.


For a 2-D tensor the output is specified by:

```
out[i][j] = input[index[i][j]][j]  # if dim == 0
out[i][j] = input[i][index[i][j]]  # if dim == 1
```

In [9]:
t = torch.tensor([[1, 2, 5], [3, 4, 7]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))

tensor([[1, 1],
        [4, 3]])

- `output[0][0]` (`index[0][0]` = 0) = `input[0][index[0][0]]` == `input[0][0]` = 1
- `output[0][1]` (`index[0][1]` = 0) = `input[0][index[0][1]]` == `input[0][0]` = 1
- `output[1][0]` (`index[1][0]` = 1) = `input[1][index[1][0]]` == `input[1][1]` = 4
- `output[1][1]` (`index[1][1]` = 0) = `input[1][index[1][1]]` == `input[1][0]` = 3

Calculate NLL (negative log likelihood) loss with `gather()` 

$$ H(p, q) = -\sum_{K}\log(p(x))q(x) $$

where $p(x)$ is prediction and $q(x)$ is ground truth.

In [10]:
pred = torch.tensor([
    [0.6, 0.5, 0.8, 0.3, 0.1],
    [0.1, 0.9, 0.4, 0.2, 0.2],
    [0.1, 0.3, 0.4, 0.2, 0.9],  
]) # raw predictions of each sample (across all 5 classes)
target_gt = torch.tensor([2, 1, 4]) # targets with class id

In [11]:
# log softmax probabilities
logprob = F.log_softmax(pred, dim=-1)
logprob

tensor([[-1.4982, -1.5982, -1.2982, -1.7982, -1.9982],
        [-1.9148, -1.1148, -1.6148, -1.8148, -1.8148],
        [-1.9318, -1.7318, -1.6318, -1.8318, -1.1318]])

In [12]:
- logprob.gather(
    dim=1, # only gather data at axis=1 (since axis=0 refers to different data sample in the batch)
    index=target_gt.unsqueeze(1), # gather value based on ground truth (only select index of true label) 
)

tensor([[1.2982],
        [1.1148],
        [1.1318]])

## `enisum`

https://stackoverflow.com/questions/55894693/understanding-pytorch-einsum

- [`torch.enisum()`](https://pytorch.org/docs/stable/_modules/torch/functional.html#einsum)
- [`numpy.enisum()`](https://numpy.org/devdocs/reference/generated/numpy.einsum.html)

- NumPy allows both small case and capitalized letters `[a-zA-Z]` for the "subscript string" whereas PyTorch allows only the small case letters `[a-z]`.

- NumPy accepts nd-arrays, plain Python lists (or tuples), list of lists (or tuple of tuples, list of tuples, tuple of lists) or even PyTorch tensors as operands (i.e. inputs). This is because the operands have only to be array_like and not strictly NumPy nd-arrays. On the contrary, PyTorch expects the operands (i.e. inputs) strictly to be PyTorch tensors. It will throw a `TypeError` if you pass either plain Python lists/tuples (or its combinations) or NumPy nd-arrays.

- NumPy supports lot of keyword arguments (for e.g. `optimize`) in addition to `nd-arrays` while PyTorch doesn't offer such flexibility yet.

In [13]:
aten = torch.tensor([
    [11, 12, 13, 14],
    [21, 22, 23, 24],
    [31, 32, 33, 34],
    [41, 42, 43, 44],
])
bten = torch.tensor([
    [1, 1, 1, 1],
    [2, 2, 2, 2],
    [3, 3, 3, 3],
    [4, 4, 4, 4],
])

vec = torch.tensor([0, 1, 2, 3])

### Matrix multiplication

In [14]:
torch.einsum('ij, jk -> ik', aten, bten)

tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])

### Extract elements along the main-diagonal

In [15]:
torch.einsum('ii -> i', aten)

tensor([11, 22, 33, 44])

In [16]:
torch.diag(aten)

tensor([11, 22, 33, 44])

### Hadamard product (i.e. element-wise product of two tensors)

In [17]:
torch.einsum('ij, ij -> ij', aten, bten)

tensor([[ 11,  12,  13,  14],
        [ 42,  44,  46,  48],
        [ 93,  96,  99, 102],
        [164, 168, 172, 176]])

### Element-wise squaring

In [18]:
torch.einsum('ij, ij -> ij', aten, aten)

tensor([[ 121,  144,  169,  196],
        [ 441,  484,  529,  576],
        [ 961, 1024, 1089, 1156],
        [1681, 1764, 1849, 1936]])

In [19]:
aten ** 2

tensor([[ 121,  144,  169,  196],
        [ 441,  484,  529,  576],
        [ 961, 1024, 1089, 1156],
        [1681, 1764, 1849, 1936]])

In [20]:
torch.einsum('i, i -> i', vec, vec)

tensor([0, 1, 4, 9])

In [21]:
vec * vec

tensor([0, 1, 4, 9])

### Trace (i.e. sum of main-diagonal elements)

In [22]:
torch.einsum('ii -> ', aten)

tensor(110)

### Matrix transpose

In [23]:
torch.einsum('ij -> ji', aten)

tensor([[11, 21, 31, 41],
        [12, 22, 32, 42],
        [13, 23, 33, 43],
        [14, 24, 34, 44]])

In [24]:
aten.transpose(0, 1)

tensor([[11, 21, 31, 41],
        [12, 22, 32, 42],
        [13, 23, 33, 43],
        [14, 24, 34, 44]])

### Outer Product (of vectors)

In [25]:
torch.einsum('i, j -> ij', vec, vec)

tensor([[0, 0, 0, 0],
        [0, 1, 2, 3],
        [0, 2, 4, 6],
        [0, 3, 6, 9]])

In [26]:
vec * vec

tensor([0, 1, 4, 9])

### Inner Product (of vectors) 

In [27]:
torch.einsum('i, i -> ', vec, vec)

tensor(14)

In [28]:
torch.dot(vec, vec)

tensor(14)

### Sum along axis

In [29]:
torch.einsum('ij -> j', aten)

tensor([104, 108, 112, 116])

In [30]:
torch.sum(aten, 0)

tensor([104, 108, 112, 116])

### Batch Matrix Multiplication

In [31]:
batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4) 

In [32]:
torch.bmm(batch_tensor_1, batch_tensor_2)  

tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

In [33]:
torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)

tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

### Sum over multiple axes (i.e. marginalization)

In [34]:
nDten = torch.randn((3,5,4,6,8,2,7,9))
nDten.shape

torch.Size([3, 5, 4, 6, 8, 2, 7, 9])

In [35]:
# marginalize out dimension 5 (i.e. "n" here)
esum = torch.einsum("ijklmnop -> n", nDten)
esum

tensor([-106.5269,   58.0306])

In [36]:
# marginalize out axis 5 (i.e. sum over rest of the axes)
tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))
torch.allclose(tsum, esum)

True

### Double Dot Products / Frobenius inner product (same as: torch.sum(hadamard-product) cf. 3)

In [37]:
torch.einsum("ij, ij -> ", aten, bten)

tensor(1300)