Memory expects incoming queries to be (b_s, seq_len, embed_dim)     (note that queries and in fact keys, check this)

Memory returns results in (b_s, seq_len, top_k, 2, embed_dim)

In the video, the embed_dim is cut into pieces and each attn head gets a piece, so what needs to happen with the results is this:

1. Split out keys and values:
    keys: (b_s, seq_len, top_k, embed_dim)
    values: (b_s, seq_len, top_k, embed_dim)
2. Cut both in pieces to prep feeding to the attn heads:
    keys: (b_s, seq_len, top_k, heads, head_dim)
    values: (b_s, seq_len, top_k, heads, head_dim)
3. Now we need to pull the head dimension forward as we're going to feed it all into each head (except the batch):
    queries: (b_s, heads, seq_len, head_dim)
    keys: (b_s, heads, seq_len, head_dim, top_k) (we also need to pull head_dim forward)
    values: (b_s, heads, seq_len, top_k, head_dim)
    # This means we're also chopping up the memories and feeding the pieces to the attn heads

4. Then we also want to do self attention in each head with the pieces, where
    qk is obtained by (b_s, heads, seq_len, head_dim) @ (b_s, heads, seq_len, head_dim, top_k) -> (b_s, heads, seq_len, top_k)

[1, 2, 3]  @  [2,4] --> [2 + 8 + 18, 4 + 10 + 21] -> basically you get an attention score for each top_k, for each token in the sequence     
              [4,5]
              [6,7]

    then we must also multiply that with the value:

    (b_s, heads, seq_len, top_k) @ (b_s, heads, seq_len, top_k, head_dim) -> 

   [2, 3] @ [1, 2, 3, 4, ...]  --> [2 + 15, 4 + 18, ... ] , we essentially add the pieces up 
            [5, 6, 7, 8, ...] 

                                                                                   
In our implemtation however, we feed each attention head the complete embed_dim so we don't need to do this, we can reuse the existing logic.

However, below is how it's done in the videa with einsum (without top_k).


In [37]:
# This is challenge setting:

import torch
from einops import rearrange, einsum

b_s = 8
seq_len = 512
num_attn_heads = 4
head_dim = 4

queries = torch.randn(b_s, seq_len, num_attn_heads * head_dim)
keys = torch.randn(b_s, seq_len, num_attn_heads * head_dim)

queries = rearrange(queries, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=num_attn_heads)
keys = rearrange(keys, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=num_attn_heads)


In [31]:
# Code examples from https://www.youtube.com/watch?v=pkVwUVEHmfI

# Rough rules
#
# - Repeated indices tell einsum to multiply on those dimensions
# - Differing indices define the shape of the output
# - Indices that are not present in the output means that summation happens over that dimension
# - Output indices (dimensions) can be put in any order

import torch

x = torch.rand((2,3))

# Permutation of tensors
# ----------------------

# Note: transpose is flipping two dimensions, so it's a special case of permutation. So the below are equivalent:

batch_size, seq_len = 3, 5
a = torch.zeros((batch_size, seq_len))
a[0] = 1
a[1] = 2
a[2] = 3

print(a)

a1 = a.transpose(1,0)
a2 = a.permute(1,0)

# Note on note: always rememeber that view() and reshape() do something fundamentally different, so the below gives a different result

a3 = a.view(seq_len, -1)
a4 = a.reshape(seq_len, -1)

torch.einsum('rc->cr', a) # This is equivalent to the transpose/permute above

# Summation RULE: THE DIMENSION THAT'S OMITTED IN THE OUTPUT IS THE ONE OVER WHICH THE SUMMATION HAPPENS
# ----------------------

# 1. Sum up the whole matrix

torch.einsum('rc->', a)

# 2. Sum up the rows

torch.einsum('rc->r',a)

# Matrix - vector multiplication (transformation)
# ----------------------

v = torch.rand((3,1))
v[0] = 2
v[1] = 2
v[2] = 2
print(v)
m = torch.rand((3,3))
m[0] = 1
m[1] = 2
m[2] = 3
print(m)

# Normally you need to transpose v, but with einsum you can specify the dimension on which the mm must happen so no need:

result = torch.einsum('r v, r m -> v m', v, m)
result

# Matrix - matrix multiplication
# ----------------------

print(x.mm(x.t()))   # 2x3 @ 3x2 -> 2x2

# Again with einsum you don't have to transpose, just specify the @ dimesion:

torch.einsum('r c, d c -> r d', x,x)

# You can also single out rows in the input, here we are taking the dot product of the first row with itself:

torch.einsum('i,i->', x[0], x[0])

# Hadamard product / element-wise multiplication (so no summing)
# ----------------------

torch.einsum('ij, ij -> ij', x, x)

# Outer product
# ----------------------

a = torch.rand((3))
b = torch.rand((5)) # -> result must be 3x5 matrix

torch.einsum('r,c->rc', a, b)

# Batch matrix multiplication
# ----------------------

a = torch.rand((3,2,5))
b = torch.rand((3,5,3)) # -> we want to do 3 bmm's of 2x5 @ 5x3 -> 2x3

torch.einsum('bij, bjk -> bik', a, b)

# Matrix diagonal
# ----------------------

m = torch.rand((3,3))
print(m)
print(torch.einsum('ii->i', m)) # Map 0,0 to 0, 1,1 to 1, etc.


# Matrix trace ( == sum over the diagonal)
# ----------------------

torch.einsum('ii->', m)


tensor([[1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]])
tensor([[2.],
        [2.],
        [2.]])
tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]])
tensor([[1.6491, 0.7200],
        [0.7200, 0.8063]])
tensor([[0.5253, 0.9324, 0.2290],
        [0.0776, 0.8039, 0.3638],
        [0.7593, 0.1732, 0.5688]])
tensor([0.5253, 0.8039, 0.5688])


tensor(1.8979)

In [35]:
# Back to coding a paper 

import torch
import numpy as np
from einops import rearrange, einsum

# Rearrange is something that's not part of standard einsum. Esssentially this allows to reshape while transposing:

# torch.rand vs torch.randn:
# torch.randn generates numbers from a normal distribution with a mean of 0 and a standard deviation of 1
# torch.rand generates numbers from a uniform distribution between 0 and 1.
# For these you *can* directly provide the dimensions as parmaters, or as a tuple.

# Similar with np.random.rand and np.random.randn but here you immediately put the dimensions (only parameters)

x = torch.randn((1,2,3))
y = torch.randn(1,2,3)

#x1 = np.random.randn((1,2,3)) <<< not supported
y1 = np.random.randn(1,2,3)

x = torch.randn(24, 10, 15)

rearrange(x, '(a b) c (d e) -> (e c) a b d', a=6, e=5).shape # b and e are product-wise inferred


torch.Size([50, 6, 4, 3])

In [2]:
# This is challenge setting:

import torch
from einops import rearrange, einsum

b_s = 8
seq_len = 512
num_attn_heads = 4
head_dim = 4

queries = torch.randn(b_s, seq_len, num_attn_heads * head_dim)
keys = torch.randn(b_s, seq_len, num_attn_heads * head_dim)

queries = rearrange(queries, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=num_attn_heads)
keys = rearrange(keys, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=num_attn_heads)

# Now that we know all of the above, let's continue with the challenge

# For the q k dot product in each attn head, we multiply over the head_dim (d) for each element in the sequence, so that 
# we get a seq_len x seq_len result matrix - this is standard attention

queries_keys = einsum(queries, keys, 'b h s1 d, b h s2 d -> b h s1 s2')

queries_keys = queries_keys * (head_dim ** -0.5) # This happens across the heads, which works because all the head_dim's are the same

# 41:30 -> implement the forward() here as an example

# Here is a forward() at the level of the multi-head attention class, implemented with einsum:

def forward(self, input_data):  # What comes in in (b_s, seq_len, embed_dim)

    b_s, seq_len = input_data.shape[:2]  # All dims except dim #2 (the third one)

    # In the vid implementation the Wq/k/v transformations are done in the multi-head attention class, not in the attn heads

    queries = self.Wq(input_data)     
    keys = self.Wk(input_data)
    values = self.Wv(input_data)
    
    # Now we split and multple with einsum (copy from the above, just added self. ...)
    queries = rearrange(queries, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
    keys = rearrange(keys, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
    
    values = rearrange(values, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
    
    # Calculate attention scores for all heads and scale (all copied from above)

    queries_keys = einsum(queries, keys, 'b h s1 d, b h s2 d -> b h s1 s2')

    queries_keys = queries_keys * (head_dim ** -0.5)

    # Here is where you do queries_keys = rel_pos_values + queries_keys -> DO THIS BEFORE MASKING FOR SURE !!!!!!

    # Causal masking - at this point queries_keys is (b_s, num_attn_heads, seq_len, seq_len)

    # Vid takes last two dims separately - should they not always be the same? We write assert here to check it. Vid code below
    # in case this assert ever fires.

    i, j = queries_keys.shape[-2:]
    assert i == j, "Attention scores in attention head are not stored in a square matrix!"
    # mask = torch.ones((i,j), dtype=torch.bool).triu(j-i+1)
    # The line below assumes square matrix:
    mask = torch.ones((queries_keys.shape[-1], queries_keys.shape[-1]), dtype=torch.bool).triu(diagonal=1)

    attn_scores = queries_keys.masked_fill(mask, float("-inf"))   # This is for all attn heads so propagation must be going on

    attn_weights = F.softmax(attn_scores, dim = -1) # This is for all attn heads to propagation must be going on

    head_context_vectors = attn_weights@values # This is for all ...  so it's head_context_vectorS !!! (multiple)
 
    # Now we use einsum to rearrange the heads away

    multihead_context_vector = rearrange(head_context_vector, 'b h s d -> b s (h d)')

    # This is where you add the Memory if this is the last multi-head attn layer -> DO THIS BEFORE RETURNING Wo OUTPUT !!!

    multihead_context_vector = self.Wo(multihead_context_vector)

    return multihead_context_vector    

# 41:50
