Memory expects incoming queries to be (b_s, seq_len, embed_dim)     (note that queries and in fact keys, check this)

Memory returns results in (b_s, seq_len, top_k, 2, embed_dim)

In the video, the embed_dim is cut into pieces and each attn head gets a piece, so what needs to happen with the results is this:

1. Split out keys and values:
    keys: (b_s, seq_len, top_k, embed_dim)
    values: (b_s, seq_len, top_k, embed_dim)
2. Cut both in pieces to prep feeding to the attn heads:
    keys: (b_s, seq_len, top_k, heads, head_dim)
    values: (b_s, seq_len, top_k, heads, head_dim)
3. Now we need to pull the head dimension forward as we're going to feed it all into each head (except the batch):
    queries: (b_s, heads, seq_len, head_dim)
    keys: (b_s, heads, seq_len, head_dim, top_k) (we also need to pull head_dim forward)
    values: (b_s, heads, seq_len, top_k, head_dim)
    # This means we're also chopping up the memories and feeding the pieces to the attn heads

4. Then we also want to do self attention in each head with the pieces, where
    qk is obtained by (b_s, heads, seq_len, head_dim) @ (b_s, heads, seq_len, head_dim, top_k) -> (b_s, heads, seq_len, top_k)

[1, 2, 3]  @  [2,4] --> [2 + 8 + 18, 4 + 10 + 21] -> basically you get an attention score for each top_k, for each token in the sequence     
              [4,5]
              [6,7]

    then we must also multiply that with the value:

    (b_s, heads, seq_len, top_k) @ (b_s, heads, seq_len, top_k, head_dim) -> 

   [2, 3] @ [1, 2, 3, 4, ...]  --> [2 + 15, 4 + 18, ... ] , we essentially add the pieces up 
            [5, 6, 7, 8, ...] 

                                                                                   
In our implemtation however, we feed each attention head the complete embed_dim so we don't need to do this, we can reuse the existing logic.

However, below is how it's done in the videa with einsum (without top_k).


In [6]:
import torch
from einops import rearrange, einsum

b_s = 8
seq_len = 512
num_attn_heads = 4
head_dim = 4

queries = torch.randn(b_s, seq_len, num_attn_heads * head_dim)
keys = torch.randn(b_s, seq_len, num_attn_heads * head_dim)

queries = rearrange(queries, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=num_attn_heads)
keys = rearrange(keys, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=num_attn_heads)


