In [28]:
import functorch.dim
import torch

# toy example
# this is to go from tokens not talking to each other to coupling tokens
# so that current token (say 5) talks to all past tokens (1, 2, 3, 4)
# info flows from previous context to current time. Cant get any info
# from future because about to try and predict future.

torch.manual_seed(1337)
B, T, C = 4, 8, 2  # batch, time, channels (num parallel contexts being processed, max context length, vocab size)
# means in this case we have up to 8 tokens in a batch
x = torch.randn(B, T, C)
print(x.shape)

torch.Size([4, 8, 2])


In [29]:
# going to use an average for now that's extremely lossy. 
# Use average of all preceding elements (get channels for 1, 2, 3, 4, and 5)
# if 5th token. Average up all channels. Gives us a feature vector that summerizes
# the 5th token in the context of its history. Does lose a ton of info about spatial arrangement of tokens, but using for now.

# We want x[b, t] = mean_{i <= t}  x[b, i].
# bow is bag of words (word stored on every one of the 8 locations)
# NEST FOR LOOPS VERY INEFFICIENT. WILL REWRITE w/ Matrix Mult
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t + 1]  # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

In [30]:
print(x[0])

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])


In [31]:
xbow[0] # basically at each iteration of xbow, first x is ident to xbow row 0 because only an avg of itself
# but with each iteration of xbow, it is an average of all the rows up to that point.
# at the end xbow is an average of every row in the x table.

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [32]:
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b   # @ = matrix mult (so this is a dot b and a is a (3, 3)
# b is a (3, 2), and so c (the dot product) is a 3, 2) b/c (3,3) * (3, 2) take the outside terms and throw away inside to get final dims of cross prod
print('a=')
print(a)
print('---')
print('b=')
print(b)
print('---')
print('c=')
print(c)


a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [33]:
# do this again, but using a lower triangular matrix

torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3,2)).float()
c = a @ b   # @ = matrix mult (so this is a dot b and a is a (3, 3)
# b is a (3, 2), and so c (the dot product) is a 3, 2) b/c (3,3) * (3, 2) take the outside terms and throw away inside to get final dims of cross prod
print('a=')
print(a)
print('---')
print('b=')
print(b)
print('---')
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [34]:
# do this again, but normalize rows so that all rows of 1 add to 1.
# so now all rows of a will sum to 1. 

torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b   # @ = matrix mult (so this is a dot b and a is a (3, 3)
# b is a (3, 2), and so c (the dot product) is a 3, 2) b/c (3,3) * (3, 2) take the outside terms and throw away inside to get final dims of cross prod
print('a=')
print(a)
print('---')
print('b=')
print(b)
print('---')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [35]:
# Create new version of average matrix called wei (short for weightts)
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [36]:
# use weighted aggreg using matrix multiply. Weights specified in T x T array
# a is now average matrix called wei (short for weights)
# b is x
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x  # a dot b where a is (T, T) @ b (B, T, C) so ----> (B, T, T) @ (B, T, C) applies batching
# so for each batch element there will be a (T, T) @  (T, C) ---> (B, T, C) again. 
torch.allclose(xbow, xbow2)

False

In [37]:
print(xbow[0])
print(xbow2[0])

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


In [38]:
from torch.nn import functional as F

# Final version of this: Softmax (also normalizes so each row = 1)
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

False

In [43]:
# version 4: Self-Attention (Self-Atten b/c key, queries, and vals all come from x
import torch.nn as nn

torch.manual_seed(1337)
B, T, C = 4, 8, 32  # batch, time, channels (# ind. parallel seqs, block size (context), vocab size)
x = torch.randn(B, T, C)

# single head performing self-attention
head_size = 16  # hyperparam
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))  # decoder block that prevents future nodes from talking to current or past nodes.
# delete above if you want all nodes to talk to each other in the batch. (i.e. all 8 nodes in the context block to talk even if you're at position 0
# in the batch. 
wei = F.softmax(wei, dim=-1)

v = value(x)  # this is the thing that gets aggregated between the different nodes (x is like private info to the node itself (say node 5). 
out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [42]:
wei[0]  # data dependent wei where the token influences the weight (not uniform as before)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)