# Tree build and query times , comparison

In [None]:
# solve issue with autocomplete
%config Completer.use_jedi = False

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import time
from geotree import gtree
import numpy as np

## Build time

In [None]:
million = 1000000
list_npoints = [million // 100, million // 10, 1*million, 10*million]

### KDTree

In [None]:
build_times_kdt = []

for npoints in list_npoints:
    print(npoints)
    mytree = gtree()
    
    lons = np.random.randint(-180, 180, npoints)
    lats = np.random.randint(-90, 90, npoints)
    depths = np.zeros(npoints)

    # Add lons/lats/depths of the first set of points
    mytree.add_lonlatdep(lons=lons, 
                         lats=lats, 
                         depths=depths)
    
    t1 = time.time()
    mytree.create_kdt()
    build_times_kdt.append(time.time() - t1)

### BallTree

In [None]:
build_times_ball = []
for npoints in list_npoints:
    print(npoints)
    mytree = gtree()
    
    lons = np.random.randint(-180, 180, npoints)
    lats = np.random.randint(-90, 90, npoints)
    depths = np.zeros(npoints)

    # Add lons/lats/depths of the first set of points
    mytree.add_lonlatdep(lons=lons, 
                         lats=lats, 
                         depths=depths)
    
    t1 = time.time()
    mytree.create_kdt()
    build_times_ball.append(time.time() - t1)

## Query time, KDTree

In [None]:
npoints = 1*million

lons = np.random.randint(-180, 180, npoints)
lats = np.random.randint(-90, 90, npoints)
depths = np.zeros(npoints)

In [None]:
list_queries = [million // 1000, million // 100, million // 10, 1*million]

In [None]:
mytree = gtree()

# Add lons/lats/depths of the first set of points
mytree.add_lonlatdep(lons=lons, 
                     lats=lats, 
                     depths=depths)

In [None]:
mytree.create_kdt()

In [None]:
query_times_kdt = []
for q_npoints in list_queries:
    print(q_npoints)

    q_lons = np.random.randint(-180, 180, q_npoints)
    q_lats = np.random.randint(-90, 90, q_npoints)
    q_depths = np.zeros(q_npoints)
    
    # Add lons/lats/depths of queries
    mytree.add_lonlatdep_query(lons=q_lons, 
                               lats=q_lats, 
                               depths=q_depths)
    
    t1 = time.time()
    mytree.query_kdt()
    query_times_kdt.append(time.time() - t1)

## Query time, BallTree

In [None]:
mytree = gtree()

# Add lons/lats/depths of the first set of points
mytree.add_lonlatdep(lons=lons, 
                     lats=lats, 
                     depths=depths)

In [None]:
mytree.create_balltree()

In [None]:
query_times_ball = []
for q_npoints in list_queries:
    print(q_npoints)
    
    q_lons = np.random.randint(-180, 180, q_npoints)
    q_lats = np.random.randint(-90, 90, q_npoints)
    q_depths = np.zeros(q_npoints)
    
    # Add lons/lats/depths of queries
    mytree.add_lonlatdep_query(lons=q_lons, 
                               lats=q_lats, 
                               depths=q_depths)
    
    t1 = time.time()
    mytree.query_balltree()
    query_times_ball.append(time.time() - t1)

## Plot results

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))

# --- Build time
plt.subplot(1, 2, 1)
plt.plot(list_npoints, build_times_kdt, 
         marker="o", c="k", lw=3, 
         label="KDTree")
plt.plot(list_npoints, build_times_ball, 
         marker="o", c="r", lw=3, 
         label="BallTree")

plt.legend(fontsize=18)
plt.xscale("log"); plt.yscale("log")
plt.xticks(size=18); plt.yticks(size=18)
plt.grid()

plt.xlabel("#Queries", size=20)
plt.ylabel("Time (sec)", size=20)
plt.title("Build time", size=24)

# --- Query time
plt.subplot(1, 2, 2)

plt.plot(list_queries, query_times_kdt, 
         marker="o", c="k", lw=3, 
         label="KDTree")
plt.plot(list_queries, query_times_ball, 
         marker="o", c="r", lw=3, 
         label="BallTree")

plt.legend(fontsize=18)
plt.xscale("log"); plt.yscale("log")
plt.xticks(size=18); plt.yticks(size=18)
plt.grid()

plt.xlabel("#Queries", size=20)
plt.ylabel("Time (sec)", size=20)
plt.title("Query time", size=24)

plt.tight_layout()
plt.show()