# Find closest neighbors (KDTree and BallTree)

In [None]:
# solve issue with autocomplete
%config Completer.use_jedi = False

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from geotree import gtree
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# instantiate gtree
mytree = gtree()

## Define the first set of points or `base`

In [None]:
npoints = 200
lons = np.random.randint(-180, 180, npoints)
lats = np.random.randint(-90, 90, npoints)
depths = np.zeros(npoints)

In [None]:
# Add lons/lats/depths of the first set of points
mytree.add_lonlatdep(lons=lons, 
                     lats=lats, 
                     depths=depths)

## Define queries

In [None]:
q_npoints = 10
q_lons = np.random.randint(-150, 150, q_npoints)
q_lats = np.random.randint(-70, 70, q_npoints)
q_depths = np.zeros(q_npoints)

In [None]:
# Add lons/lats/depths of queries
mytree.add_lonlatdep_query(lons=q_lons, 
                           lats=q_lats, 
                           depths=q_depths)

## Find neighbors, KDTree

In [None]:
# Create KDTree (kdt)
mytree.create_kdt()

In [None]:
# Choose the desired number of neighbors (and upper bound for distance, if needed):
mytree.query_kdt(num_neighs=3, distance_upper=np.inf)

In [None]:
# Now, for each query, distances to the closest `base` neighbors and their indices are stored in (row-wise):

# distances to the closest `base` neighbors 
mytree.dists2query

In [None]:
# indices of the closest `base` neighbors
mytree.indxs2query

### Plot the results

In [None]:
plt.figure(figsize=(15, 7))

colormap = plt.cm.Spectral
list_colors = [colormap(i) for i in np.linspace(0, 1,len(q_lons))]

# plot queries
plt.scatter(q_lons, q_lats, 
            c="b", 
            marker="x", 
            label="queries")

# --- plot a line between one query and its neighbours
# neighboring base lons/lats
b_lons = lons[mytree.indxs2query]
b_lats = lats[mytree.indxs2query]
for i in range(len(q_lons)):
    for j in range(len(mytree.indxs2query[i])):
        plt.plot((q_lons[i], b_lons[i, j]), 
                 (q_lats[i], b_lats[i, j]), 
                 c=list_colors[i])

# plot base points
plt.scatter(lons, lats,
            c="k", 
            marker="o",
            label="base",
            zorder=100)

plt.legend(bbox_to_anchor=(0., 1.01, 1., .05), 
           loc="right", ncol=2, 
           fontsize=16,
           borderaxespad=0.)

plt.grid()
plt.tight_layout()
plt.show()

## Project results using `cartopy`

`cartopy` needs to be installed, see: https://scitools.org.uk/cartopy/docs/latest/index.html

In [None]:
from cartopy import crs

fig = plt.figure(figsize=(10, 7))

ax = fig.add_subplot(1, 1, 1, 
                     projection=crs.InterruptedGoodeHomolosine())
ax.coastlines(color="black")

# plot queries
plt.scatter(
    mytree.lons_q, 
    mytree.lats_q, 
    transform=crs.PlateCarree(), 
    c="b",
    marker="x")

# --- plot a line between one query and its neighbours
# neighboring base lons/lats
b_lons = lons[mytree.indxs2query]
b_lats = lats[mytree.indxs2query]
for i in range(len(q_lons)):
    for j in range(len(mytree.indxs2query[i])):
        plt.plot((q_lons[i], b_lons[i, j]), 
                 (q_lats[i], b_lats[i, j]), 
                 transform=crs.PlateCarree(), 
                 c = list_colors[i])

# plot base points
plt.scatter(
    mytree.lons, 
    mytree.lats, 
    transform=crs.PlateCarree(), 
    color="k")


plt.tight_layout()
plt.show()

## Find neighbors, Ball tree

In [None]:
mytree.create_balltree()

In [None]:
# Choose the desired number of neighbors:
mytree.query_balltree(num_neighs=3)

In [None]:
# Now, for each query, distances to the closest `base` neighbors and their indices are stored in (row-wise):

# distances to the closest `base` neighbors 
mytree.dists2query

In [None]:
# indices of the closest `base` neighbors
mytree.indxs2query