In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat, pack, unpack, einsum

b_s = 16
seq_len = 512
head_dim = 10
num_attn_heads = 8
embed_dim = 13 


In [53]:
# Create some fake traininig data

input_data = torch.randn(b_s, seq_len, embed_dim)

Wq = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
Wk = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
Wv = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
Wo = nn.Linear(num_attn_heads * head_dim, embed_dim) # Look up why Wo can hava a bias term

queries = Wq(input_data)
keys = Wk(input_data)
values = Wv(input_data)

# The above is all the same, now we're goint to add fake recurrence

recurrence = torch.randn(b_s, seq_len, 2, num_attn_heads * head_dim) # This is how it will come in.

recurrence_keys, recurrence_values = recurrence.unbind(dim=-2) # Each of the 2 has dim (b_s, seq_len, num_attn_heads * head_dim)

xl_keys = torch.cat((recurrence_keys, keys), dim=-2) # Prepend the recurrence keys to the keys along the seq_len dimension (-2)
# So we are basically doubling the seq_len
xl_values = torch.cat((recurrence_values, values), dim=-2)

#print(queries.shape)  # torch.Size([16, 512, 80])
#print(xl_keys.shape)  # torch.Size([16, 1024, 80])
#xl_values.shape       # torch.Size([16, 1024, 80])

# Same as before: we pull the heads out and caculate attention scores per head with einsum.
queries = rearrange(queries, 'b s (h d) -> b h s d', h = num_attn_heads)
xl_keys = rearrange(xl_keys, 'b s (h d) -> b h s d', h = num_attn_heads) # s is twice as long here
attn_scores = einsum(queries, xl_keys, 'b h s1 d, b h s2 d -> b h s1 s2')

# Apply causal masking -> tokens from the recurrent matrix are all allowed to be seen.
# 1    1    1    1    1 -inf -inf -inf
# 1    1    1    1    1    1 -inf -inf
# 1    1    1    1    1    1    1 -inf
# 1    1    1    1    1    1    1    1

rows, cols = attn_scores.shape[-2:]
mask = torch.ones((rows, cols), dtype=torch.bool).triu(diagonal=(cols-rows)+1)
attn_scores = attn_scores.masked_fill(mask, float("-inf"))

# -1 here means that we sum up to 1 across the rows, see below
attn_weights = F.softmax(attn_scores, dim = -1)

# All the rest is reqular self attention again i.e. we rearrange the 

xl_values = rearrange(xl_values, 'b s (h d) -> b h s d', h = num_attn_heads) # s is twice as long here

# print(attn_weights.shape) -> torch.Size([16, 8, 512, 1024])
# print(xl_values.shape) -> torch.Size([16, 8, 1024, 10])

head_context_vectors = attn_weights@xl_values
# print(head_context_vectors.shape) -> torch.Size([16, 8, 512, 10])

multihead_context_vector = rearrange(head_context_vectors, 'b h s d -> b s (h d)')
#multihead_context_vector.shape -> torch.Size([16, 512, 80])

out = Wo(multihead_context_vector)  # Back to embedding dimension
# out.shape -> torch.Size([16, 512, 13])

In [28]:
# Intermezzo on triu

import torch

attn_scores = torch.tensor([[0.,1.,2.,3.,4.,5.,6.,7.],
               [0.,1.,2.,3.,4.,5.,6.,7.],
               [0.,1.,2.,3.,4.,5.,6.,7.],
               [0.,1.,2.,3.,4.,5.,6.,7.]])

rows, cols = attn_scores.shape[-2:] # Here is equiv. to attn_scores.shape[0:] -> basically we are getting the dims 
#print(attn_scores.shape) -> torch.Size([4, 8])

mask = torch.ones((4, 8), dtype=torch.bool).triu(diagonal=(cols-rows)+1) # triu retains upper half, sets bottom to zero. Half is shifted up +4.
# Note: if the inputs are square, then you can put "rows+1", but this would only work with recurrence.
# Without recurrence we need the dim to be 1
# So we use (cols-rows)+1 -> this is rows+1 if they are not the same and 1 otherwise.
print(mask)
attn_scores = attn_scores.masked_fill(mask, float("-inf")) # Put -inf where mask is True, leave untouched elsewhere.
print(attn_scores)

tensor([[False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False]])
tensor([[0., 1., 2., 3., 4., -inf, -inf, -inf],
        [0., 1., 2., 3., 4., 5., -inf, -inf],
        [0., 1., 2., 3., 4., 5., 6., -inf],
        [0., 1., 2., 3., 4., 5., 6., 7.]])


In [45]:
# Intermezzo on softmax

attn_scores = torch.tensor([[0.,1.,2.,3.,4.,5.,6.,7.],
               [0.,10.,20.,30.,40.,50.,60.,70.],
               [0.,100.,200.,300.,400.,500.,600.,700.],
               [0.,1000.,2000.,3000.,4000.,5000.,6000.,7000.]])
attn_weights = F.softmax(attn_scores, dim = -1)
print(attn_weights)
print(attn_weights.shape) 
print(attn_weights.shape[-1])
print(torch.sum(attn_weights, dim=-1))
# attn_weights[0].sum() -> this sums up only the first row, sum() comes instead of the last dimension


# dim -1 is the "most inner" list, so the row in the below. So after softmax, the row sums up to 1.
#tensor([[5.7661e-04, 1.5674e-03, 4.2606e-03, 1.1582e-02, 3.1482e-02, 8.5577e-02,
#         2.3262e-01, 6.3233e-01],
#        [3.9753e-31, 8.7561e-27, 1.9287e-22, 4.2482e-18, 9.3572e-14, 2.0611e-09,
#         4.5398e-05, 9.9995e-01],
#        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
#         3.7835e-44, 1.0000e+00],
#        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
#         0.0000e+00, 1.0000e+00]])
#torch.Size([4, 8])
#8
#tensor([1., 1., 1., 1.])


tensor([[5.7661e-04, 1.5674e-03, 4.2606e-03, 1.1582e-02, 3.1482e-02, 8.5577e-02,
         2.3262e-01, 6.3233e-01],
        [3.9753e-31, 8.7561e-27, 1.9287e-22, 4.2482e-18, 9.3572e-14, 2.0611e-09,
         4.5398e-05, 9.9995e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         3.7835e-44, 1.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 1.0000e+00]])
torch.Size([4, 8])
8
tensor([1., 1., 1., 1.])


In [None]:
# We can now add the principles above to our existing MHSelfAttn, we take the einsum version here from ExplainEinsum as base (exact copy)

class MHSelfAttn(nn.Module):

    def __init__(self, embed_dim, num_attn_heads=8, head_dim=32):
        super().__init__()

        # Need these in forward()
        self.num_attn_heads = num_attn_heads
        
        self.Wq = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
        self.Wk = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
        self.Wv = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
        self.Wo = nn.Linear(num_attn_heads * head_dim, embed_dim) # Look up why Wo can hava a bias term

#    def forward(self, input_data):
    def forward(self, input_data, recurrence=None):

        b_s, seq_len = input_data.shape[:2]
        
        queries = self.Wq(input_data)
        keys = self.Wk(input_data)
        values = self.Wv(input_data)

        # Start new
        if recurrence is not None:
            # Copied from above:
            recurrence_keys, recurrence_values = recurrence.unbind(dim=-2)
            keys = torch.cat((recurrence_keys, keys), dim=-2)                 # xl_keys   -> keys so that rearrange below keeps working
            values = torch.cat((recurrence_values, values), dim=-2)           # xl_values -> values same reason
            # Note: for the first sequence, the keys and values stay the same, and we perform regular self attention.
            recurrence_length = recurrence_keys.shape[1]                      # 2nd dim so the seq_len of just the recurrence
        # End new
        
        queries = rearrange(queries, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
        keys = rearrange(keys, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
        values = rearrange(values, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)

        attn_scores = einsum(queries, keys, 'b h s1 d, b h s2 d -> b h s1 s2')

        i, j = attn_scores.shape[-2:]
        mask = torch.ones((i,j), dtype=torch.bool).triu(diagonal=(j-i+1))
        masked_attn_scores = attn_scores.masked_fill(mask, float("-inf"))

        attn_weights = F.softmax(masked_attn_scores, dim = -1)

        head_context_vectors = attn_weights@values
        
        multihead_context_vector = rearrange(head_context_vectors, 'b h s d -> b s (h d)')
        out = self.Wo(multihead_context_vector)

        # Start new
        # New stuff to return the recurrence -> we only return the current keys and values because they have to "trail".
        # 1. Restore the original arragement, note that keys and values are rectangular if recurrence was concat'ed
        keys = rearrange(keys, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
        values = rearrange(values, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
        # Do the inverse of the unbind:
        k_v_stacked = torch.stack((keys, values), dim=-2) # (b_s, seq_len, 2, embed_dim)

        if recurrence is not None:
            # Now we want to cut out only the currenct keys and values (I guess we can also store them before the concat)

            # Recurrence is prepended
            recurrence, current = k_v_stacked[:, :-recurrence_length], k_v_stacked[:, -recurrence_length:]

            # >>> test
            # tensor([[1, 2, 3, 4],
            #         [5, 6, 7, 8]])
            # >>> test[:,:-2]  -> everything until position -2
            # tensor([[1, 2],
            #         [5, 6]]) 
            # >>> test[:,-2:]  -> everything from position -2 on
            # tensor([[3, 4],
            #         [7, 8]])

        else: # This means we processing the first sequence, which means we didn't concat so we can just pass on the keys/values
            current = k_v_stacked
        # End new
        
        # Add current
        return out, current


In [None]:
# Now we do the same but for MHSelfAttnWithMem - additions marked with comments.


class MHSelfAttnWithMem(nn.Module):

    def __init__(self, embed_dim, num_attn_heads=8, head_dim=32, top_k=3):
        super().__init__()

        self.num_attn_heads = num_attn_heads
        self.top_k = top_k

        self.Wq = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
        self.Wk = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
        self.Wv = nn.Linear(embed_dim, num_attn_heads * head_dim, bias=False)
        self.Wo = nn.Linear(num_attn_heads * head_dim, embed_dim) # Look up why Wo can hava a bias term

        self.gate = nn.Parameter(torch.randn(num_attn_heads, 1, 1))

#    def forward(self, input_data, memory):
    def forward(self, input_data, memory, recurrence):
        b_s, seq_len = input_data.shape[:2]
        
        queries = self.Wq(input_data)
        keys = self.Wk(input_data)
        values = self.Wv(input_data)

        # Start new - copied from MHSelfAttn above
        if recurrence is not None:
            recurrence_keys, recurrence_values = recurrence.unbind(dim=-2)
            keys = torch.cat((recurrence_keys, keys), dim=-2)                
            values = torch.cat((recurrence_values, values), dim=-2)
            recurrence_length = recurrence_keys.shape[1]
        # End new
        
        queries = rearrange(queries, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
        keys = rearrange(keys, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
        values = rearrange(values, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)

        attn_scores = einsum(queries, keys, 'b h s1 d, b h s2 d -> b h s1 s2')

        i, j = attn_scores.shape[-2:]
        mask = torch.ones((i,j), dtype=torch.bool).triu(diagonal=(j-i+1))
        masked_attn_scores = attn_scores.masked_fill(mask, float("-inf"))

        attn_weights = F.softmax(masked_attn_scores, dim = -1)

        head_context_vectors = attn_weights@values
        
        queries = rearrange(queries, 'b h s d -> b s (h d)')

        mem_keys_and_values = memory.query(queries, self.top_k)

        mem_keys, mem_values = mem_keys_and_values.unbind(-2)
        
        mem_keys = rearrange(mem_keys, 'b s k (h d) -> b h s k d', h = num_attn_heads)
        mem_values = rearrange(mem_values, 'b s k (h d) -> b h s k d', h = num_attn_heads)

        queries = rearrange(queries, 'b s (h d) -> b h s d', h=num_attn_heads)
        
        attn_scores = einsum(queries, mem_keys, 'b h s d, b h s k d -> b h s k')       
        attn_scores = attn_scores * (head_dim ** -0.5)

        attn_weights = F.softmax(attn_scores, dim = -1)
        
        attn_weights = self.dropout(attn_weigths)

        mem_context_vectors = einsum(attn_weights, mem_values, 'b h s k, b h s k d -> b h s d')

        head_mem_context_vectors = (mem_context_vectors * self.gate) + (head_context_vectors * (1 - self.gate))
        
        multihead_context_vector = rearrange(head_mem_context_vector, 'b h s d -> b s (h d)')
        out = self.Wo(multihead_context_vector)

        # Start new - copied from above's MHSelfAttn
        keys = rearrange(keys, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
        values = rearrange(values, 'b_s seq_len (num_attn_heads head_dim) -> b_s num_attn_heads seq_len head_dim', num_attn_heads=self.num_attn_heads)
        k_v_stacked = torch.stack((keys, values), dim=-2) # (b_s, seq_len, 2, embed_dim)
        if recurrence is not None:
            recurrence, current = k_v_stacked[:, :-recurrence_length], k_v_stacked[:, -recurrence_length:]
        else: 
            current = k_v_stacked
        # End new

        # ALSO STORE THE CURRENT keys and values in the knn memory (we didn't do this yet):
        knn.add(current)

        # Add current
        return out, current
