<a href="https://colab.research.google.com/github/bythyag/nanoGPT/blob/main/gpt_testbook_self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
torch.manual_seed(42)

<torch._C.Generator at 0x7a2e44cedd90>

In [8]:
#let's create some tensors !

a = torch.tril(torch.ones(3, 3))
print(a)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])


In [9]:
# divide each element in 'a' by the sum of its corresponding row
# normalizing each row so that the row elements sum is 1

a = a / torch.sum(a, 1, keepdim=True)
print(a)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


In [10]:
# create another tensor with values between 0-9 of size (3, 2)

b = torch.randint(0,10,(3,2)).float()
print(b)

tensor([[0., 4.],
        [1., 2.],
        [5., 5.]])


In [11]:
# multiply tensor a and b
c = a @ b
print(c)

tensor([[0.0000, 4.0000],
        [0.5000, 3.0000],
        [2.0000, 3.6667]])


In [39]:
# version 1: using mean of prev elements by consecutive calculations

B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
print(x.shape)
print("Original x tensor:\n", x)

torch.Size([4, 8, 2])
Original x tensor:
 tensor([[[-2.5850, -0.0240],
         [-0.1222, -0.7470],
         [ 1.7093,  0.0579],
         [ 0.8637, -0.5890],
         [ 0.7287,  0.9809],
         [ 0.4146,  1.1566],
         [ 0.2691, -0.0366],
         [-0.4808,  0.3163]],

        [[-0.5419, -0.4410],
         [-0.3136, -0.1293],
         [-0.7150, -0.0476],
         [ 0.5230,  0.9717],
         [ 0.9364,  0.7122],
         [-0.0318,  0.1016],
         [ 1.3433,  0.7133],
         [ 0.3463, -0.5402]],

        [[ 0.8337, -0.9585],
         [ 0.4536,  1.2461],
         [-2.3065, -1.2869],
         [ 0.2137, -1.2351],
         [-0.1341, -1.0408],
         [-0.7647, -0.0553],
         [ 1.2049, -0.9825],
         [ 0.3040,  0.9339]],

        [[ 1.0554, -1.4534],
         [ 0.4652,  0.3714],
         [-0.0047,  0.0795],
         [-0.4560, -0.0619],
         [-1.7237, -0.8435],
         [ 0.4351,  0.2659],
         [-0.5871,  0.0827],
         [ 0.1858, -0.9698]]])


In [40]:
# now we want to take the mean of all the prev T in a batch along the columns in C

avgX = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        prevX = x[b,:t+1] # (t,C)
        avgX[b,t] = torch.mean(prevX, 0)

In [41]:
# output example
x[2], avgX[2]

(tensor([[ 0.8337, -0.9585],
         [ 0.4536,  1.2461],
         [-2.3065, -1.2869],
         [ 0.2137, -1.2351],
         [-0.1341, -1.0408],
         [-0.7647, -0.0553],
         [ 1.2049, -0.9825],
         [ 0.3040,  0.9339]]),
 tensor([[ 0.8337, -0.9585],
         [ 0.6437,  0.1438],
         [-0.3397, -0.3331],
         [-0.2014, -0.5586],
         [-0.1879, -0.6550],
         [-0.2840, -0.5551],
         [-0.0713, -0.6161],
         [-0.0244, -0.4224]]))

In [42]:
( -0.5389 + -0.8719 )/2

-0.7054

In [44]:
# version 2: using matrix multiply for a weighted aggregation

lowDiag = torch.tril(torch.ones(T, T))
print(lowDiag)

lowDiag = lowDiag / lowDiag.sum(1, keepdim=True)
print("\nNormalised lower diagonal matrix:")
print(lowDiag)

avgX2 = lowDiag @ x # x is torch.randn(B,T,C), the @ operation is maths and changes shape as follows: (B, T, T) @ (B, T, C) ----> (B, T, C)

print(f"\n is xAvg equal to avgX2? {torch.allclose(avgX, avgX2)}")

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

Normalised lower diagonal matrix:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

 is xAvg equal to avgX2? True
