In [None]:
!pip install faiss-gpu
import faiss
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [None]:
dim = 64
index = faiss.IndexFlatL2(dim)

IndexFlatL2 measures the L2 (or Euclidean) distance between all given points between our query vector, and the vectors loaded into the index. It’s simple, very accurate, but not too fast.

In [None]:
vector_data = np.random.random((10000, dim)).astype('float32')

In [None]:
index.add(vector_data)

In [None]:
index.ntotal

10000

In [None]:
index.remove_ids(np.arange(10))

10

In [None]:
index.ntotal

9990

In [None]:
query_data = np.random.random((10, dim)).astype('float32')
top_k = 2
distance, ids = index.search(query_data,top_k)

In [None]:
distance

array([[5.827529 , 5.9235134],
       [5.26203  , 6.7408357],
       [5.938055 , 6.3166738],
       [6.046454 , 6.407394 ],
       [5.1537175, 5.591063 ],
       [4.8479753, 5.3954043],
       [6.462555 , 6.5777235],
       [6.5997653, 6.7846956],
       [4.966932 , 5.265863 ],
       [5.9377885, 6.090594 ]], dtype=float32)

In [None]:
ids

array([[1649, 2569],
       [3664, 7181],
       [2811, 1506],
       [9839, 5726],
       [2630, 2242],
       [7215, 8266],
       [6718, 7883],
       [7982, 3781],
       [ 234, 8131],
       [7394, 6117]])

In [None]:
index.search_and_reconstruct(query_data,top_k)

(array([[5.827529 , 5.9235134],
        [5.26203  , 6.7408357],
        [5.938055 , 6.3166738],
        [6.046454 , 6.407394 ],
        [5.1537175, 5.591063 ],
        [4.8479753, 5.3954043],
        [6.462555 , 6.5777235],
        [6.5997653, 6.7846956],
        [4.966932 , 5.265863 ],
        [5.9377885, 6.090594 ]], dtype=float32),
 array([[1649, 2569],
        [3664, 7181],
        [2811, 1506],
        [9839, 5726],
        [2630, 2242],
        [7215, 8266],
        [6718, 7883],
        [7982, 3781],
        [ 234, 8131],
        [7394, 6117]]),
 array([[[0.23735897, 0.27604678, 0.6506493 , ..., 0.12610015,
          0.13045064, 0.14365897],
         [0.41333184, 0.5881936 , 0.9008604 , ..., 0.4965098 ,
          0.00142935, 0.44413137]],
 
        [[0.9049586 , 0.995561  , 0.3454757 , ..., 0.72844285,
          0.14323242, 0.91439706],
         [0.4093519 , 0.9812425 , 0.55505455, ..., 0.76743543,
          0.17521223, 0.43498406]],
 
        [[0.3151435 , 0.6045696 , 0.5940221

In [None]:

db_filepath = "./memory.memmap"
max_memories = 10000
shape = (max_memories, 2, dim)
db = np.memmap(db_filepath, mode = 'w+', dtype = np.float32, shape = shape)

In [None]:
db

memmap([[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)

In [None]:
db[1:2] = np.random.rand(1,2,dim)

In [None]:
db

memmap([[[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        [[0.49275523, 0.11181501, 0.22879675, ..., 0.9322479 ,
          0.8035381 , 0.56936264],
         [0.02991646, 0.44009668, 0.3851566 , ..., 0.7680023 ,
          0.37311602, 0.47453612]],

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        ...,

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]

In [None]:
db[0] = torch.randn(1,2,dim)

In [None]:
type(db[0])

numpy.memmap

In [None]:
db[1].shape

(2, 64)

In [None]:
dim = 10
max_memories = 10000
batch_size = 16
top_k = 3

db_filepath = "./memory.memmap"
shape = (max_memories, 2, dim)


# create index
index = faiss.IndexFlatL2(dim)

# create database
db = np.memmap(db_filepath, mode = 'w+', dtype = np.float32, shape = shape)

In [None]:
#KNN DATABASE CLASS

#add to index
#add to database
#query the index
#retrieve from the database
#remove/clear from index and database

In [None]:
#Stacking key and value projections
kv = np.random.rand(batch_size, 512, 2, dim).astype('float32') # b t 2 (hd)
kv = kv.reshape(-1, 2, dim)
kv.shape #8192 pairs of key value projections

(8192, 2, 10)

In [None]:
16 * 512

8192

In [None]:
k = kv[:,0,:]
k.shape

(8192, 10)

In [None]:
index.add(np.ascontiguousarray(k))

In [None]:
db_offset = 0
kv_len = kv.shape[0]
ids = (np.arange(kv_len) + db_offset)
db_offset += kv_len
db[ids] = kv

## Query and retrieve


In [None]:
query = np.random.rand(batch_size, 512, dim).astype('float32') # b t (hd)
query = query.reshape(-1, dim) #flatten
query.shape

(8192, 10)

In [None]:
distance, ids = index.search(query, top_k)
ids.shape

(8192, 3)

In [None]:
ids

array([[2515, 7663, 7431],
       [ 522, 7214, 5771],
       [5083, 1415, 2169],
       ...,
       [3429, 4940, 6444],
       [2988, 7767, 3557],
       [ 824, 4563,  413]])

In [None]:
retrieved_kvs = db[ids]
retrieved_kvs.shape

(8192, 3, 2, 10)

### Remove / Clear / Database management

In [None]:
# 5120
# 10 segments of 512 tokens

In [None]:

class KNN():
    def __init__(
        self,
        dim,
        max_memories,
        ):
        self.dim = dim
        self.max_memories = max_memories
        self.shape = (max_memories, 2, dim)
        self.db_offset = 0
        self.db_filepath = "./memory.memmap"
        self.db = np.memmap(self.db_filepath, mode = 'w+', dtype = np.float32, shape = self.shape)
        self.index = faiss.IndexFlatL2(dim)


    def add_to_db(self, new_data):
        new_data_len = new_data.shape[0]
        ids = (np.arange(new_data_len) + self.db_offset)
        self.db[ids] = new_data
        self.db_offset += new_data_len
        # Write to file
        self.db.flush()


    def search_and_retrieve(self, query_vecs, topk):
        query_vecs = query_vecs
        distances, indices = self.index.search(query_vecs, topk)
        kvs = self.db[indices]
        return kvs

    def add(self, new_data):
        # Input is b n 2 d, flatten to (b n) 2 d
        new_data = new_data.flatten(0,1)
        # Add to db
        self.add_to_db(new_data)
        # Only keys are used in knn index
        keys, vals = new_data.unbind(dim=-2)
        # Add (b n) d tensors to index
        keys = np.ascontiguousarray(keys.numpy())
        # Add to index
        self.index.add(keys)

    def search(self, query_vecs, topk):
        # can override topk
        query_batch_size, query_seq_len = query_vecs.shape[0], query_vecs.shape[1]
        # Input is b n d, flatten to (b n) d
        query_vecs = query_vecs.flatten(0,1)
        kvs = self.search_and_retrieve(np.ascontiguousarray(query_vecs.numpy()), topk)
        # kvs are (b n) k 2 d, unflatten to b n k 2 d
        kvs = torch.tensor(kvs)
        kvs = torch.unflatten(kvs, 0, (query_batch_size, query_seq_len))
        return kvs

    def clear(self): #empties out database
        self.index.reset()
        self.db[:] = 0
        self.db_offset = 0


In [None]:
#separate memory for each batch dimension

In [None]:
#testing

batch_size = 16
dim = 10
segments = 10
seq_len = 512
max_memories = batch_size * seq_len * segments

knn = KNN(dim=dim, max_memories=max_memories)

In [None]:
kv = torch.randn(batch_size, seq_len, 2, dim) # b t 2 (hd)
query = torch.randn(batch_size, seq_len, dim) # b t (hd)

In [None]:
knn.add(kv)

In [None]:
knn.index.ntotal

8192

In [None]:
knn.db[8191]

memmap([[-0.09684315,  1.2812485 ,  0.05524765,  1.4415232 , -1.2765571 ,
          0.2632024 ,  1.8022029 ,  1.6431124 , -1.8025988 , -0.2768691 ],
        [ 0.10737767, -1.9741173 , -1.0190316 , -0.25465897, -1.1085768 ,
          0.2854817 ,  0.4937839 ,  0.67907417,  0.7290921 ,  0.44400656]],
       dtype=float32)

In [None]:
knn.db[8192]

memmap([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [None]:
retrieved_kvs = knn.search(query, 3)

In [None]:
retrieved_kvs.shape

torch.Size([16, 512, 3, 2, 10])

### KNN ATTENTION CLASS


In [None]:
class MHAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 32,
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)


    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        # Separate  into heads for multi-head attention
        k = keys.reshape(batch_size, sequence_length, self.heads, head_dimension)
        q = queries.reshape(batch_size, sequence_length, self.heads, head_dimension)
        v = values.reshape(batch_size, sequence_length, self.heads, head_dimension)

        # Swap head and sequence length dimensions
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)
        # Rearrange keys to prepare for matrix multiplication q@k
        k = k.transpose(2,3)

        # QK
        qk = q@k

        qk = qk * self.scale

        ############
        # TODO
        # qk = relative_position_values + qk
        ############

        i, j = qk.shape[-2:]
        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        qkv = qk@v
        qkv = qkv.transpose(1,2)
        qkv = qkv.reshape(batch_size, sequence_length, self.heads * head_dimension)

        ############
        # TODO
        # KNN Memory
        ############

        out = self.output_matrix(qkv)

        return out

In [None]:
# make sure q is (b t (hd)) for searching in knn (reshape and transpose)
# knn returns (b t k 2 (hd))
# split to key and value each size (b n k (hd)) (unbind)
# convert k and v to (b t k h d) (reshape)
# change q to (b h t d) (transpose)
# change k to (b h t d k) (multiple transpose)
# change v to (b h t k d) (multiple transpose)
# get qk of  (b h t d) @ (b h t d k) -> (b h t k)
# get qkv of (b h t k) @ (b h t k d) -> (b h t d)
# .....

In [None]:
number_heads = 8
head_dimension = 10
q = torch.randn(batch_size, seq_len, number_heads * head_dimension)
k = torch.randn(batch_size, seq_len, number_heads * head_dimension)
k.shape

torch.Size([16, 512, 80])

In [None]:
# Manually

# Separate queries matrix into heads for multi-head attention
q = q.reshape(batch_size, seq_len, number_heads, head_dimension)
# Rearrange indices to prepare for matrix multiplication q@k
q = q.transpose(1,2)
# Separate keys matrix into heads for multi-head attention
k = k.reshape(batch_size, seq_len, number_heads, head_dimension)
# Rearrange indices to prepare for matrix multiplication q@k
k = k.permute(0,2,3,1)

manual_qk = q@k

print ("queries:", q.shape)
print ("keys:", k.shape)
print ("qk:", manual_qk.shape)

queries: torch.Size([16, 8, 512, 10])
keys: torch.Size([16, 8, 10, 512])
qk: torch.Size([16, 8, 512, 512])


### Einops

In [None]:
!pip install einops
from einops import rearrange, repeat, pack, unpack, einsum



In [None]:
q = torch.randn(batch_size, seq_len, number_heads * head_dimension)
k = torch.randn(batch_size, seq_len, number_heads * head_dimension)
k.shape

torch.Size([16, 512, 80])

In [None]:
# With einsum
q =  rearrange(q, 'b t (h d) -> b h t d', h = number_heads)
k =  rearrange(k, 'b t (h d) -> b h t d', h = number_heads)
qk = einsum(q, k, 'b h i d, b h j d -> b h i j')

print ("queries:", q.shape)
print ("keys:", k.shape)
print ("qk:", qk.shape)

queries: torch.Size([16, 8, 512, 10])
keys: torch.Size([16, 8, 512, 10])
qk: torch.Size([16, 8, 512, 512])


In [None]:
class KNNAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 32,
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)


    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        keys    = rearrange(keys, 'b t (h d) -> b h t d', h = self.heads)
        qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        qk = qk * self.scale

        ############
        # TODO
        # qk = relative_position_values + qk
        ############

        i, j = qk.shape[-2:]
        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk@values
        qkv = rearrange(qkv, 'b h t d -> b t (h d)')

        ############
        # TODO
        # KNN Memory
        ############

        out = self.output_matrix(qkv)

        return out


In [None]:
queries = torch.randn(batch_size, number_heads, seq_len, head_dimension)
mem_kv = torch.randn(batch_size, seq_len, 3, 2, number_heads*head_dimension)
scale = 1

In [None]:
queries = rearrange(queries, 'b h t d -> b t (h d)')
queries.shape

torch.Size([16, 512, 80])

In [None]:
# mem_kv = knn.search(queries, topk)
mem_k, mem_v = mem_kv.unbind(dim = -2)
mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=number_heads)
mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=number_heads)
mem_v.shape

torch.Size([16, 8, 512, 3, 10])

In [None]:
queries = rearrange(queries, 'b t (h d) -> b h t d', h=number_heads)
mem_qk = einsum(queries, mem_k, 'b h t d, b h t k d -> b h t k') # d dimension
mem_qk.shape

torch.Size([16, 8, 512, 3])

In [None]:
mem_qk = mem_qk * scale

In [None]:
mem_qk = F.softmax(mem_qk, dim=-1)
mem_qkv = einsum(mem_qk, mem_v, 'b h t k, b h t k d -> b h t d') # k dimension
mem_qkv.shape

torch.Size([16, 8, 512, 10])

In [None]:
# gate between 0 and 1
gate = nn.Parameter(torch.randn(number_heads, 1, 1))
combined_qkv = (mem_qkv * gate) + (qkv * (1 - gate))
out = output_matrix(combined_qkv)