In [1]:
import torch

In [2]:
B,T,C = 4,8,2

In [3]:
x = torch.randn(B,T,C) # (B,T,C)
x.shape

torch.Size([4, 8, 2])

In [10]:
x[0]

tensor([[-0.0585, -1.4621],
        [-0.1025, -0.9740],
        [-0.0708,  0.0954],
        [-0.5929, -0.4518],
        [ 1.2725, -0.8707],
        [-0.3340,  0.8821],
        [-0.2869,  1.3561],
        [-1.4222,  0.1716]])

# # version 1

In [9]:
xbow = torch.zeros(B,T,C)
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1,:] # (t,C)
        xbow[b,t] = torch.mean(xprev, dim=0, keepdims=True)
xbow[0]

tensor([[-0.0585, -1.4621],
        [-0.0805, -1.2181],
        [-0.0773, -0.7803],
        [-0.2062, -0.6981],
        [ 0.0896, -0.7327],
        [ 0.0190, -0.4635],
        [-0.0247, -0.2036],
        [-0.1994, -0.1567]])

# # version 2

In [35]:
# toy example

a = torch.ones(3,3)
a = torch.tril(a) # makes the tensor a lower triangle matrix
a = a / torch.sum(a, dim=1, keepdim=True) # normalizing the lower triangle matrix to create averaging effect while information sharing
b = torch.randint(0,10, (3,2)).float()
c = a @ b

print(a)
print('----------')
print(b)
print('----------')
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
----------
tensor([[7., 7.],
        [0., 0.],
        [0., 6.]])
----------
tensor([[7.0000, 7.0000],
        [3.5000, 3.5000],
        [2.3333, 4.3333]])


In [36]:
# doing the above in batched manner
# weighted aggregation - wei will be the weights to do the aggregation
wei = torch.tril(torch.ones(T,T))
wei = wei / torch.sum(wei, dim=1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [37]:
xbow2 = wei @ x # (T,T) @ (B,T,C) --> (B,T,C)
xbow2[0]

tensor([[-0.0585, -1.4621],
        [-0.0805, -1.2181],
        [-0.0773, -0.7803],
        [-0.2062, -0.6981],
        [ 0.0896, -0.7327],
        [ 0.0190, -0.4635],
        [-0.0247, -0.2036],
        [-0.1994, -0.1567]])

In [38]:
torch.allclose(xbow, xbow2)

True

# # version 3

In [44]:
tril = torch.tril(torch.ones(T,T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [42]:
# using Softmax

wei = torch.zeros(T,T)
wei = wei.masked_fill(tril==0, float('-inf'))
wei = torch.nn.functional.softmax(wei, dim=1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [43]:
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True