In [36]:
import torch
import torch.nn.functional as F
import torch.nn as nn
torch.manual_seed(1337)
B, T, C = 4, 8, 2 #batch, time, channels
x = torch.randn(B, T, C)
x.shape


torch.Size([4, 8, 2])

## Version 1: Averaging with For Loops



cons: computationally intensive


pros: easy to understand, each character will have the context of the previous tokens in the sequence

In [22]:
#the goal is x[b,t] = x[b,t-1] + x[b,t-2] + x[b,t-3] + x[b,t-4]
#we can do this with a for loop
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] #all the previous tokens in the sequence
        xbow[b, t] = torch.mean(xprev, dim=0) #average the previous tokens


## Version 2: Matrix Multiplication as Weighted Aggregation

In [23]:
#example
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('b=')
print(b)
print('c=')
print(c)



a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
b=
tensor([[8., 6.],
        [5., 2.],
        [4., 4.]])
c=
tensor([[17., 12.],
        [17., 12.],
        [17., 12.]])


In [24]:
#using tril to get the lower triangular part of the matrix
a = torch.tril(torch.ones(3,3))
print('a=')
print(a)



a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])


In [25]:
#depending on how many 1s and zeros we have, we are doing a sum of the rows and depositing into c
#we can do incremnetal average update to c as well

a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, dim=1, keepdim=True)
print('a='  )
print(a)
print('b=')
print(b)
#now we can do the weighted aggregation
c = a @ b
print('c=')
print(c)




a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[8., 6.],
        [5., 2.],
        [4., 4.]])
c=
tensor([[8.0000, 6.0000],
        [6.5000, 4.0000],
        [5.6667, 4.0000]])


In [26]:
#look how the first row of c is the same as the first row of b
#this is because we are averaging the previous tokens
#our a matrix defines the averaging scheme

#now we can do the same thing with the transformer

wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True)
xbow2 = wei @ x #this creates (B, T, T) @ (B, T, C) = (B, T, C)
torch.allclose(xbow, xbow2)

#this is a much more efficient way to do the weighted aggregation
#we are using matrix multiplication to do the weighted aggregation
#this is the same as the for loop, but it is much more efficient

True

In [27]:
xbow[0], xbow2[0]
#they are identical. we used batch matrix multiplication to do this weighted aggregation (sepcified in the T x T array)


(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

## Version 3: Using Soft-max

1. These weights begin with 0 but this will not always be the case. How much of each token from the past do we want to aggregate?

These tokens will start looking at each other. They will have affinities.

2. The past still cannot communicate with the future.

3. Normalise and..

4. Sum to aggregate values depending on how interesting they find each other.

In [30]:
tril  = torch.tril(torch.ones(T,T))
wei = torch.zeros((T, T)) #1. 
wei = wei.masked_fill(tril == 0, float('-inf')) #2. for all the elements where tril is 0 make them -inf
wei = F.softmax(wei, dim=-1) #3. on every row do a softmax
xbow3 = wei @ x #4. multiply the weights by the values
torch.allclose(xbow, xbow3)



True

## Version 4: Self-Attention


In [39]:
torch.manual_seed(1337)
#single head self attention. we want the wei to be data-dependent
B, T, C = 4, 8, 32 #batch, time, channels (from 2 to 32 now)
x = torch.randn(B, T, C)

# query (what am I looking for), key (what am I looking at)
# we want to compute the attention between the query and the key. if the key and query are similar they will have a high attention score
#wei = torch.zeros((T, T)) #1. single head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

k = key(x) #(B, T, 16) -> (B, T, head_size)
q = query(x) #(B, T, 16) -> (B, T, head_size)

wei = q @ k.transpose(-2, -1) #(B, T, head_size) @ (B, head_size, T) -> (B, T, T)
#for every row of B we will a T x T matrix of the affinities between the query and the key
tril  = torch.tril(torch.ones(T,T))

#wei = wei.masked_fill(tril == 0, float('-inf')) #2. for all the elements where tril is 0 make them -inf
#wei = F.softmax(wei, dim=-1) #3. on every row do a softmax
out = wei @ x #4. multiply the weights by the values

out.shape


torch.Size([4, 8, 32])

In [None]:
#looking at the raw wei matrix, we can see that the attention is being computed between the query and the key
#the wei matrix is now data-dependent.
wei

In [43]:
#we can now mask the wei matrix so that the future cannot communicate with the past
wei = wei.masked_fill(tril == 0, float('-inf'))
#wei = F.softmax(wei, dim=-1)

out = wei @ x

wei[0]


tensor([[-1.7629,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-3.3334, -1.6556,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.0226, -1.2606,  0.0762,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.7836, -0.8014, -0.3368, -0.8496,    -inf,    -inf,    -inf,    -inf],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,    -inf,    -inf,    -inf],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,    -inf,    -inf],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,    -inf],
        [-1.8044, -0.4126, -0.8306,  0.5898, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

In [47]:
#and apply a softmax to the wei matrix

wei = wei.masked_fill(tril == 0, float('-inf'))

wei = F.softmax(wei, dim=-1)

wei[0]

#This tells us how much information to aggregate from the past.


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4592, 0.5408, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3280, 0.3267, 0.3452, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2559, 0.2477, 0.2487, 0.2477, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1983, 0.1988, 0.1984, 0.1983, 0.2061, 0.0000, 0.0000, 0.0000],
        [0.1659, 0.1670, 0.1659, 0.1658, 0.1697, 0.1658, 0.0000, 0.0000],
        [0.1429, 0.1437, 0.1426, 0.1426, 0.1427, 0.1430, 0.1425, 0.0000],
        [0.1248, 0.1249, 0.1249, 0.1252, 0.1249, 0.1249, 0.1252, 0.1252]],
       grad_fn=<SelectBackward0>)

In [48]:
#we don't aggregate the raw x we create a new representation of x called 'v' - single head aggregation between nodes
#repeat code:
torch.manual_seed(1337)
#single head self attention. we want the wei to be data-dependent
B, T, C = 4, 8, 32 #batch, time, channels (from 2 to 32 now)
x = torch.randn(B, T, C)

# query (what am I looking for), key (what am I looking at)
# we want to compute the attention between the query and the key. if the key and query are similar they will have a high attention score
#wei = torch.zeros((T, T)) #1. single head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) #(B, T, 16) -> (B, T, head_size)
q = query(x) #(B, T, 16) -> (B, T, head_size)

wei = q @ k.transpose(-2, -1) #(B, T, head_size) @ (B, head_size, T) -> (B, T, T)
#for every row of B we will a T x T matrix of the affinities between the query and the key
tril  = torch.tril(torch.ones(T,T))

wei = wei.masked_fill(tril == 0, float('-inf')) #2. for all the elements where tril is 0 make them -inf
wei = F.softmax(wei, dim=-1) #3. on every row do a softmax

v = value(x) #(B, T, 16) -> (B, T, head_size)
out = wei @ v #4. multiply the weights by the values

out.shape
