In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [34]:
# efficient matrix multiplication w/ masking
a = torch.ones(3, 3)
print(a)

a = a.tril()
print(a)

a = a / a.sum(1, keepdim=True)
print(a)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


In [36]:
# we can see how this works on an example
b = torch.randint(0, 10, (3, 3)).float()
print(b)

c = a @ b
print(c)

# notice how the averages go downward along the columns

tensor([[4., 8., 0.],
        [8., 5., 3.],
        [9., 9., 8.]])
tensor([[4.0000, 8.0000, 0.0000],
        [6.0000, 6.5000, 1.5000],
        [7.0000, 7.3333, 3.6667]])


In [32]:
# however, you can achieve the same result with softmax
T = 8

# triangular matrix
w = torch.tril(torch.ones(T, T))
print(w)

# mask with negative infinity
w = w.masked_fill(w == 0, float('-inf'))
print(w)

# softmax to get appropriate subdivisions
# dim = 0 avg over rows, dim = 1 avg over cols
w = F.softmax(w, dim=1)
print(w)


tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000,