In [1]:
import time
import numpy as np
from snnpy import *
from sklearn.neighbors import BallTree
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree
from bf_search import *

In [2]:
def bvecs_read(fname):
    a = np.fromfile(fname, dtype=np.int32, count=1)
    b = np.fromfile(fname, dtype=np.uint8)
    d = a[0]
    return b.reshape(-1, d + 4)[:, 4:].copy()


def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy()


def fvecs_read2(fname):
    return ivecs_read(fname).view('float32')

### fashion mnist

In [3]:
fmn_train = np.load("Euclidean_data/fashion_mnist/train.npy")
fmn_query = np.load("Euclidean_data/fashion_mnist/queries.npy")

fmn_kdtree_index_timing = list()
fmn_balltree_index_timing = list()
fmn_snn_index_timing = list()

fmn_bf_run_timing1 = list()
fmn_bf_run_timing2 = list()
fmn_kdtree_run_timing = list()
fmn_balltree_run_timing = list()
fmn_snn_run_timing = list()


In [4]:
radius = [800, 900, 1000, 1100, 1200] # test various radius

In [5]:
# Brute force 1
for R in radius:
    st = time.time()
    neigh = NearestNeighbors(radius=R, algorithm='brute')
    neigh.fit(fmn_train)
    for j in range(fmn_query.shape[0]): 
        ind = neigh.radius_neighbors(
           fmn_query[j:j+1], radius=R, return_distance=False
        )
    et = time.time() - st
    fmn_bf_run_timing1.append(et)
    print("brute force 1 query time:", et)
print()

# Brute force 2
for R in radius:
    st = time.time()
    for j in range(fmn_query.shape[0]): 
        ind = bf_radius_fairness(fmn_query[j], fmn_train, R, return_distance=False)
    et = time.time() - st
    fmn_bf_run_timing2.append(et)
    print("brute force 2 query time:", et)
print()


    
# KDtree
st = time.time()
tree = KDTree(fmn_train)   
et = time.time() - st
fmn_kdtree_index_timing.append(et)
print("kdtree index time:", et)
    
for R in radius:
    st = time.time()
    for j in range(fmn_query.shape[0]): 
        ind = tree.query_radius(fmn_query[j:j+1], r=R, return_distance=False)
    et = time.time() - st
    fmn_kdtree_run_timing.append(et)
    print("kdtree query time:", et)
print()


# Balltree
st = time.time()
tree = BallTree(fmn_train)   
et = time.time() - st
fmn_balltree_index_timing.append(et)
print("ball tree index time:", et)

for R in radius:
    st = time.time()
    for j in range(fmn_query.shape[0]): 
        ind = tree.query_radius(fmn_query[j:j+1], r=R, return_distance=False)
    et = time.time() - st
    fmn_balltree_run_timing.append(et)
    print("ball tree query time:", et)
print()

# SNN
st = time.time()
snn = build_snn_model(fmn_train)
et = time.time() - st
fmn_snn_index_timing.append(et)
print("snn index time:", et)


for R in radius:
    st = time.time()
    for j in range(fmn_query.shape[0]): 
        ind = snn.query_radius(fmn_query[j], R)
    et = time.time() - st
    fmn_snn_run_timing.append(et)
    print("snn query time:", et)

brute force 1 query time: 3027.9109239578247
brute force 1 query time: 2444.403444290161
brute force 1 query time: 2185.027276277542
brute force 1 query time: 2173.101660490036
brute force 1 query time: 2162.1529512405396

brute force 2 query time: 439.9359426498413
brute force 2 query time: 439.56174492836
brute force 2 query time: 440.9798436164856
brute force 2 query time: 442.83657813072205
brute force 2 query time: 443.2322974205017

kdtree index time: 9.035333633422852
kdtree query time: 1463.209626197815
kdtree query time: 1522.1630582809448
kdtree query time: 1571.509311914444
kdtree query time: 1605.4885742664337
kdtree query time: 1632.9131729602814

ball tree index time: 7.88179612159729
ball tree query time: 1103.4436783790588
ball tree query time: 1106.9124047756195
ball tree query time: 1111.954965353012
ball tree query time: 1114.8868584632874
ball tree query time: 1107.6689455509186

snn index time: 1.3347787857055664
snn query time: 77.65121459960938
snn query time: 86

In [6]:
with open('result/fmn_kdtree_index_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(fmn_kdtree_index_timing))

with open('result/fmn_balltree_index_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(fmn_balltree_index_timing))

with open('result/fmn_snn_index_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(fmn_snn_index_timing))

with open('result/fmn_bf_run_timing1_norm.npy', 'wb') as f:
    np.save(f, np.array(fmn_bf_run_timing1))
    
with open('result/fmn_bf_run_timing2_norm.npy', 'wb') as f:
    np.save(f, np.array(fmn_bf_run_timing2))
    
with open('result/fmn_kdtree_run_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(fmn_kdtree_run_timing))

with open('result/fmn_balltree_run_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(fmn_balltree_run_timing))

with open('result/fmn_snn_run_timing_norm.npy', 'wb') as f:
    np.save(f, np.array(fmn_snn_run_timing))
    
