In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
!pip install einops
from einops import rearrange, repeat, pack, unpack, einsum



In [None]:

# 1st segment: compute current kv projections [kv_1] and perform attention
# 2nd segment: concatenate old kv projections with current kv projections [kv1 + kv2] and perform attention
# 3rd segment: concatenate old kv projections with current kv projections [kv2 + kv3] and perform attention
# 4th segment: concatenate old kv projections with current kv projections [kv3 + kv4] and perform attention
# ...

# 1st segment:
seg_one_kv = [seg_1_layer_1_kv,
            seg_1_layer_2_kv,
            seg_1_layer_3_kv,
              ...]

# 2nd segment:
seg_two_kv = [concatenate(seg_1_layer_1_kv, seg_2_layer_1_kv),
            concatenate(seg_1_layer_2_kv, seg_2_layer_2_kv),
            concatenate(seg_1_layer_3_kv, seg_2_layer_3_kv),
                ...]

# 3rd segment:
seg_three_kv = [concatenate(seg_2_layer_1_kv, seg_3_layer_1_kv),
            concatenate(seg_2_layer_2_kv, seg_3_layer_2_kv),
            concatenate(seg_2_layer_3_kv, seg_3_layer_3_kv),
                ...]

NameError: name 'seg_1_layer_1_kv' is not defined

In [None]:
batch_size = 16
seq_len = 512
head_dimension = 10
number_heads = 8
embedding_dimension = 13
scaling_factor = 1

In [None]:
# Create fake training batch
input_data = torch.randn((batch_size, seq_len, embedding_dimension))
input_data.shape

torch.Size([16, 512, 13])

In [None]:
# Initialize projection matrices
query_matrix = nn.Linear(embedding_dimension, number_heads * head_dimension)
key_matrix = nn.Linear(embedding_dimension, number_heads * head_dimension)
value_matrix = nn.Linear(embedding_dimension, number_heads * head_dimension)
output_matrix = nn.Linear(number_heads * head_dimension, embedding_dimension)

In [None]:
# Create KQV matrices with input data
queries = query_matrix(input_data)
keys = key_matrix(input_data)
values = value_matrix(input_data)
values.shape

torch.Size([16, 512, 80])

In [None]:
# Create a fake cached XL recurrence
xl_memory = torch.randn(batch_size, seq_len,2,number_heads*head_dimension)
xl_memory.shape

torch.Size([16, 512, 2, 80])

In [None]:
xl_keys, xl_values = xl_memory.unbind(dim=-2)
xl_keys.shape

torch.Size([16, 512, 80])

In [None]:
keys = torch.cat((xl_keys, keys), dim=-2)
values = torch.cat((xl_values, values), dim=-2)
values.shape

torch.Size([16, 1024, 80])

In [None]:
queries.shape

torch.Size([16, 512, 80])

In [None]:
queries = rearrange(queries, 'b t (h d) -> b h t d', h = number_heads)
keys    = rearrange(keys, 'b t (h d) -> b h t d', h = number_heads)
qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

print ("queries:", queries.shape)
print ("keys:", keys.shape)
print ("qk:", qk.shape)

queries: torch.Size([16, 8, 512, 10])
keys: torch.Size([16, 8, 1024, 10])
qk: torch.Size([16, 8, 512, 1024])


In [None]:
# Regular Self Attention QK (4,4)
#
# [    1., -1000., -1000., -1000.]
# [    1.,     1., -1000., -1000.]
# [    1.,     1.,     1., -1000.]
# [    1.,     1.,     1.,     1.]



# Transformer XL Self Attention QK (4,8)
#
# [    1.,     1.,     1.,     1.,     1., -1000., -1000., -1000.]
# [    1.,     1.,     1.,     1.,     1.,     1., -1000., -1000.]
# [    1.,     1.,     1.,     1.,     1.,     1.,     1., -1000.]
# [    1.,     1.,     1.,     1.,     1.,     1.,     1.,     1.]

In [None]:
i, j = qk.shape[-2:]
j

1024

In [None]:
# Create mask
mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
mask.shape

torch.Size([512, 1024])

In [None]:
qk = qk.masked_fill(mask, float('-inf'))

In [None]:
qk

tensor([[[[-1.4778e+00, -2.4345e+00, -1.4185e+00,  ...,        -inf,
                  -inf,        -inf],
          [-1.6760e+00,  1.4537e+00, -1.9621e+00,  ...,        -inf,
                  -inf,        -inf],
          [-8.0605e-01, -3.6405e+00, -6.1674e-01,  ...,        -inf,
                  -inf,        -inf],
          ...,
          [-2.1815e+00,  6.8134e-01, -2.3266e+00,  ...,  7.9539e-01,
                  -inf,        -inf],
          [-1.1277e+00, -1.0929e+00, -1.0921e-01,  ..., -2.3192e-01,
           -1.1430e+00,        -inf],
          [ 3.3872e+00,  1.5124e+00, -5.3877e-01,  ...,  2.1358e+00,
            3.5113e+00, -1.3868e+00]],

         [[ 5.4370e-01,  1.8310e+00, -1.4181e-01,  ...,        -inf,
                  -inf,        -inf],
          [ 6.9232e-01, -2.7681e-01,  1.6967e+00,  ...,        -inf,
                  -inf,        -inf],
          [-1.4152e+00, -3.1564e+00, -2.4268e+00,  ...,        -inf,
                  -inf,        -inf],
          ...,
     

In [None]:
# Apply softmax
qk = F.softmax(qk, dim=-1)
qk[0][0][0].sum()

tensor(1., grad_fn=<SumBackward0>)

In [None]:
# Separate values tensor into heads for multi-head attention and move dimensions for @ with qk
values = rearrange(values, 'b t (h d) -> b h t d', h=number_heads)
print ("qk:", qk.shape)
print ("values:", values.shape)

qk: torch.Size([16, 8, 512, 1024])
values: torch.Size([16, 8, 1024, 10])


In [None]:
qkv = qk@values
qkv.shape

torch.Size([16, 8, 512, 10])

In [None]:
# Reassemble all heads
qkv = rearrange(qkv, 'b h t d -> b t (h d)')
qkv.shape

torch.Size([16, 512, 80])

In [None]:
output_matrix

Linear(in_features=80, out_features=13, bias=True)

In [None]:
out = output_matrix(qkv)
out.shape

torch.Size([16, 512, 13])

In [None]:
class XLAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 32,
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)


    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
        xl_memory = None
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim=-2) # unstack
            keys = torch.cat((k_xl, keys), dim = -2) # prepend XL memory
            values = torch.cat((v_xl, values), dim = -2) # prepend XL memory
            xl_sequence_length = k_xl.shape[1]

        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        keys    = rearrange(keys, 'b t (h d) -> b h t d', h = self.heads)
        qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        qk = qk * self.scale

        i, j = qk.shape[-2:]
        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk@values
        qkv = rearrange(qkv, 'b h t d -> b t (h d)')

        #### Return XL Memories

        keys = rearrange(keys, 'b h t d -> b t (h d)', h = self.heads)
        values = rearrange(values, 'b h t d -> b t (h d)', h=self.heads)
        kv_memories = torch.stack((keys, values), dim=-2) # (batch, sequence_len, 2, dimension)

        if xl_memory is not None: #pass on the keys and values so that next segment can use these projections
            xl_memories, current_input = kv_memories[:, :-xl_sequence_length], kv_memories[:, -xl_sequence_length:]
            kv_to_add_xl = current_input
        else:
            kv_to_add_xl = kv_memories


        out = self.output_matrix(qkv)



        return out, kv_to_add_xl


In [None]:
class KNN_XLAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 32,
        topk_retrieved_memories = 3,
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)

        self.gate_bias = nn.Parameter(torch.randn(self.heads, 1, 1))
        self.topk_retrieved_memories = topk_retrieved_memories

    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
        knn,
        xl_memory = None
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim = -2) # unstack
            keys = torch.cat((k_xl, keys), dim = -2) # prepend XL memory
            values = torch.cat((v_xl, values), dim = -2) # prepend XL memory
            xl_sequence_length = k_xl.shape[1]

        ### LOCAL ATTENTION

        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        keys    = rearrange(keys, 'b t (h d) -> b h t d', h = self.heads)
        qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        qk = qk * self.scale

        ############
        # TODO
        # qk = relative_position_values + qk
        ############

        i, j = qk.shape[-2:]
        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk@values
        qkv = rearrange(qkv, 'b h t d -> b t (h d)')

        ### KNN ATTENTION

        # Convert queries to search form
        queries = rearrange(queries, 'b h t d -> b t (h d)')
        mem_kv = knn.search(queries, topk = self.topk_retrieved_memories) # returns b t k 2 d
        mem_k, mem_v = mem_kv.unbind(dim = -2)
        mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=self.heads)
        mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=self.heads)

        # Convert queries to attention form
        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        mem_qk = einsum('b h t d, b h t k d -> b h t k', queries, mem_k)
        mem_qk = mem_qk * self.scale

        mem_qk = F.softmax(mem_qk, dim=-1)
        mem_qk = self.dropout(mem_qk)
        mem_qkv = einsum('b h t k, b h t k d -> b h t d', mem_qk, mem_v)

        # Combined attentions

        combined_qkv = mem_qkv * self.gate_bias + qkv * (1 - self.gate_bias)
        combined_qkv = rearrange(combined_qkv, 'b h t d -> b t (h d)')
        out = self.output_matrix(combined_qkv)

        # New XL memories
        keys = rearrange(keys, 'b h t d -> b t (h d)', h = self.heads)
        values = rearrange(values, 'b h t d -> b t (h d)', h=self.heads)
        kv_memories = torch.stack((keys, values), dim=-2) # (batch, sequence_len, 2, dimension)

        if xl_memory is not None:
            # if we're on a middle/end segment of a document (there are previous XL memories)
            xl_memories, current_kv = kv_memories[:, :-xl_sequence_length], kv_memories[:, -xl_sequence_length:]
        else:
            # if we're at the first segment
            current_kv = kv_memories

        knn.add(current_kv)

        return out, current_kv