In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

In [2]:
from typing import List, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import time 
from torch import Tensor, einsum
from einops import parse_shape, rearrange, repeat, reduce

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
from brainle.models.architectures.knn import KNNBlock

block = KNNBlock(
    features = 64
)

# Insertion 
start = time.time()
for i in range(1000):
    block.push(torch.rand(10_000, 64))
print(f"Inserted 10M elements in {time.time() - start}s", )       

Inserted 10M elements in 2.7754080295562744s


In [4]:
# Search 
x = torch.rand(1000, 64)
start = time.time()
out = block(x, k=5)
print(f"Searched 1K elements in {time.time() - start}s", )       
print(out['embedding'].shape)

Searched 1K elements in 5.253890037536621s
torch.Size([1000, 5, 64])


In [5]:
# Comparison with (smaller) full matrix product 
start = time.time()
(torch.rand([5_000_000, 64]) @ torch.rand([64, 1000])).shape
print(time.time() - start)

11.796071290969849
