looking into the output of sum along different dimensions in pytorch to derive an algorithm

In [None]:
import math
import torch

shape = (2, 3, 2, 4)
t = torch.arange(1, math.prod(shape) + 1).view(*shape)
t

tensor([[[[ 1,  2,  3,  4],
          [ 5,  6,  7,  8]],

         [[ 9, 10, 11, 12],
          [13, 14, 15, 16]],

         [[17, 18, 19, 20],
          [21, 22, 23, 24]]],


        [[[25, 26, 27, 28],
          [29, 30, 31, 32]],

         [[33, 34, 35, 36],
          [37, 38, 39, 40]],

         [[41, 42, 43, 44],
          [45, 46, 47, 48]]]])

---

In [3]:
t.sum(0, keepdim=True)

tensor([[[[26, 28, 30, 32],
          [34, 36, 38, 40]],

         [[42, 44, 46, 48],
          [50, 52, 54, 56]],

         [[58, 60, 62, 64],
          [66, 68, 70, 72]]]])

For the 0th dimension, the two (3, 2, 4) arrays have been summed together element-wise:

```
    [ 1  2  3  4]
    [ 5  6  7  8]

    [ 9 10 11 12]
    [13 14 15 16]      [ 1  2  3  4]   [25 26 27 28]   [26 28 30 32]
                       [ 5  6  7  8]   [29 30 31 32]   [34 36 38 40]
    [17 18 19 20]
    [21 22 23 24]      [ 9 10 11 12]   [33 34 35 36]   [42 44 46 48]
                  ---> [13 14 15 16] + [37 38 39 40] = [50 52 54 56]

    [25 26 27 28]      [17 18 19 20]   [41 42 43 44]   [58 60 62 64]
    [29 30 31 32]      [21 22 23 24]   [45 46 47 48]   [66 68 70 72]

    [33 34 35 36]
    [37 38 39 40]

    [41 42 43 44]
    [45 46 47 48]
```

What I want to do is for the output tensor of shape (1, 3, 2, 4) iterate over each cell left to right and top to bottom (the way the array is stored in memory) and calculate the sum for each cell.
To do this I have come up with the following parameters:

**For each sum operation, how do we know which elements have to be summed together?**
- `stride`: how many elements have to be skipped to get the next element of the sum
- `nstride`: how many elements have to be added to the sum

In this case `stride = 24` because each element of the first (3, 2, 4) array has to be matched with each element of the second (3, 2, 4) array and `nstride = 2` because there are only two arrays.

This can be generalised with `stride` being the product of the remaining dimension sizes (3 * 2 * 4 = 24) and `nstride` the size of the dimension to sum, which makes sense.

So, if the current element being summed is the first element of the source array (offset 0) it will have to be summed with the 25th element to produce the first sum of the output array:

```
        ┏━━━━┓
        ┃   [ 1  2  3  4]
        ┃   [ 5  6  7  8]
        ┃
        ┃   [ 9 10 11 12]
    +24 ┃   [13 14 15 16]
        ┃
        ┃   [17 18 19 20]
        ┃   [21 22 23 24]
        ┃
        ┗━━━━┓
            [25 26 27 28]
            [29 30 31 32]

            [33 34 35 36]
            [37 38 39 40]

            [41 42 43 44]
            [45 46 47 48]
```

**After each sum operation, how do we know which element will be the first of the next sum?**
- `step`: how many elements have to be skipped to calculate the next sum
- `nstep`: how many sums have to be performed

In this case `step = 1` because the elements of the first (3, 2, 4) array are summed element-wise with the second (3, 2, 4) array, so the next element to be summed will be the one just after the last. And `nstep = 24` because that's how many elements are in each (3, 2, 4) array.

So, once the first summation is done the current element being summed will be shifted by `step`, taking the previous visualisation:

```
               +1
              ┏━>━┓
        ┏━━━━━━━━━┓
        ┃   [ 1  2  3  4]
        ┃   [ 5  6  7  8]
        ┃
        ┃   [ 9 10 11 12]
    +24 ┃   [13 14 15 16]
        ┃
        ┃   [17 18 19 20]
        ┃   [21 22 23 24]
        ┃
        ┗━━━━━━━━┓
            [25 26 27 28]
            [29 30 31 32]

            [33 34 35 36]
            [37 38 39 40]

            [41 42 43 44]
            [45 46 47 48]
```

In [9]:
t.sum(3, keepdim=True)

tensor([[[[ 10],
          [ 26]],

         [[ 42],
          [ 58]],

         [[ 74],
          [ 90]]],


        [[[106],
          [122]],

         [[138],
          [154]],

         [[170],
          [186]]]])

The last dimension, is an almost identical case to the first dimension but the parameter values differ:

```
    [ 1  2  3  4]      [ 1 +  2 +  3 +  4]   [ 10]
    [ 5  6  7  8]      [ 5 +  6 +  7 +  8]   [ 26]
                                                
    [ 9 10 11 12]      [ 9 + 10 + 11 + 12]   [ 42]
    [13 14 15 16]      [13 + 14 + 15 + 16]   [ 58]
                                                
    [17 18 19 20]      [17 + 18 + 19 + 20]   [ 74]
    [21 22 23 24]      [21 + 22 + 23 + 24]   [ 90]
                  --->                     = 
                                                
    [25 26 27 28]      [25 + 26 + 27 + 28]   [106]
    [29 30 31 32]      [29 + 30 + 31 + 32]   [122]
                                                
    [33 34 35 36]      [33 + 34 + 35 + 36]   [138]
    [37 38 39 40]      [37 + 38 + 39 + 40]   [154]
                                                
    [41 42 43 44]      [41 + 42 + 43 + 44]   [170]
    [45 46 47 48]      [45 + 46 + 47 + 48]   [186]
```

In this case `stride = 1` because each element is summed to the next and `nstride = 4` because the last dimension is of size 4. The generalisation of the last case still holds up here if we assume that the product of 0 elements is 1 (there are no remaining dimensions as we are summing along the last).

Then, `step = 4` because that's how many elements there are in the sum for each output cell and they are contiguous. Finally, `nstep = 12` (the number of rows to sum) which turns out to be the product of all the previous dimension sizes.

---

The middle dimensions (`0 < dim < ndim-1`) need a new parameter, because the whole process described above has to be repeated multiple times.

In [8]:
t.sum(1, keepdim=True)

tensor([[[[ 27,  30,  33,  36],
          [ 39,  42,  45,  48]]],


        [[[ 99, 102, 105, 108],
          [111, 114, 117, 120]]]])

```
    [ 1  2  3  4]      [ 1  2  3  4]
    [ 5  6  7  8]      [ 5  6  7  8]
                             +
    [ 9 10 11 12]      [ 9 10 11 12]
    [13 14 15 16]      [13 14 15 16]
                             +
    [17 18 19 20]      [17 18 19 20]   [ 27  30  33  36]
    [21 22 23 24]      [21 22 23 24]   [ 39  42  45  48]
                  --->               =

    [25 26 27 28]      [25 26 27 28]   [ 99 102 105 108]
    [29 30 31 32]      [29 30 31 32]   [111 114 117 120]
                             +
    [33 34 35 36]      [33 34 35 36]
    [37 38 39 40]      [37 38 39 40]
                             +
    [41 42 43 44]      [41 42 43 44]
    [45 46 47 48]      [45 46 47 48]
```

`stride`, `nstride`, `step` and `nstep` are calculated the same way as in the first dimension:

- `stride = 8`
- `nstride = 3`
- `step = 1`
- `nstep = 8`

But if you follow the steps outlined in the previous cases it will only yield the sum for the first array of the first dimension (the dimension before the one being summed along):

```
[ 27  30  33  36]
[ 39  42  45  48]


[  0   0   0   0]
[  0   0   0   0]
```

So in this case the same process has to be repeated again one more time. The way I imagine this is as if `stride` and `step` allow you to work on a "window" of the source array and now we have to shift it to the next area. In this case we have already computed the first array of the first dimension so we would have to do the exact same thing but shifted (or offsetted) by 24 elements (the product of all dimension sizes except the previous ones).

- `shift`: how many elements to shift the window where `stride` and `step` work on
- `nshift`: number of times the window has to be shifted

In this case `shift = 24` and `nshift = 2`, as explained in the previous paragraph.

Shifted windows:

```
            ┏━━━━━━━━━━━━━━━┓
            ┃ [ 1  2  3  4] ┃
            ┃ [ 5  6  7  8] ┃
            ┃               ┃
  window 1  ┃ [ 9 10 11 12] ┃  ┓
            ┃ [13 14 15 16] ┃  ┃
            ┃               ┃  ┃
            ┃ [17 18 19 20] ┃  ┃
            ┃ [21 22 23 24] ┃  ┃
            ┗━━━━━━━━━━━━━━━┛  v   +24 (shift)
            ┏━━━━━━━━━━━━━━━┓  ┃
            ┃ [25 26 27 28] ┃  ┃
            ┃ [29 30 31 32] ┃  ┃
            ┃               ┃  ┃
  window 2  ┃ [33 34 35 36] ┃  ┃
            ┃ [37 38 39 40] ┃  ┛
            ┃               ┃
            ┃ [41 42 43 44] ┃
            ┃ [45 46 47 48] ┃
            ┗━━━━━━━━━━━━━━━┛
```

In [5]:
t.sum(2, keepdim=True)

tensor([[[[ 6,  8, 10, 12]],

         [[22, 24, 26, 28]],

         [[38, 40, 42, 44]]],


        [[[54, 56, 58, 60]],

         [[70, 72, 74, 76]],

         [[86, 88, 90, 92]]]])

```
    [ 1  2  3  4]
    [ 5  6  7  8]

    [ 9 10 11 12]      [ 1  2  3  4] + [ 5  6  7  8]   [ 6,  8, 10, 12]
    [13 14 15 16]
                       [ 9 10 11 12] + [13 14 15 16]   [22, 24, 26, 28]
    [17 18 19 20]
    [21 22 23 24]      [17 18 19 20] + [21 22 23 24]   [38, 40, 42, 44]
                  --->                               = 

    [25 26 27 28]      [25 26 27 28] + [29 30 31 32]   [54, 56, 58, 60]
    [29 30 31 32]
                       [33 34 35 36] + [37 38 39 40]   [70, 72, 74, 76]
    [33 34 35 36]
    [37 38 39 40]      [41 42 43 44] + [45 46 47 48]   [86, 88, 90, 92]

    [41 42 43 44]
    [45 46 47 48]
```

Exact same steps as previous case:

- `stride = 4`
- `nstride = 2`
- `step = 1`
- `nstep = 4`
- `shift = 8`
- `nshift = 6`

Shifted windows:

```
            ┏━━━━━━━━━━━━━━━┓
  window 1  ┃ [ 1  2  3  4] ┃ ┓
            ┃ [ 5  6  7  8] ┃ ┃
            ┗━━━━━━━━━━━━━━━┛ v   +8 (shift)
  window 2  ┃ [ 9 10 11 12] ┃ ┛
            ┃ [13 14 15 16] ┃ ┓
            ┗━━━━━━━━━━━━━━━┛ ┃
  window 3  ┃ [17 18 19 20] ┃ v   +8 (shift)
            ┃ [21 22 23 24] ┃ ┛
            ┏━━━━━━━━━━━━━━━┓ 
  window 4  ┃ [25 26 27 28] ┃ ┓
            ┃ [29 30 31 32] ┃ ┃
            ┗━━━━━━━━━━━━━━━┛ v   +8 (shift)
  window 5  ┃ [33 34 35 36] ┃ ┛
            ┃ [37 38 39 40] ┃ ┓
            ┗━━━━━━━━━━━━━━━┛ ┃
  window 6  ┃ [41 42 43 44] ┃ v   +8 (shift)
            ┃ [45 46 47 48] ┃ ┛
            ┗━━━━━━━━━━━━━━━┛
```

---

To wrap up, here's a Python implementation on how to calculate each parameter (alternatively, check `tensor.c`'s `sum()` function):

In [14]:
ndim = len(shape)
for dim in range(ndim):
  mprev, mnext = math.prod(shape[:dim]), math.prod(shape[dim+1:])
  stride, nstride = mnext, shape[dim]
  step, nstep = (shape[dim], mprev) if dim == ndim-1 else (1, mnext)
  shift, nshift = (mnext*shape[dim], mprev) if 0 < dim < ndim-1 else (0, 1)
  print(f"dim = {dim}")
  print(f"\tstride = {stride}")
  print(f"\tnstride = {nstride}")
  print(f"\tstep = {step}")
  print(f"\tnstep = {nstep}")
  print(f"\tshift = {shift}")
  print(f"\tnshift = {nshift}")

dim = 0
	stride = 24
	nstride = 2
	step = 1
	nstep = 24
	shift = 0
	nshift = 1
dim = 1
	stride = 8
	nstride = 3
	step = 1
	nstep = 8
	shift = 24
	nshift = 2
dim = 2
	stride = 4
	nstride = 2
	step = 1
	nstep = 4
	shift = 8
	nshift = 6
dim = 3
	stride = 1
	nstride = 4
	step = 4
	nstep = 12
	shift = 0
	nshift = 1
