In [49]:
import numpy as np
from operator import itemgetter

In [50]:
def euclidean(a, b):
    return np.sqrt(np.sum(np.square(a-b)))

In [51]:
class KDNode():
    def __init__(self, split, sample, parent, left, right):
        self.split = split
        self.sample = sample
        self.parent = parent
        self.left = left
        self.right = right
    
    def is_root(self):
        return not self.parent
    
    def is_leaf(self):
        return not (self.left or self.right)
        
    def is_left(self):
        return self.parent and self.parent.left is self

    def is_right(self):
        return self.parent and self.parent.right is self

    def get_sibling(self):
        if self.parent and self.parent.left is self:
            return self.parent.right
        elif self.parent and self.parent.right is self:
            return self.parent.left

In [52]:
def create_kd_tree(data):
    if data is None or data.shape[0] == 0 or data.shape[1] == 0:
        return None
    split = np.argmax(np.var(data, axis=0))
    data = np.array(sorted(data, key=lambda x: x[split]))
    head_index = len(data) // 2
    left = create_kd_tree(data[:head_index])
    right = create_kd_tree(data[head_index+1:])
    head = KDNode(split, data[head_index], None, left, right)
    if left:
        left.parent = head
    if right:
        right.parent = head
    return head

In [67]:
def find_nearest(kd_tree, target):
    if not kd_tree or not target:
        return -1
    head = kd_tree
    while not head.is_leaf():
        if target[head.split] <= head.sample[head.split]:
            head = head.left
        else:
            head = head.right
    curr_node = head
    curr_dis = euclidean(target, curr_node.sample)
    while not head.is_root(): 
        if np.abs(head.parent.sample[head.split] - target[head.parent.split]) < curr_dis:
            sbiling = head.get_sibling()
            if sbiling:
                dis = euclidean(target, sbiling.sample)
                if dis < curr_dis:
                    curr_node = sbiling
                    curr_dis = dis
            dis = euclidean(target, head.parent.sample)
            if dis < curr_dis:
                curr_node = head.parent
                curr_dis = dis
        head = head.parent
    return curr_node

In [53]:
data = np.array([[2,3], [5,4], [9,6], [4,7], [8,1], [7,2]])

In [54]:
kdtree = create_kd_tree(data)

In [79]:
find_nearest(kdtree, [3, 7.1]).sample

array([4, 7])