In [1]:
import time
import warnings
import numpy as np
from snnpy import *
from sklearn.neighbors import BallTree
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree
from bf_search import *
from tqdm import tqdm

leaf_size = 30
warnings.filterwarnings("ignore")

### dimensions

In [2]:
n_samples = 10000
unit_inc = 30
Rlist = [0.5, 2, 3.5, 5, 6.5]
rng = np.random.RandomState(0)


index_time_kd = list()
index_time_bl = list()
index_time_sn = list()

run_time_bf1 = list()
run_time_bf2 = list()
run_time_kd = list()
run_time_bl = list()
run_time_sn = list()

xrange = np.arange(2, 2 + 10*unit_inc, unit_inc)
for dim in tqdm(range(len(xrange))):
    qm = 0
    i = xrange[dim]
    data = rng.random_sample((2*n_samples, i)) 
    data = (data - data.mean(axis=0))/data.std(axis=0)
    X = data[:n_samples]
    Query = data[n_samples:]
    
    bf1et = 0
    bf2et = 0
    kd_id_et = 0
    kd_qy_et = 0 
    bl_id_et = 0
    bl_qy_et = 0
    sn_id_et = 0
    sn_qy_et = 0
    
    for R in Rlist:
        st = time.time()
        neigh = NearestNeighbors(radius=R, algorithm='brute')
        neigh.fit(X)
        for j in range(Query.shape[0]): 
            ind = neigh.radius_neighbors(
               Query[j:j+1], radius=R, return_distance=False
            )
        bf1et += time.time() - st

        st = time.time()
        for j in range(Query.shape[0]): 
            ind = bf_radius_fairness(Query[j], X, R, return_distance=False)
        bf2et += time.time() - st


        # KDtree
        st = time.time()
        kdtree = NearestNeighbors(radius=R, algorithm='kd_tree', leaf_size=leaf_size)  
        kdtree.fit(X)
        kd_id_et += time.time() - st
        
        st = time.time()
        for j in range(Query.shape[0]): 
            kdind = kdtree.radius_neighbors(Query[j:j+1], radius=R, return_distance=False) 
        kd_qy_et += time.time() - st
        
        # Balltree
        st = time.time()
        bltree = NearestNeighbors(radius=R, algorithm='ball_tree', leaf_size=leaf_size) 
        bltree.fit(X)
        bl_id_et += time.time() - st
        
        st = time.time()
        for j in range(Query.shape[0]): 
            blind = bltree.radius_neighbors(Query[j:j+1], radius=R, return_distance=False) 
        bl_qy_et += time.time() - st
        
        
        # snn
        st = time.time()
        snn = build_snn_model(X)
        sn_id_et += time.time() - st
        
        st = time.time()
        for j in range(Query.shape[0]): 
            sind = snn.query_radius(Query[j], R)
        sn_qy_et += time.time() - st
        
        
    print("brute force 1 query time:", bf1et/ len(Rlist))
    print("brute force 2 query time:", bf2et/ len(Rlist))
    print("KDtree index:", kd_id_et/ len(Rlist))
    print("KDtree query time:", kd_qy_et/ len(Rlist))
    print("Balltree index:", bl_id_et/ len(Rlist))
    print("Balltree query time:", bl_qy_et/ len(Rlist))
    print("SNN index:", sn_id_et/ len(Rlist))
    print("SNN query time:", sn_qy_et/ len(Rlist))
    
    index_time_kd.append( kd_id_et / len(Rlist))
    run_time_bf1.append(bf1et / len(Rlist))
    run_time_bf2.append(bf2et / len(Rlist))
    run_time_kd.append(kd_qy_et / len(Rlist))
    index_time_bl.append(bl_id_et / len(Rlist))
    run_time_bl.append( bl_qy_et / len(Rlist))
    index_time_sn.append(sn_id_et / len(Rlist))
    run_time_sn.append(sn_qy_et / len(Rlist))
    print()


 10%|█         | 1/10 [01:30<13:32, 90.25s/it]

brute force 1 query time: 8.53269100189209
brute force 2 query time: 1.272822904586792
KDtree index: 0.004906415939331055
KDtree query time: 3.7898655414581297
Balltree index: 0.0038990497589111326
Balltree query time: 3.6771608352661134
SNN index: 0.0025220870971679687
SNN query time: 0.764811372756958



 20%|██        | 2/10 [04:50<20:40, 155.07s/it]

brute force 1 query time: 12.444607543945313
brute force 2 query time: 4.175926685333252
KDtree index: 0.03530750274658203
KDtree query time: 12.847333812713623
Balltree index: 0.027465391159057616
Balltree query time: 8.932817459106445
SNN index: 0.0074117183685302734
SNN query time: 1.613000249862671



 30%|███       | 3/10 [10:32<28:03, 240.45s/it]

brute force 1 query time: 18.221064805984497
brute force 2 query time: 7.95151653289795
KDtree index: 0.0632136344909668
KDtree query time: 22.867919731140137
Balltree index: 0.04878444671630859
Balltree query time: 16.474570274353027
SNN index: 0.013378047943115234
SNN query time: 2.7625435829162597



 40%|████      | 4/10 [17:57<32:05, 320.99s/it]

brute force 1 query time: 22.65805697441101
brute force 2 query time: 10.594118690490722
KDtree index: 0.09270281791687011
KDtree query time: 31.195228576660156
Balltree index: 0.06921191215515136
Balltree query time: 20.643914461135864
SNN index: 0.02019810676574707
SNN query time: 3.6060754299163817



 50%|█████     | 5/10 [28:39<36:24, 436.84s/it]

brute force 1 query time: 31.71458215713501
brute force 2 query time: 15.39807448387146
KDtree index: 0.1293653964996338
KDtree query time: 43.7151563167572
Balltree index: 0.09855179786682129
Balltree query time: 31.892113733291627
SNN index: 0.027449417114257812
SNN query time: 5.461703586578369



 60%|██████    | 6/10 [42:27<37:59, 569.93s/it]

brute force 1 query time: 42.35680632591247
brute force 2 query time: 20.742682647705077
KDtree index: 0.15974302291870118
KDtree query time: 53.65612258911133
Balltree index: 0.12131915092468262
Balltree query time: 40.40652112960815
SNN index: 0.03767037391662598
SNN query time: 8.157477617263794



 70%|███████   | 7/10 [59:42<36:06, 722.00s/it]

brute force 1 query time: 51.908102130889894
brute force 2 query time: 28.730365562438966
KDtree index: 0.20204539299011232
KDtree query time: 63.77945642471313
Balltree index: 0.1554805278778076
Balltree query time: 49.96784505844116
SNN index: 0.04706859588623047
SNN query time: 12.206254959106445



 80%|████████  | 8/10 [1:20:12<29:27, 883.51s/it]

brute force 1 query time: 62.548076486587526
brute force 2 query time: 34.53612232208252
KDtree index: 0.2367098808288574
KDtree query time: 73.60799751281738
Balltree index: 0.18109521865844727
Balltree query time: 59.48930602073669
SNN index: 0.057553672790527345
SNN query time: 15.186445045471192



 90%|█████████ | 9/10 [1:43:15<17:19, 1039.66s/it]

brute force 1 query time: 67.84930491447449
brute force 2 query time: 39.375402736663816
KDtree index: 0.27951903343200685
KDtree query time: 83.27975225448608
Balltree index: 0.21146087646484374
Balltree query time: 67.6963939666748
SNN index: 0.06776084899902343
SNN query time: 17.810234832763673



100%|██████████| 10/10 [2:09:51<00:00, 779.13s/it] 

brute force 1 query time: 83.23849830627441
brute force 2 query time: 47.029201650619505
KDtree index: 0.3003885746002197
KDtree query time: 92.39968962669373
Balltree index: 0.230979061126709
Balltree query time: 74.60061655044555
SNN index: 0.08022751808166503
SNN query time: 21.314986276626588






In [3]:
### save the data

index_time_kd = np.array(index_time_kd)
index_time_bl = np.array(index_time_bl) 
index_time_sn = np.array(index_time_sn) 

run_time_bf1 = np.array(run_time_bf1)
run_time_bf2 = np.array(run_time_bf2)
run_time_kd = np.array(run_time_kd)
run_time_bl = np.array(run_time_bl)
run_time_sn = np.array(run_time_sn)

with open('result/dim/index_time_kd.npy', 'wb') as f:
    np.save(f, index_time_kd)

with open('result/dim/index_time_bl.npy', 'wb') as f:
    np.save(f, index_time_bl)

with open('result/dim/index_time_sn.npy', 'wb') as f:
    np.save(f, index_time_sn)
    
with open('result/dim/run_time_bf1.npy', 'wb') as f:
    np.save(f, run_time_bf1)

with open('result/dim/run_time_bf2.npy', 'wb') as f:
    np.save(f, run_time_bf2)
    
with open('result/dim/run_time_kd.npy', 'wb') as f:
    np.save(f, run_time_kd)

with open('result/dim/run_time_bl.npy', 'wb') as f:
    np.save(f, run_time_bl)
    
with open('result/dim/run_time_sn.npy', 'wb') as f:
    np.save(f, run_time_sn)
    