In [49]:
import numpy as np
from operator import itemgetter

In [50]:
def euclidean(a, b):
    return np.sqrt(np.sum(np.square(a-b)))

In [51]:
class KDNode():
    def __init__(self, split, sample, parent, left, right):
        self.split = split
        self.sample = sample
        self.parent = parent
        self.left = left
        self.right = right
    
    def is_root(self):
        return not self.parent
    
    def is_leaf(self):
        return not (self.left or self.right)
        
    def is_left(self):
        return self.parent and self.parent.left is self

    def is_right(self):
        return self.parent and self.parent.right is self

    def get_sibling(self):
        if self.parent and self.parent.left is self:
            return self.parent.right
        elif self.parent and self.parent.right is self:
            return self.parent.left

In [52]:
def create_kd_tree(data):
    if data is None or data.shape[0] == 0 or data.shape[1] == 0:
        return None
    split = np.argmax(np.var(data, axis=0))
    data = np.array(sorted(data, key=lambda x: x[split]))
    head_index = len(data) // 2
    left = create_kd_tree(data[:head_index])
    right = create_kd_tree(data[head_index+1:])
    head = KDNode(split, data[head_index], None, left, right)
    if left:
        left.parent = head
    if right:
        right.parent = head
    return head

In [99]:
def find_nearest(kd_tree, target):
    if len(kd_tree.sample) != len(target):
        raise ValueError('dim not match!')
    if not kd_tree or target is None:
        return -1
    head = kd_tree
    while not head.is_leaf():
        if target[head.split] <= head.sample[head.split]:
            head = head.left
        else:
            head = head.right
    curr_node = head
    curr_dis = euclidean(target, curr_node.sample)
    while not head.is_root(): 
        if np.abs(head.parent.sample[head.split] - target[head.parent.split]) < curr_dis:
            sbiling = head.get_sibling()
            if sbiling:
                dis = euclidean(target, sbiling.sample)
                if dis < curr_dis:
                    curr_node = sbiling
                    curr_dis = dis
            dis = euclidean(target, head.parent.sample)
            if dis < curr_dis:
                curr_node = head.parent
                curr_dis = dis
        head = head.parent
    return curr_node

In [113]:
from time import time

In [121]:
data = np.random.randn(1000000, 100)

In [122]:
start = time()
kdtree = create_kd_tree(data)
print(time()-start)

107.47628879547119


In [123]:
start = time()
x = find_nearest(kdtree, np.random.randn(100,)).sample
print(time()-start)
x

0.019930124282836914


array([-0.56049827, -1.77639888,  0.83050868, -1.02105641, -1.19375309,
       -0.2280448 , -1.94537018,  0.78539984, -0.80579798, -0.74345352,
       -0.05623168, -0.83911542, -1.84112504, -1.0023634 ,  0.6602706 ,
       -0.3654255 ,  1.81669523,  0.5323542 , -1.6162134 , -0.4957659 ,
        2.59221906,  0.18844354,  0.28398069,  0.54893566, -1.06128046,
        0.3568645 ,  0.21530722, -0.99837141,  0.52801384, -0.14903611,
        0.58808552,  0.41907573, -0.44009569, -1.78553887, -0.41225327,
        0.39835787,  0.71154818, -0.2332933 , -0.09840185, -1.12056705,
        0.41064583,  0.78216058,  1.58941708,  1.69764421, -1.14925329,
        1.02655942, -0.99259864, -1.03153328,  2.26730468, -0.9296867 ,
        1.26807766,  0.85190048, -0.9050304 , -0.16492928,  0.30493217,
        0.42288897,  1.44158378, -0.04736664,  1.3579324 ,  0.57138329,
        1.18463809,  0.69940215,  0.50192297,  0.60986733, -0.88743104,
        1.55667818, -0.78796919, -0.35689132,  0.95136294,  0.75

In [108]:
a

NameError: name 'a' is not defined