In [3]:
import faiss
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
dim = 64
index = faiss.IndexFlatL2(dim)

In [5]:
index

<faiss.swigfaiss.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x130035410> >

In [6]:
vector_data = torch.randn((10000, dim), dtype=torch.float32)

In [7]:
index.add(vector_data)

In [22]:
index.ntotal

10000

In [23]:
index.remove_ids(np.arange(10))

10

In [24]:
index.ntotal

9990

In [25]:
query_data = torch.randn((10, dim), dtype=torch.float64)

In [27]:
topk = 2
distance, ids = index.search(query_data, topk)

In [28]:
distance

array([[93.641426, 94.49693 ],
       [68.84827 , 69.310295],
       [70.8887  , 71.146706],
       [70.90154 , 74.26372 ],
       [65.32275 , 69.598114],
       [55.525856, 61.573452],
       [66.89134 , 71.34758 ],
       [58.615837, 64.854164],
       [61.58206 , 63.413498],
       [63.78425 , 68.54495 ]], dtype=float32)

In [29]:
ids

array([[8411, 6728],
       [7724, 5542],
       [7227, 9551],
       [8064,  213],
       [7146, 8178],
       [6320, 1132],
       [ 974, 5225],
       [ 615, 1005],
       [5187,  330],
       [5415, 5944]])

In [30]:
distance, ids, original = index.search_and_reconstruct(query_data, topk)

In [32]:
original.shape

(10, 2, 64)

In [33]:
# memmap for retrieving values from keys
db_filepath = "./memory.memmap"
max_memories = 10000
shape = (max_memories, topk, dim)
db = np.memmap(db_filepath, mode='w+', dtype=np.float32, shape=shape)

In [36]:
db[1:2] = np.random.rand(1, 2, dim)

In [37]:
db

memmap([[[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        [[0.4050407 , 0.35399655, 0.3129091 , ..., 0.5684382 ,
          0.1647372 , 0.05455185],
         [0.76077986, 0.3255453 , 0.92109597, ..., 0.52519727,
          0.8003544 , 0.69297594]],

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        ...,

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]

In [2]:
n_embed = 32
key_db = faiss.IndexFlatL2(n_embed)

In [3]:
mem_size = 100000
value_db = np.memmap("./shit.memmap", mode='w+', dtype=np.float32, shape=(mem_size, n_embed))

In [2]:
b = 16
t = 100
c = n_embed
key = torch.randn((b, t, c), dtype=torch.float32)
value = torch.randn((b, t, c), dtype=torch.float32)
total_offset = 0

NameError: name 'n_embed' is not defined

In [5]:
offset = key.shape[0] * key.shape[1]
offset

1600

In [6]:
ids = torch.arange(total_offset, total_offset + offset)
ids

tensor([   0,    1,    2,  ..., 1597, 1598, 1599])

In [7]:
key_db.add(key.flatten(0, 1))

In [8]:
key_db.ntotal

1600

In [9]:
value_db[ids] = value.flatten(0, 1)

In [10]:
value_db

memmap([[ 0.7134852 , -0.98897314,  1.1582507 , ..., -0.65572065,
         -0.38709787, -0.8826125 ],
        [-0.6550964 ,  1.6092925 ,  0.80030733, ...,  0.8556267 ,
          1.2299192 ,  0.04382041],
        [-1.2351928 ,  1.5633427 , -0.836643  , ...,  2.033948  ,
         -1.4471961 , -0.24193352],
        ...,
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ]], dtype=float32)

In [11]:
value_db.flush()

In [12]:
total_offset += offset

In [13]:
value_db[0]

memmap([ 0.7134852 , -0.98897314,  1.1582507 , -0.7211242 , -0.9199839 ,
        -0.8324986 , -0.42878613,  0.18412289,  1.3613797 ,  1.5028625 ,
        -1.9431353 ,  0.10089584, -0.56640095, -0.43664274,  0.1166168 ,
         1.8329364 ,  1.1545871 , -1.0967777 ,  0.30457202,  0.10313164,
         0.41399816,  0.21535783, -0.08292428,  0.6627333 , -1.6125681 ,
         0.1864786 ,  0.6997727 , -0.7951304 ,  0.25472143, -0.65572065,
        -0.38709787, -0.8826125 ], dtype=float32)

In [14]:
topk = 2
search_key = torch.randn((b,t,c), dtype=torch.float32)

In [15]:
_, ids, original = key_db.search_and_reconstruct(search_key.flatten(0,1), topk)

In [16]:
ids.shape

(1600, 2)

In [17]:
original.shape

(1600, 2, 32)

In [18]:
res_values = value_db[ids]
res_values.shape

(1600, 2, 32)

In [None]:
a = torch.tensor(original)

In [1]:
import faiss
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class Memory:

    def __init__(self, mem_size, n_embed, head_id) -> None:
        """
        There are 2 components of this transformer memory:
        - A vector db storing the keys of the data, supports similarity search
        - A memmap storing the values

        For each head, we will have a separate memory.
        """
        self.key_db = faiss.IndexFlatL2(n_embed)
        value_db_filepath = f"./memory_head_{head_id}.memmap"
        # Todo: How to decide the mem_size?
        self.value_db = np.memmap(value_db_filepath, mode='w+', dtype=np.float32, shape=(mem_size, n_embed))
        self.total_offset = 0

    def store(self, key, value):
        """
        key: (batch, sequence_length, embed_dim)  (B, T, C)
        value: (batch, sequence_length, embed_dim)  (B, T, C)
        """
        offset = key.shape[0] * key.shape[1]
        ids = torch.arange(self.total_offset, self.total_offset + offset)
        
        # Add the key to FAISS. Flatten the key to (batch * sequence_length, embed_dim)
        self.key_db.add(key.flatten(0, 1))

        # Add the value to the memmap
        self.value_db[ids] = value.flatten(0, 1)
        self.value_db.flush()
        self.total_offset += offset
        return ids

    def retrieve(self, query, topk=2):
        """
        query: (batch, sequence, embed_dim)  (B, T, C)
        """
        # matched_keys & matched_values: (batch * sequence_length, topk, embed_dim)
        _, ids, matched_keys = self.key_db.search_and_reconstruct(query.flatten(0, 1), topk)
        matched_values = self.value_db[ids]  # ndarray

        matched_keys = torch.from_numpy(matched_keys).unflatten(0, (query.shape[0], query.shape[1]))
        matched_values = torch.from_numpy(matched_values).unflatten(0, (query.shape[0], query.shape[1]))
        # returning (B, T, topk, C)  Todo: is this appropriate?
        return matched_keys, matched_values

    def clear(self):
        self.key_db.reset()
        self.value_db.flush()
        self.total_offset = 0


In [2]:
mem_size = 10000
n_embed = 32
head_id = 1
mem = Memory(mem_size, n_embed, head_id)

In [9]:
b = 16
t = 100
c = n_embed
key = torch.randn((b, t, c), dtype=torch.float32)
value = torch.randn((b, t, c), dtype=torch.float32)
query = torch.randn((b,t,c), dtype=torch.float32)

In [10]:
ids = mem.store(key, value)

In [11]:
ids

tensor([1600, 1601, 1602,  ..., 3197, 3198, 3199])

In [12]:
mem.key_db.ntotal

3200

In [13]:
mem.value_db

memmap([[ 0.3710723 , -0.00485805,  0.89239454, ...,  0.32404032,
          0.54582256,  0.26578963],
        [ 0.6924881 , -1.073629  ,  0.74572736, ..., -0.7672913 ,
         -0.086476  , -0.4977963 ],
        [-2.5152192 ,  0.5555816 ,  1.4097614 , ..., -1.2520555 ,
          0.7043623 , -0.5531658 ],
        ...,
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ]], dtype=float32)

In [15]:
memk, memv = mem.retrieve(query)

In [16]:
memk.shape

torch.Size([16, 100, 2, 32])

In [17]:
memv.shape

torch.Size([16, 100, 2, 32])

In [18]:
mem.clear()

In [19]:
mem.key_db.ntotal

0

In [20]:
mem.value_db

memmap([[ 0.3710723 , -0.00485805,  0.89239454, ...,  0.32404032,
          0.54582256,  0.26578963],
        [ 0.6924881 , -1.073629  ,  0.74572736, ..., -0.7672913 ,
         -0.086476  , -0.4977963 ],
        [-2.5152192 ,  0.5555816 ,  1.4097614 , ..., -1.2520555 ,
          0.7043623 , -0.5531658 ],
        ...,
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ]], dtype=float32)

In [54]:
b, t, c, topk = 2, 3, 4, 5
q = torch.randn(b, t, c)
k = torch.randn(b, t, c, topk)
v = torch.randn(b, t, c, topk)

In [55]:
q.shape

torch.Size([2, 3, 4])

In [56]:
k.shape

torch.Size([2, 3, 4, 5])

In [62]:
wei = torch.einsum('btc,btck->btk', q, k)
wei = F.softmax(wei, dim=-1)

In [63]:
wei.shape

torch.Size([2, 3, 5])

In [65]:
out = torch.einsum('btk,btck->btc', wei, v)

In [66]:
out.shape

torch.Size([2, 3, 4])

In [67]:
m = torch.randn(b, t, c)

In [82]:
gate_bias = nn.Parameter(torch.zeros((1)))

In [83]:
gate = torch.sigmoid(gate_bias)

In [84]:
gate.shape

torch.Size([1])

In [85]:
shit = m * gate + out * (1 - gate)

In [86]:
shit.shape

torch.Size([2, 3, 4])

In [13]:
# Transformer XL
import faiss
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

b = 32
t = 80
c = 8

In [14]:
xl_mem = torch.randn((b, t, 2, c), dtype=torch.float32)

In [15]:
xl_mem.shape

torch.Size([32, 80, 2, 8])

In [16]:
xl_keys, xl_values = xl_mem.unbind(dim=-2)

In [17]:
xl_keys.shape

torch.Size([32, 80, 8])

In [18]:
key = torch.randn((b, t, c), dtype=torch.float32)
value = torch.randn((b, t, c), dtype=torch.float32)
query = torch.randn((b,t,c), dtype=torch.float32)

In [19]:
query.shape

torch.Size([32, 80, 8])

In [20]:
augmented_key = torch.cat((xl_keys, key), dim=1)
augmented_key.shape

torch.Size([32, 160, 8])

In [21]:
augmented_value = torch.cat((xl_values, value), dim=1)
augmented_value.shape

torch.Size([32, 160, 8])

In [22]:
qk = query @ augmented_key.transpose(-2, -1)
qk.shape

torch.Size([32, 80, 160])

In [23]:
i, j = qk.shape[1:]
i, j = 3, 3

In [2]:
tensor = torch.ones((i, j), dtype=torch.bool).triu()

NameError: name 'torch' is not defined

In [1]:
shit = tensor.triu()

NameError: name 'tensor' is not defined