In [1]:
import torch

# Aggregating values

In [34]:
torch.manual_seed(1337)

b = torch.randint(0,10,(3,2)).float()
b

tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])

Aggregating values can be done via matrix multiplication. I.e., for summing along the column we can do:

In [36]:
a = torch.ones(3,3)
a

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [37]:
a @ b

tensor([[12., 10.],
        [12., 10.],
        [12., 10.]])

Doing a cumulative sum up works by using a triangular matrix:

In [38]:
diag = torch.tril(a)
diag

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [39]:
diag @ b

tensor([[ 5.,  7.],
        [ 7.,  7.],
        [12., 10.]])

Which can be easily extended to an average computation by weighting the matrix entries:

In [40]:
diag /= diag.sum(dim=1, keepdim=True)
diag

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [41]:
diag @ b

tensor([[5.0000, 7.0000],
        [3.5000, 3.5000],
        [4.0000, 3.3333]])

## Using softmax

The same can be achieved by abusing the fact that $exp(-inf) = 0$, using softmax

In [42]:
a = torch.ones(3,3)
diag = torch.tril(a)
a.masked_fill_(diag==0, -torch.inf)
a = torch.softmax(a, dim=1)
a

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])