In [48]:
import torch
from torch import nn
from torch.nn import functional as F

torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn(B, T, C)
print(x.shape)
x[0]

torch.Size([4, 8, 2])


tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [41]:
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # t,C
        xbow[b, t] = torch.mean(xprev, 0) 

xbow[B-1]

tensor([[ 1.6455, -0.8030],
        [ 1.4985, -0.5395],
        [ 0.4954,  0.3420],
        [ 1.0623, -0.1802],
        [ 1.1401, -0.4462],
        [ 1.0870, -0.4071],
        [ 1.0430, -0.1299],
        [ 1.1138, -0.1641]])

In [42]:
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(f"a={a}")
print(f"b={b}")
print(f"c={c}")

a=tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [43]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3), diagonal=0)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(f"a={a}")
print(f"b={b}")
print(f"c={c}")

a=tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [44]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3), diagonal=0)
print(torch.sum(a, 1, keepdim=False))
print(torch.sum(a, 1, keepdim=True))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print(f"a={a}")
print(f"b={b}")
print(f"c={c}")

tensor([1., 2., 3.])
tensor([[1.],
        [2.],
        [3.]])
a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [45]:
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (T, T) @ (B, T, C) -> (B, T, T) @ (B, T, C) --> (B, T, C)
print(xbow[1])
print(xbow2[1])
torch.allclose(xbow, xbow2)

tensor([[ 1.3488, -0.1396],
        [ 0.8173,  0.4127],
        [-0.1342,  0.4395],
        [ 0.2711,  0.4774],
        [ 0.2421,  0.0694],
        [ 0.0084,  0.0020],
        [ 0.0712, -0.1128],
        [ 0.2527,  0.2149]])
tensor([[ 1.3488, -0.1396],
        [ 0.8173,  0.4127],
        [-0.1342,  0.4395],
        [ 0.2711,  0.4774],
        [ 0.2421,  0.0694],
        [ 0.0084,  0.0020],
        [ 0.0712, -0.1128],
        [ 0.2527,  0.2149]])


False

In [46]:


tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
torch.allclose(xbow2, xbow3)

True

In [61]:
torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn(B, T, C)

# single head operation for self-attention
head_size = 16

key = nn.Linear(C, head_size, bias=False) # keys I've
query = nn.Linear(C, head_size, bias=False) # query I want to run
value = nn.Linear(C, head_size, bias=False) # If you find me interesting, this is my value

k = key(x) # B, T, 16
q = query(x) # B, T, 16

wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril[:T,:T] == 0, float('-inf'))
wei = F.softmax(wei, dim=2)

v = value(x)

out = wei @ v

In [63]:
print(wei.shape)
print(v.shape)

torch.Size([4, 8, 8])
torch.Size([4, 8, 16])


In [None]:
a = torch.randn((4, 8, 8))
b = torch.randn((4, 8, 16))

a @ b

In [73]:
torch.manual_seed(42)
a = torch.randint(0, 10, (4, 3, 2)).float()
b = torch.arange(3)
print(a)
print(b)
a + b

tensor([[[2., 7.],
         [6., 4.],
         [6., 5.]],

        [[0., 4.],
         [0., 3.],
         [8., 4.]],

        [[0., 4.],
         [1., 2.],
         [5., 5.]],

        [[7., 6.],
         [9., 6.],
         [3., 1.]]])
tensor([0, 1, 2])


RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2