In [42]:
import numpy as np
import faiss
from tqdm import tqdm
import time

In [29]:
key_sizes = [1000, 10000, 50000, 100000, 1000000]
query_size = 1000
num_neighbours = 5

In [50]:
np.random.seed(0)
keys = np.random.standard_normal((key_sizes[0], 4 * 64 * 64)).astype(
    np.float16, copy=False
)
queries = np.random.standard_normal((query_size, 4 * 64 * 64)).astype(
    np.float16, copy=False
)

In [51]:
index = faiss.IndexFlatL2(4 * 64 * 64)
index.add(keys)
print("Index size:", index.ntotal)

Index size: 1000


In [None]:
avg_time = 0
reference_negibours = []
for query in tqdm(queries):
    query_flat = query.flatten().reshape(1, -1)
    time_start = time.time()
    D, ids = index.search(query_flat, num_neighbours)
    nearest_ids = ids[0]
    time_end = time.time()
    avg_time += (time_end - time_start) / 1000
    reference_negibours.append(nearest_ids)
print("Average search time:", avg_time)
print("Reference neighbours:", np.array(reference_negibours))

100%|██████████| 1000/1000 [00:02<00:00, 356.73it/s]

Average search time: 0.002718142032623291
Reference neighbours: [620 496 663 ... 820 239 200]





In [None]:
for size in sizes:
    keys = np.random.standard_normal((size, 4 * 64 * 64)).astype(np.float16)
    print(f"keys.shape: {keys.shape}")
    print(f"keys.dtype: {keys.dtype}")
    print(f"keys.nbytes: {keys.nbytes}")
    print(f"keys.itemsize: {keys.itemsize}")
    print(f"keys.nbytes / keys.itemsize: {keys.nbytes / keys.itemsize}")
    print(f"keys.nbytes / keys.shape[0]: {keys.nbytes / keys.shape[0]}")
    print()