looking into the output of sum along different dimensions in pytorch to derive an algorithm

In [2]:
import torch
import math

shape = (2, 3, 2, 4)
t = torch.arange(1, math.prod(shape) + 1).view(*shape)
t

tensor([[[[ 1,  2,  3,  4],
          [ 5,  6,  7,  8]],

         [[ 9, 10, 11, 12],
          [13, 14, 15, 16]],

         [[17, 18, 19, 20],
          [21, 22, 23, 24]]],


        [[[25, 26, 27, 28],
          [29, 30, 31, 32]],

         [[33, 34, 35, 36],
          [37, 38, 39, 40]],

         [[41, 42, 43, 44],
          [45, 46, 47, 48]]]])

In [3]:
t.sum(0, keepdim=True)

tensor([[[[26, 28, 30, 32],
          [34, 36, 38, 40]],

         [[42, 44, 46, 48],
          [50, 52, 54, 56]],

         [[58, 60, 62, 64],
          [66, 68, 70, 72]]]])

For the 0th dimension the two (3, 2, 4) arrays have been summed together element-wise:

```
    [ 1,  2,  3,  4]
    [ 5,  6,  7,  8]

    [ 9, 10, 11, 12]
    [13, 14, 15, 16]              [ 1,  2,  3,  4]        [25, 26, 27, 28]        [26, 28, 30, 32]
                                  [ 5,  6,  7,  8]        [29, 30, 31, 32]        [34, 36, 38, 40]
    [17, 18, 19, 20]
    [21, 22, 23, 24]              [ 9, 10, 11, 12]        [33, 34, 35, 36]        [42, 44, 46, 48]
                        ------>   [13, 14, 15, 16]   +    [37, 38, 39, 40]   =    [50, 52, 54, 56]

    [25, 26, 27, 28]              [17, 18, 19, 20]        [41, 42, 43, 44]        [58, 60, 62, 64]
    [29, 30, 31, 32]              [21, 22, 23, 24]        [45, 46, 47, 48]        [66, 68, 70, 72]

    [33, 34, 35, 36]
    [37, 38, 39, 40]

    [41, 42, 43, 44]
    [45, 46, 47, 48]
```

What I want to do is for the output tensor of shape (1, 3, 2, 4) iterate over each cell left to right and top to bottom (the way the array is stored in memory) and calculate the sum for each cell.
To do this I have come up with the following parameters:

**For each sum operation, how do we know which elements have to be summed together?**
- `stride`: how many elements have to be skipped to get the next element of the sum
- `nstride`: how many elements have to be added to the sum

In this case `stride = 24` because each element of the first (3, 2, 4) array has to be matched with each element of the second (3, 2, 4) array and `nstride = 2` because there are only two arrays.

This can be generalised with `stride` being the product of the remaining dimension sizes (3 * 2 * 4 = 24) and `nstride` the size of the dimension to sum, which makes sense.

**After each sum operation, how do we know which element will be the first of the next sum?**
- `step`: how many elements have to be skipped to calculate the next sum
- `nstep`: how many sums have to be performed

In this case `step = 1` because the elements of the first (3, 2, 4) array are summed element-wise with the second (3, 2, 4) array, so the next element to be summed will be the one just after the last. And `nstep = 24` because that's how many elements are in each (3, 2, 4) array.

In [4]:
t.sum(1, keepdim=True)

tensor([[[[ 27,  30,  33,  36],
          [ 39,  42,  45,  48]]],


        [[[ 99, 102, 105, 108],
          [111, 114, 117, 120]]]])

In [5]:
t.sum(2, keepdim=True)

tensor([[[[ 6,  8, 10, 12]],

         [[22, 24, 26, 28]],

         [[38, 40, 42, 44]]],


        [[[54, 56, 58, 60]],

         [[70, 72, 74, 76]],

         [[86, 88, 90, 92]]]])

In [6]:
t.sum(3, keepdim=True)

tensor([[[[ 10],
          [ 26]],

         [[ 42],
          [ 58]],

         [[ 74],
          [ 90]]],


        [[[106],
          [122]],

         [[138],
          [154]],

         [[170],
          [186]]]])