In [1]:
import time
import numpy as np
from snnpy import *
from sklearn.preprocessing import Normalizer
from sklearn.neighbors import BallTree
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree
from bf_search import *

### glove

In [2]:
glo_train = np.load("Angular_data/glove/train.npy")
glo_query = np.load("Angular_data/glove/queries.npy")


glo_kdtree_index_timing = list()
glo_balltree_index_timing = list()
glo_snn_index_timing = list()

glo_bf_run_timing1 = list()
glo_bf_run_timing2 = list()
glo_kdtree_run_timing = list()
glo_balltree_run_timing = list()
glo_snn_run_timing = list()


In [3]:
transformer = Normalizer().fit(glo_train)
glo_train = transformer.transform(glo_train)
glo_query = transformer.transform(glo_query)


In [4]:
radius = np.array([0.3, 0.31, 0.32, 0.33, 0.34])*np.pi # test various radius

In [5]:
# Brute force 1
for R in radius:
    st = time.time()
    neigh = NearestNeighbors(radius=R, algorithm='brute')
    neigh.fit(glo_train)
    for j in range(glo_query.shape[0]): 
        ind = neigh.radius_neighbors(
           glo_query[j:j+1], radius=R, return_distance=False
        )
    et = time.time() - st
    glo_bf_run_timing1.append(et)
    print("brute force 1 query time:", et)
print()

# Brute force 2
for R in radius:
    st = time.time()
    for j in range(glo_query.shape[0]): 
        ind = bf_radius_fairness(glo_query[j], glo_train, R, return_distance=False)
    et = time.time() - st
    glo_bf_run_timing2.append(et)
    print("brute force 2 query time:", et)
print()


    
# KDtree
st = time.time()
tree = KDTree(glo_train)   
et = time.time() - st
glo_kdtree_index_timing.append(et)
print("kdtree index time:", et)
    
for R in radius:
    st = time.time()
    for j in range(glo_query.shape[0]): 
        ind = tree.query_radius(glo_query[j:j+1], r=R, return_distance=False)
    et = time.time() - st
    glo_kdtree_run_timing.append(et)
    print("kdtree query time:", et)
print()


# Balltree
st = time.time()
tree = BallTree(glo_train)   
et = time.time() - st
glo_balltree_index_timing.append(et)
print("ball tree index time:", et)

for R in radius:
    st = time.time()
    for j in range(glo_query.shape[0]): 
        ind = tree.query_radius(glo_query[j:j+1], r=R, return_distance=False)
    et = time.time() - st
    glo_balltree_run_timing.append(et)
    print("ball tree query time:", et)
print()

# SNN
st = time.time()
snn = build_snn_model(glo_train)
et = time.time() - st
glo_snn_index_timing.append(et)
print("snn index time:", et)


for R in radius:
    st = time.time()
    for j in range(glo_query.shape[0]): 
        ind = snn.query_radius(glo_query[j], R)
    et = time.time() - st
    glo_snn_run_timing.append(et)
    print("snn query time:", et)

brute force 1 query time: 5169.211848974228
brute force 1 query time: 5140.920110702515
brute force 1 query time: 5147.184922456741
brute force 1 query time: 5201.149308681488
brute force 1 query time: 5220.306466817856

brute force 2 query time: 1272.5856957435608
brute force 2 query time: 1268.7173070907593
brute force 2 query time: 1267.8216750621796
brute force 2 query time: 1265.3503448963165
brute force 2 query time: 1278.3153648376465

kdtree index time: 41.20866847038269
kdtree query time: 6715.308529138565
kdtree query time: 6732.466593742371
kdtree query time: 6705.652412414551
kdtree query time: 6748.568563699722
kdtree query time: 6745.901397943497

ball tree index time: 39.80410122871399
ball tree query time: 5674.9054915905
ball tree query time: 5617.704619169235
ball tree query time: 5648.804111242294
ball tree query time: 5609.530064821243
ball tree query time: 5621.5708796978

snn index time: 1.5493049621582031
snn query time: 783.789053440094
snn query time: 794.70098

In [6]:
with open('result/glo_kdtree_index_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(glo_kdtree_index_timing))

with open('result/glo_balltree_index_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(glo_balltree_index_timing))

with open('result/glo_snn_index_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(glo_snn_index_timing))

with open('result/glo_bf_run_timing1_norm.npy', 'wb') as f:
    np.save(f, np.array(glo_bf_run_timing1))
    
with open('result/glo_bf_run_timing2_norm.npy', 'wb') as f:
    np.save(f, np.array(glo_bf_run_timing2))
    
with open('result/glo_kdtree_run_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(glo_kdtree_run_timing))

with open('result/glo_balltree_run_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(glo_balltree_run_timing))

with open('result/glo_snn_run_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(glo_snn_run_timing))
    
