In [24]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

In [25]:
from typing import List

import time 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import Tensor, einsum
from einops import parse_shape, rearrange, repeat

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_parameters_all(model: nn.Module):
    return sum(p.numel() for p in model.parameters())

In [3]:
from brainle.models.architectures.attention import AttentionBase, SABlock, RABlock, DMABlock, FeedForwardBlock, TransformerBlock, PatcherBlock, UnpatcherBlock, ConvTention, ConvTeNet
        
att = AttentionBase(
    in_features = 12,
    out_features = 24,
    num_heads = 4,
)
q = torch.rand(2, 10, 12)
k = torch.rand(2, 20, 12)
v = torch.rand(2, 20, 24)
print(att(q, k, v).shape)
print(f"Params: {count_parameters(att)}")

torch.Size([2, 10, 24])
Params: 600


In [4]:
block = SABlock(
    in_features = 12,
    out_features = 24,
    num_heads = 4
)

out = block(torch.rand(2, 10, 12))
print(out.shape)        
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 24])
Params: 1176


In [5]:
block = RABlock(
    in_tokens = 10,
    out_tokens = 5,
    in_features = 12,
    out_features = 24,
    num_heads = 4
)

out = block(torch.rand(2, 10, 12))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 5, 24])
Params: 1092


In [6]:
block = DMABlock(
    memory_size = 512,
    in_features = 12,
    out_features = 24,
    num_heads = 4
)

out = block(torch.rand(2, 10, 12))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 24])
Params: 19176


In [7]:
block = FeedForwardBlock(
    features = 512,
    multiplier = 4,
    dropout = 0.1
)
out = block(torch.rand(2, 10, 512))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 512])
Params: 2100736


In [8]:
block = TransformerBlock(
    features = 256,
    num_heads = 2,
    dropout_attention = 0.1,
    dropout_mlp = 0.1,
    mlp_multiplier = 4
)
out = block(torch.rand(2, 10, 256))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 256])
Params: 788992


In [9]:
patcher = PatcherBlock(
    kernel_size = 4,
    stride = 2,
    padding = 1
)

unpatcher = UnpatcherBlock(
    kernel_size = 4,
    stride = 2,
    padding = 1
)
x = torch.tensor([[ [1,1,1], [2,2,2], [3,3,3], [4,4,4], [5,5,5], [6,6,6] ]]).float()
print(x, x.shape)
x = patcher(x)
print(x, x.shape)
x = unpatcher(x)
print(x, x.shape)

tensor([[[1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.],
         [4., 4., 4.],
         [5., 5., 5.],
         [6., 6., 6.]]]) torch.Size([1, 6, 3])
tensor([[[[0., 0., 0.],
          [1., 1., 1.],
          [2., 2., 2.],
          [3., 3., 3.]],

         [[2., 2., 2.],
          [3., 3., 3.],
          [4., 4., 4.],
          [5., 5., 5.]],

         [[4., 4., 4.],
          [5., 5., 5.],
          [6., 6., 6.],
          [0., 0., 0.]]]]) torch.Size([1, 3, 4, 3])
tensor([[[1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.],
         [4., 4., 4.],
         [5., 5., 5.],
         [6., 6., 6.]]]) torch.Size([1, 6, 3])


In [10]:
encode = ConvTention(
    in_features = 256,
    out_features = 256,
    num_heads = 8,
    num_layers = 4,
    out_patch_tokens = 2,
    kernel_size = 4,
    stride = 4,
    padding = 0,
    memory_size = 512,
    dropout = 0.1
)
decode = ConvTention(
    in_features = 256,
    out_features = 256,
    num_heads = 8,
    num_layers = 4,
    out_patch_tokens = 4,
    kernel_size = 2,
    stride = 2,
    padding = 0,
    memory_size = 512,
    dropout = 0.1
)

out = encode(torch.rand(2, 1024, 256))
print(out.shape)
out = decode(out)
print(out.shape)
print(f"Params: {count_parameters(encode)}")

torch.Size([2, 512, 256])
torch.Size([2, 1024, 256])
Params: 3747840


In [11]:
net = ConvTeNet(
    vocabulary_size = 800,
    embedding_dim = 256,
    num_layers = 7,
    num_heads = 8,
    use_skip = True
)

x = torch.randint(low=0, high=800, size=(2, 2048))
out = net(x)
print(out.shape)
print(f"Params: {count_parameters(net)}")

torch.Size([2, 2048, 800])
Params: 47382816


In [12]:
from brainle.models.architectures.attention import KVMemory
  
# Build memory 
memory = KVMemory(k_features=3, v_features=2, memory_size=6, items_per_query=4)

# Inseart and search
k, v = torch.tensor([ [1,1,1], [2,2,2], [3,3,3], [4,4,4]]).to(torch.float), torch.tensor([[1,1],[2,2],[3,3], [4,4]]).to(torch.float)
memory.insert(k, v)
q = torch.tensor([[1,1,1], [0,0,0]]).to(torch.float)
k, v = memory(q)
print("Memory:",memory.k_memory)
print("K,V", k,v)

# Insert again (notice that it's FIFO)
k, v = torch.tensor([[5,5,5], [6,6,6], [7,7,7], [8,8,8]]).to(torch.float), torch.tensor([[5,5],[6,6],[7,7],[8,8]]).to(torch.float)
memory.insert(k, v)
q = torch.tensor([[1,1,1],[8,8,8]]).to(torch.float)
k, v = memory(q)
print("Memory:",memory.k_memory)
print("K,V", k,v)

# Check state dict stores memory 
file = './memory.pt'
torch.save(memory.state_dict(), file)
memory = KVMemory(k_features=3, v_features=2, memory_size=6, items_per_query=4)
memory.load_state_dict(torch.load(file))
# Query should return same result 
q = torch.tensor([[1,1,1],[8,8,8]]).to(torch.float)
k, v = memory(q)
print("Memory:",memory.k_memory)
print("K,V", k,v)
os.remove(file)


Memory: tensor([[0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.],
        [4., 4., 4.]])
K,V tensor([[4., 4., 4.],
        [3., 3., 3.],
        [2., 2., 2.],
        [1., 1., 1.],
        [0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.],
        [0., 0., 0.]]) tensor([[4., 4.],
        [3., 3.],
        [2., 2.],
        [1., 1.],
        [0., 0.],
        [1., 1.],
        [2., 2.],
        [0., 0.]])
Memory: tensor([[3., 3., 3.],
        [4., 4., 4.],
        [5., 5., 5.],
        [6., 6., 6.],
        [7., 7., 7.],
        [8., 8., 8.]])
K,V tensor([[8., 8., 8.],
        [7., 7., 7.],
        [6., 6., 6.],
        [5., 5., 5.],
        [8., 8., 8.],
        [7., 7., 7.],
        [6., 6., 6.],
        [5., 5., 5.]]) tensor([[8., 8.],
        [7., 7.],
        [6., 6.],
        [5., 5.],
        [8., 8.],
        [7., 7.],
        [6., 6.],
        [5., 5.]])
Memory: tensor([[3., 3., 3.],
        [4., 4., 4.],
        [5., 5

In [13]:
memory = KVMemory(k_features=256, v_features=128, memory_size=100_000, items_per_query=16)

start = time.time() 
k, v = memory(torch.rand(300, 256))
print(time.time() - start)
print(k.shape, v.shape)

0.03985595703125
torch.Size([4800, 256]) torch.Size([4800, 128])


In [14]:
from brainle.models.architectures.attention import MABlock
  
block = MABlock(
    in_features = 512,
    out_features = 256,
    num_heads = 8,
    memory_size = 50_000,
    memory_items_per_query = 16
)

out = block(torch.rand(32, 10, 512))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([32, 10, 256])
Params: 852224


In [15]:
from brainle.models.architectures.attention import MemoformerBlock

block = MemoformerBlock(
    features = 256,
    num_heads = 2,
    dropout_attention = 0.1,
    dropout_mlp = 0.1,
    mlp_multiplier = 4,
    memory_size = 50_000,
    memory_items_per_query = 16
)
out = block(torch.rand(2, 10, 256))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 256])
Params: 854528


In [16]:
from brainle.models.architectures.attention import ConvMemoTention

encode = ConvMemoTention(
    in_features = 256,
    out_features = 256,
    num_heads = 8,
    num_layers = 4,
    out_patch_tokens = 2,
    kernel_size = 4,
    stride = 4,
    padding = 0,
    memory_size = 50_000,
    memory_items_per_query = 16,
    dropout = 0.1
)
decode = ConvMemoTention(
    in_features = 256,
    out_features = 256,
    num_heads = 8,
    num_layers = 4,
    out_patch_tokens = 4,
    kernel_size = 2,
    stride = 2,
    padding = 0,
    memory_size = 50_000,
    memory_items_per_query = 16,
    dropout = 0.1
)

out = encode(torch.rand(2, 1024, 256))
print(out.shape)
out = decode(out)
print(out.shape)
print(f"Params: {count_parameters(encode)}")

torch.Size([2, 512, 256])
torch.Size([2, 1024, 256])
Params: 3616512


In [20]:
from brainle.models.architectures.attention import ConvMeNet

net = ConvMeNet(
    vocabulary_size = 800,
    embedding_dim = 256,
    num_layers = 7,
    num_heads = 8,
    num_attention_layers = 4,
    window_size = 4,
    use_skip = True,
    memory_size = 50_000,
    memory_items_per_query = 8,
)

x = torch.randint(low=0, high=800, size=(2, 2048))
out = net(x)
print(out.shape)
print(f"Params: {count_parameters(net)}")

torch.Size([2, 2048, 800])
Params: 51052832


In [34]:
import faiss
import faiss.contrib.torch_utils

class KVMemory(nn.Module):

    """Key value memory with FIFO replacement strategy."""

    def __init__(
        self, k_features: int, v_features: int, memory_size: int, items_per_query: int
    ):
        super().__init__()
        self.k_features = k_features
        self.v_features = v_features
        self.memory_size = memory_size
        self.items_per_query = items_per_query
        # Initialize index for KNN search and memory
        if torch.cuda.is_available():
            index_cpu = faiss.IndexFlatIP(k_features)
            self.index = faiss.index_cpu_to_all_gpus(index_cpu)
        else:
            self.index = faiss.IndexFlatIP(k_features)
        self.index.add(np.zeros((memory_size, k_features), dtype="float32"))
        self.register_buffer("k_memory", torch.zeros(memory_size, k_features))
        self.register_buffer("v_memory", torch.zeros(memory_size, v_features))

    def insert(self, k: Tensor, v: Tensor):
        (m, kd), (mv, vd) = k.shape, v.shape
        assert m == mv, "Expected same number of keys and values"
        assert m <= self.memory_size, "More items inserted than memory size"
        assert kd == self.k_features, "Expected k of shape [m, k_features]"
        assert vd == self.v_features, "Expected v of shape [m, v_features]"
        # Update memory (with FIFO strategy)
        self.k_memory = torch.cat([self.k_memory[m:], k.detach()])
        self.v_memory = torch.cat([self.v_memory[m:], v.detach()])
        # Update index
        self.index.remove_ids(np.arange(m))
        self.index.add(k)

    def forward(self, q: Tensor):
        """Parses memory with query and returns keys, values."""
        # Dimensionality check
        n, d = q.shape
        assert d == self.k_features, f"Expected tensor of shape [n, k_features]"
        # KNN search into index with `items_per_query` neighbors
        i = self.items_per_query
        distances, indices, embedding = self.index.search_and_reconstruct(q, i)
        # Move to torch and same device
        # distances = torch.tensor(distances).to(q)
        # embedding = torch.tensor(embedding).to(q)
        indices = torch.tensor(indices).to(q)
        # Extract keys and values from memory
        indices = rearrange(indices.to(torch.long), "n i -> (n i)")
        k = self.k_memory[indices]
        v = self.v_memory[indices]
        # assert torch.all(k.eq(rearrange(embedding, 'n i d -> (n i) d'))), 'Index/memory mismatch.'
        return k, v

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        k_memory_numpy = self.k_memory.cpu().numpy()
        # Update to index
        self.index.remove_ids(np.arange(self.memory_size))
        self.index.add(k_memory_numpy)
        
        
        
# Build memory 
memory = KVMemory(k_features=3, v_features=2, memory_size=6, items_per_query=4)

# Inseart and search
k, v = torch.tensor([ [1,1,1], [2,2,2], [3,3,3], [4,4,4]]).to(torch.float), torch.tensor([[1,1],[2,2],[3,3], [4,4]]).to(torch.float)
memory.insert(k, v)
q = torch.tensor([[1,1,1], [0,0,0]]).to(torch.float)
k, v = memory(q)
print("Memory:",memory.k_memory)
print("K,V", k,v)

Memory: tensor([[0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.],
        [4., 4., 4.]])
K,V tensor([[4., 4., 4.],
        [3., 3., 3.],
        [2., 2., 2.],
        [1., 1., 1.],
        [0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.],
        [0., 0., 0.]]) tensor([[4., 4.],
        [3., 3.],
        [2., 2.],
        [1., 1.],
        [0., 0.],
        [1., 1.],
        [2., 2.],
        [0., 0.]])


  indices = torch.tensor(indices).to(q)
