In [1]:
import time
import numpy as np
from snnpy import *
from sklearn.neighbors import BallTree
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree
from bf_search import *

In [2]:
def bvecs_read(fname):
    a = np.fromfile(fname, dtype=np.int32, count=1)
    b = np.fromfile(fname, dtype=np.uint8)
    d = a[0]
    return b.reshape(-1, d + 4)[:, 4:].copy()


def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy()


def fvecs_read2(fname):
    return ivecs_read(fname).view('float32')

### gist

In [3]:
gist_train = np.load("Euclidean_data/gist/train.npy")
gist_query = np.load("Euclidean_data/gist/queries.npy")

# mu = gist_train.mean(axis=0)
# scl = gist_train.std(axis=0)
# gist_train = (gist_train - mu) / scl
# gist_query = (gist_query - mu) / scl

gist_kdtree_index_timing = list()
gist_balltree_index_timing = list()
gist_sn_index_timing = list()

gist_bf_run_timing1 = list()
gist_bf_run_timing2 = list()
gist_kdtree_run_timing = list()
gist_balltree_run_timing = list()
gist_sn_run_timing = list()

In [4]:
radius = [0.8, 0.85, 0.9, 0.95, 1] # test various radius

In [5]:
# Brute force 1
for R in radius:
    st = time.time()
    neigh = NearestNeighbors(radius=R, algorithm='brute')
    neigh.fit(gist_train)
    for j in range(gist_query.shape[0]): 
        ind = neigh.radius_neighbors(
           gist_query[j:j+1], radius=R, return_distance=True
        )
    et = time.time() - st
    gist_bf_run_timing1.append(et)
    print("brute force 1 query time:", et)
print()
    

# Brute force 2
for R in radius:
    st = time.time()
    for j in range(gist_query.shape[0]): 
        ind = bf_radius_fairness(gist_query[j], gist_train, R, return_distance=False)
    et = time.time() - st
    gist_bf_run_timing2.append(et)
    print("brute force 2 query time:", et)
print()
    
    
# KDtree
st = time.time()
tree = KDTree(gist_train)   
et = time.time() - st
gist_kdtree_index_timing.append(et)
print("kdtree build time:", et)
for R in radius:
    st = time.time()
    for j in range(gist_query.shape[0]):
        ind = tree.query_radius(gist_query[j:j+1], r=R, return_distance=False)
    et = time.time() - st
    gist_kdtree_run_timing.append(et)
    print("kdtree query time:", et)
print()


# Balltree
st = time.time()
tree = BallTree(gist_train)   
et = time.time() - st
gist_balltree_index_timing.append(et)
print("balltree build time:", et)

for R in radius:
    st = time.time()
    for j in range(gist_query.shape[0]):
        ind = tree.query_radius(gist_query[j:j+1], r=R, return_distance=False)
    et = time.time() - st
    gist_balltree_run_timing.append(et)
    print("balltree query time:", et)

print()


# SNN
st = time.time()
snn = build_snn_model(gist_train)
et = time.time() - st
gist_sn_index_timing.append(et)
print("snn build time:", et)


for R in radius:
    st = time.time()
    for j in range(gist_query.shape[0]):
        ind = snn.query_radius(gist_query[j], R)
    et = time.time() - st
    gist_sn_run_timing.append(et)
    print("snn query time:", et)

brute force 1 query time: 3954.7653596401215
brute force 1 query time: 3965.736777305603
brute force 1 query time: 3941.2481956481934
brute force 1 query time: 3817.1694226264954
brute force 1 query time: 3758.8757162094116

brute force 2 query time: 862.1556403636932
brute force 2 query time: 861.3777701854706
brute force 2 query time: 861.5746960639954
brute force 2 query time: 861.7276568412781
brute force 2 query time: 861.3657484054565

kdtree build time: 319.4303812980652
kdtree query time: 3144.453806400299
kdtree query time: 3182.3090012073517
kdtree query time: 3205.856163740158
kdtree query time: 3223.3404545783997
kdtree query time: 3237.025584459305

balltree build time: 297.88309502601624
balltree query time: 2159.7659151554108
balltree query time: 2164.154891014099
balltree query time: 2171.104510784149
balltree query time: 2178.4723556041718
balltree query time: 2183.157583475113

snn build time: 29.14320969581604
snn query time: 281.51042342185974
snn query time: 293.92

In [6]:
with open('result/gist_kdtree_index_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(gist_kdtree_index_timing))

with open('result/gist_balltree_index_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(gist_balltree_index_timing))

with open('result/gist_snn_index_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(gist_sn_index_timing))
    
with open('result/gist_bf_run_timing1_norm.npy', 'wb') as f:
    np.save(f, np.array(gist_bf_run_timing1))

with open('result/gist_bf_run_timing2_norm.npy', 'wb') as f:
    np.save(f, np.array(gist_bf_run_timing2))
    
with open('result/gist_kdtree_run_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(gist_kdtree_run_timing))

with open('result/gist_balltree_run_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(gist_balltree_run_timing))

with open('result/gist_snn_run_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(gist_sn_run_timing))
    
    