In [1]:
import time
import warnings
import numpy as np
from snnpy import *
from sklearn.neighbors import BallTree
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree
from bf_search import *
from tqdm import tqdm

leaf_size = 30
warnings.filterwarnings("ignore")

### size (ndim=2)

In [2]:
n_samples = 2000
unit_inc = 2000
n_dim =  2
Rlist = [0.02, 0.05, 0.08, 0.11, 0.14]


index_time_kd = list()
index_time_bl = list()
index_time_sn = list()

run_time_bf1 = list()
run_time_bf2 = list()
run_time_kd = list()
run_time_bl = list()
run_time_sn = list()

rng = np.random.RandomState(0)
xrange = np.arange(n_samples, n_samples + 10*unit_inc, unit_inc)
for size in tqdm(range(len(xrange))):
    i = xrange[size]
    data = rng.random_sample((2*i, n_dim)) 
    X = data[:i]
    Query = data[i:]
    
    bf1et = 0
    bf2et = 0
    kd_id_et = 0
    kd_qy_et = 0 
    bl_id_et = 0
    bl_qy_et = 0
    sn_id_et = 0
    sn_qy_et = 0
    
    for R in Rlist:
        st = time.time()
        neigh = NearestNeighbors(radius=R, algorithm='brute')
        neigh.fit(X)
        for j in range(Query.shape[0]): 
            ind = neigh.radius_neighbors(
               Query[j:j+1], radius=R, return_distance=False
            )
        bf1et += time.time() - st

        st = time.time()
        for j in range(Query.shape[0]): 
            ind = bf_radius_fairness(Query[j], X, R, return_distance=False)
        bf2et += time.time() - st

        # KDtree
        st = time.time()
        kdtree = NearestNeighbors(radius=R, algorithm='kd_tree', leaf_size=leaf_size)  
        kdtree.fit(X)
        kd_id_et += time.time() - st
        
        
        st = time.time()
        for j in range(Query.shape[0]): 
            kdind = kdtree.radius_neighbors(Query[j:j+1], radius=R, return_distance=False) 
        kd_qy_et += time.time() - st
        
        # Balltree
        st = time.time()
        bltree = NearestNeighbors(radius=R, algorithm='ball_tree', leaf_size=leaf_size) 
        bltree.fit(X)
        bl_id_et += time.time() - st
        
        st = time.time()
        for j in range(Query.shape[0]): 
            blind = bltree.radius_neighbors(Query[j:j+1], radius=R, return_distance=False) 
        bl_qy_et += time.time() - st
        
        # snn
        st = time.time()
        snn = build_snn_model(X)
        sn_id_et += time.time() - st
        
        st = time.time()
        for j in range(Query.shape[0]): 
            sind = snn.query_radius(Query[j], R)
        sn_qy_et += time.time() - st
        
        
    print("brute force 1 query time:", bf1et/ len(Rlist))
    print("brute force 2 query time:", bf2et/ len(Rlist))
    print("KDtree index:", kd_id_et/ len(Rlist))
    print("KDtree query time:", kd_qy_et/ len(Rlist))
    print("Balltree index:", bl_id_et/ len(Rlist))
    print("Balltree query time:", bl_qy_et/ len(Rlist))
    print("SNN index:", sn_id_et/ len(Rlist))
    print("SNN query time:", sn_qy_et/ len(Rlist))
    
    index_time_kd.append( kd_id_et / len(Rlist))
    run_time_bf1.append(bf1et / len(Rlist))
    run_time_bf2.append(bf2et / len(Rlist))
    run_time_kd.append( kd_qy_et / len(Rlist))
    index_time_bl.append( bl_id_et / len(Rlist))
    run_time_bl.append( bl_qy_et / len(Rlist))
    index_time_sn.append(sn_id_et / len(Rlist))
    run_time_sn.append(sn_qy_et/ len(Rlist))
    print()

 10%|█         | 1/10 [00:12<01:49, 12.22s/it]

brute force 1 query time: 0.892121410369873
brute force 2 query time: 0.07711148262023926
KDtree index: 0.0013301849365234375
KDtree query time: 0.7333163738250732
Balltree index: 0.0011136531829833984
Balltree query time: 0.6782093048095703
SNN index: 0.0009559154510498047
SNN query time: 0.05891270637512207



 20%|██        | 2/10 [00:38<02:42, 20.31s/it]

brute force 1 query time: 2.0887484550476074
brute force 2 query time: 0.22319679260253905
KDtree index: 0.002316188812255859
KDtree query time: 1.3855984687805176
Balltree index: 0.0015722751617431641
Balltree query time: 1.3602072238922118
SNN index: 0.0013588905334472657
SNN query time: 0.13034820556640625



 30%|███       | 3/10 [01:20<03:31, 30.14s/it]

brute force 1 query time: 3.530639886856079
brute force 2 query time: 0.44072113037109373
KDtree index: 0.002806997299194336
KDtree query time: 2.110543441772461
Balltree index: 0.0022358417510986326
Balltree query time: 2.070222187042236
SNN index: 0.0016164779663085938
SNN query time: 0.2099766254425049



 40%|████      | 4/10 [02:19<04:10, 41.79s/it]

brute force 1 query time: 5.270484685897827
brute force 2 query time: 0.7302399635314941
KDtree index: 0.0036507606506347655
KDtree query time: 2.8403017044067385
Balltree index: 0.0030890464782714843
Balltree query time: 2.7780088901519777
SNN index: 0.001802825927734375
SNN query time: 0.2979975700378418



 50%|█████     | 5/10 [03:39<04:37, 55.45s/it]

brute force 1 query time: 7.310607242584228
brute force 2 query time: 1.0969375610351562
KDtree index: 0.004346990585327148
KDtree query time: 3.6016801834106444
Balltree index: 0.0035769462585449217
Balltree query time: 3.5104534149169924
SNN index: 0.0020571231842041017
SNN query time: 0.40345449447631837



 60%|██████    | 6/10 [05:16<04:38, 69.58s/it]

brute force 1 query time: 9.124963808059693
brute force 2 query time: 1.4341964244842529
KDtree index: 0.005086231231689453
KDtree query time: 4.242967748641968
Balltree index: 0.0045298576354980465
Balltree query time: 4.130775499343872
SNN index: 0.0025084972381591796
SNN query time: 0.4574014186859131



 70%|███████   | 7/10 [07:19<04:20, 87.00s/it]

brute force 1 query time: 11.734852504730224
brute force 2 query time: 2.054579257965088
KDtree index: 0.006420707702636719
KDtree query time: 5.147843265533448
Balltree index: 0.005020713806152344
Balltree query time: 5.053049755096436
SNN index: 0.0031085491180419924
SNN query time: 0.5654738426208497



 80%|████████  | 8/10 [09:45<03:31, 105.98s/it]

brute force 1 query time: 14.159282350540161
brute force 2 query time: 2.6852327823638915
KDtree index: 0.007737588882446289
KDtree query time: 5.979711055755615
Balltree index: 0.006150531768798828
Balltree query time: 5.802154779434204
SNN index: 0.0032407283782958985
SNN query time: 0.6812399387359619



 90%|█████████ | 9/10 [12:38<02:06, 126.80s/it]

brute force 1 query time: 16.774963235855104
brute force 2 query time: 3.4916654586791993
KDtree index: 0.00843186378479004
KDtree query time: 6.827892541885376
Balltree index: 0.006991004943847657
Balltree query time: 6.581945991516113
SNN index: 0.0034593582153320313
SNN query time: 0.8169749259948731



100%|██████████| 10/10 [15:55<00:00, 95.55s/it] 

brute force 1 query time: 19.454589557647704
brute force 2 query time: 4.113681507110596
KDtree index: 0.008959150314331055
KDtree query time: 7.552644205093384
Balltree index: 0.007610940933227539
Balltree query time: 7.330803251266479
SNN index: 0.003923463821411133
SNN query time: 0.9372890949249267






In [3]:
### save the data

index_time_kd = np.array(index_time_kd)
index_time_bl = np.array(index_time_bl) 
index_time_sn = np.array(index_time_sn) 

run_time_bf1 = np.array(run_time_bf1)
run_time_bf2 = np.array(run_time_bf2)
run_time_kd = np.array(run_time_kd)
run_time_bl = np.array(run_time_bl)
run_time_sn = np.array(run_time_sn)


with open('result/query_r/size/index_time_kd_d2.npy', 'wb') as f:
    np.save(f, index_time_kd)

with open('result/query_r/size/index_time_bl_d2.npy', 'wb') as f:
    np.save(f, index_time_bl)

with open('result/query_r/size/index_time_sn_d2.npy', 'wb') as f:
    np.save(f, index_time_sn)
    
with open('result/query_r/size/run_time_bf1_d2.npy', 'wb') as f:
    np.save(f, run_time_bf1)

with open('result/query_r/size/run_time_bf2_d2.npy', 'wb') as f:
    np.save(f, run_time_bf2)

with open('result/query_r/size/run_time_kd_d2.npy', 'wb') as f:
    np.save(f, run_time_kd)

with open('result/query_r/size/run_time_bl_d2.npy', 'wb') as f:
    np.save(f, run_time_bl)
    
with open('result/query_r/size/run_time_sn_d2.npy', 'wb') as f:
    np.save(f, run_time_sn)
    

### size (ndim=50)

In [4]:
n_samples = 2000
unit_inc = 2000
n_dim =  50
Rlist = [2, 2.1, 2.2, 2.3, 2.4]


index_time_kd = list()
index_time_bl = list()
index_time_sn = list()

run_time_bf1 = list()
run_time_bf2 = list()
run_time_kd = list()
run_time_bl = list()
run_time_sn = list()

rng = np.random.RandomState(0)
xrange = np.arange(n_samples, n_samples + 10*unit_inc, unit_inc)
for size in tqdm(range(len(xrange))):
    i = xrange[size]
    data = rng.random_sample((2*i, n_dim)) 
    X = data[:i]
    Query = data[i:]
     
    bf1et = 0
    bf2et = 0
    kd_id_et = 0
    kd_qy_et = 0 
    bl_id_et = 0
    bl_qy_et = 0
    sn_id_et = 0
    sn_qy_et = 0
    
    for R in Rlist:
        st = time.time()
        neigh = NearestNeighbors(radius=R, algorithm='brute')
        neigh.fit(X)
        for j in range(Query.shape[0]): 
            ind = neigh.radius_neighbors(
               Query[j:j+1], radius=R, return_distance=False
            )
        bf1et += time.time() - st

        st = time.time()
        for j in range(Query.shape[0]): 
            ind = bf_radius_fairness(Query[j], X, R, return_distance=False)
        bf2et += time.time() - st


        # KDtree
        st = time.time()
        kdtree = NearestNeighbors(radius=R, algorithm='kd_tree', leaf_size=leaf_size)  
        kdtree.fit(X)
        kd_id_et += time.time() - st
        
        st = time.time()
        for j in range(Query.shape[0]): 
            kdind = kdtree.radius_neighbors(Query[j:j+1], radius=R, return_distance=False) 
        kd_qy_et += time.time() - st
        
        # Balltree
        st = time.time()
        bltree = NearestNeighbors(radius=R, algorithm='ball_tree', leaf_size=leaf_size) 
        bltree.fit(X)
        bl_id_et += time.time() - st
        
        st = time.time()
        for j in range(Query.shape[0]): 
            blind = bltree.radius_neighbors(Query[j:j+1], radius=R, return_distance=False) 
        bl_qy_et += time.time() - st
        
        
        # snn
        st = time.time()
        snn = build_snn_model(X)
        sn_id_et += time.time() - st
        
        st = time.time()
        for j in range(Query.shape[0]): 
            sind = snn.query_radius(Query[j], R)
        sn_qy_et += time.time() - st
        
        
    print("brute force 1 query time:", bf1et/ len(Rlist))
    print("brute force 2 query time:", bf2et/ len(Rlist))
    print("KDtree index:", kd_id_et/ len(Rlist))
    print("KDtree query time:", kd_qy_et/ len(Rlist))
    print("Balltree index:", bl_id_et/ len(Rlist))
    print("Balltree query time:", bl_qy_et/ len(Rlist))
    print("SNN index:", sn_id_et/ len(Rlist))
    print("SNN query time:", sn_qy_et/ len(Rlist))
    
    index_time_kd.append(kd_id_et / len(Rlist))
    run_time_bf1.append(bf1et / len(Rlist))
    run_time_bf2.append(bf2et / len(Rlist))
    run_time_kd.append( kd_qy_et / len(Rlist))
    index_time_bl.append( bl_id_et / len(Rlist))
    run_time_bl.append( bl_qy_et / len(Rlist))
    index_time_sn.append(sn_id_et / len(Rlist))
    run_time_sn.append(sn_qy_et / len(Rlist))
    print()

 10%|█         | 1/10 [00:21<03:15, 21.77s/it]

brute force 1 query time: 1.2857616901397706
brute force 2 query time: 0.2525928020477295
KDtree index: 0.00824422836303711
KDtree query time: 1.600658655166626
Balltree index: 0.005670595169067383
Balltree query time: 1.0759779930114746
SNN index: 0.002515840530395508
SNN query time: 0.12198481559753419



 20%|██        | 2/10 [01:26<06:18, 47.29s/it]

brute force 1 query time: 3.5274044990539553
brute force 2 query time: 1.0507651805877685
KDtree index: 0.018450832366943358
KDtree query time: 5.024897718429566
Balltree index: 0.013147830963134766
Balltree query time: 2.896449565887451
SNN index: 0.004449701309204102
SNN query time: 0.4945737361907959



 30%|███       | 3/10 [03:26<09:21, 80.26s/it]

brute force 1 query time: 6.638609409332275
brute force 2 query time: 2.4020076751708985
KDtree index: 0.027284955978393553
KDtree query time: 8.526308155059814
Balltree index: 0.019746160507202147
Balltree query time: 5.166051435470581
SNN index: 0.0070989131927490234
SNN query time: 1.1100894927978515



 40%|████      | 4/10 [07:07<13:35, 135.96s/it]

brute force 1 query time: 10.501354312896728
brute force 2 query time: 4.17706151008606
KDtree index: 0.04130067825317383
KDtree query time: 18.912963533401488
Balltree index: 0.03014826774597168
Balltree query time: 8.64463996887207
SNN index: 0.009214019775390625
SNN query time: 1.9505250930786133



 50%|█████     | 5/10 [13:35<18:53, 226.65s/it]

brute force 1 query time: 16.42732768058777
brute force 2 query time: 6.4816522121429445
KDtree index: 0.05551347732543945
KDtree query time: 31.5589777469635
Balltree index: 0.04117393493652344
Balltree query time: 19.844099760055542
SNN index: 0.012276601791381837
SNN query time: 3.0650165557861326



 60%|██████    | 6/10 [22:57<22:43, 340.93s/it]

brute force 1 query time: 22.218090629577638
brute force 2 query time: 9.717243337631226
KDtree index: 0.06725263595581055
KDtree query time: 49.45512132644653
Balltree index: 0.049594688415527347
Balltree query time: 26.61669478416443
SNN index: 0.014846372604370116
SNN query time: 4.41048846244812



 70%|███████   | 7/10 [36:52<25:07, 502.45s/it]

brute force 1 query time: 31.027120113372803
brute force 2 query time: 13.046793556213379
KDtree index: 0.08338637351989746
KDtree query time: 67.5608229637146
Balltree index: 0.06746654510498047
Balltree query time: 48.83416194915772
SNN index: 0.017702627182006835
SNN query time: 6.354003000259399



 80%|████████  | 8/10 [54:20<22:31, 675.87s/it]

brute force 1 query time: 41.97265753746033
brute force 2 query time: 17.197084617614745
KDtree index: 0.1019866943359375
KDtree query time: 91.94804821014404
Balltree index: 0.07225623130798339
Balltree query time: 49.25440516471863
SNN index: 0.01941704750061035
SNN query time: 8.871907186508178



 90%|█████████ | 9/10 [1:15:27<14:20, 860.73s/it]

brute force 1 query time: 48.655680131912234
brute force 2 query time: 22.084027147293092
KDtree index: 0.11615700721740722
KDtree query time: 107.80477318763732
Balltree index: 0.08374629020690919
Balltree query time: 62.53867907524109
SNN index: 0.02195148468017578
SNN query time: 12.127765130996703



100%|██████████| 10/10 [1:42:06<00:00, 612.65s/it] 

brute force 1 query time: 64.09627637863159
brute force 2 query time: 28.675651597976685
KDtree index: 0.1340017795562744
KDtree query time: 131.5664942264557
Balltree index: 0.09935789108276367
Balltree query time: 79.32814283370972
SNN index: 0.023569345474243164
SNN query time: 15.905593109130859






In [5]:
### save the data

index_time_kd = np.array(index_time_kd)
index_time_bl = np.array(index_time_bl) 
index_time_sn = np.array(index_time_sn) 

run_time_bf1 = np.array(run_time_bf1)
run_time_bf2 = np.array(run_time_bf2)
run_time_kd = np.array(run_time_kd)
run_time_bl = np.array(run_time_bl)
run_time_sn = np.array(run_time_sn)


with open('result/size/index_time_kd.npy', 'wb') as f:
    np.save(f, index_time_kd)

with open('result/size/index_time_bl.npy', 'wb') as f:
    np.save(f, index_time_bl)

with open('result/size/index_time_sn.npy', 'wb') as f:
    np.save(f, index_time_sn)
    
with open('result/size/run_time_bf1.npy', 'wb') as f:
    np.save(f, run_time_bf1)

with open('result/size/run_time_bf2.npy', 'wb') as f:
    np.save(f, run_time_bf2)

with open('result/size/run_time_kd.npy', 'wb') as f:
    np.save(f, run_time_kd)

with open('result/size/run_time_bl.npy', 'wb') as f:
    np.save(f, run_time_bl)
    
with open('result/size/run_time_sn.npy', 'wb') as f:
    np.save(f, run_time_sn)
    