# 第 2 章 KNN（KD tree搜索）
## By LiuGang - 2018/11/13
## Reference Book - statistical learning method (Chinese)<br>Description: faster than traversal, more obvious in the case of more X and more y.
### 1:  Create some data

In [170]:
import numpy as np
#X = np.array([[1,2],[1,1],[2,1],[2,2],[1.5,3],[3,3],[2.5,0.5],[4,3]])
#y = np.array([0,0,0,0,1,1,1,1])
np.random.seed(5)
X = np.random.randint(1,10000,size=(100000, 3))
y = np.array([1 if i%2==0 else 0 for i in range(100000)])

[5521 1033  741]


### 2: Build KD Tree

In [102]:
class kd_node():
    def __init__(self,x):
        self.val = x
        self.parent = None
        self.left = None
        self.right = None
        self.dim = 1
        self.y = 0


In [166]:
class kd_tree():
    def __init__(self, data):
        self.data = data
        self.dim = data.shape[1]
        self.mindis = np.nan
        self.nearest_node = None
        self.k_point = []
        self.k_dis = []
    
    def init_use(self):
        self.mindis = np.nan
        self.nearest_node = None
        self.k_point = []
        self.k_dis = []
        
    def build_tree(self, data, y, sdim, head, parent):
        if len(data) == 0:
            return
        sdim = sdim%self.dim
        data = data[data[:,sdim].argsort()]
        medi = data.shape[0]//2
        
        head = kd_node(data[medi])
        head.dim = sdim
        head.y = y[medi]
        head.parent = parent
        head.left = self.build_tree(data[0:medi], y[0:medi], sdim+1,head.left, head)
        head.right = self.build_tree(data[medi+1:], y[medi+1:], sdim+1, head.right, head)
        return head
    
    def get_euc_dis(self,x1,x2):
        return np.sum(np.multiply(x1-x2,x1-x2))**(0.5)
    
    def kd_get_path(self, root, point, k):
        stack = []
        stack.append(root)
        self.nearest_node = root
        root_dis = self.get_euc_dis(root.val, point)
        self.mindis = min(root_dis, self.mindis)
        #save top-k value
        if len(self.k_point) < k:
            self.k_point.append(root)
            self.k_dis.append(root_dis)
        else:
            self.k_point.append(root)
            self.k_dis.append(root_dis)
            dis_arg = np.array(self.k_dis).argsort()
            self.k_dis = np.array(self.k_dis)[dis_arg].tolist()[0:k]
            self.k_point = np.array(self.k_point)[dis_arg].tolist()[0:k]
        
        while True:
            if point[root.dim] < root.val[root.dim]:
                root = root.left
            else:
                root = root.right
            if root:
                new_dis = self.get_euc_dis(root.val, point)
                #save top-k value
                if len(self.k_point) < k:
                    self.k_point.append(root)
                    self.k_dis.append(new_dis)
                else:
                    self.k_point.append(root)
                    self.k_dis.append(new_dis)
                    dis_arg = np.array(self.k_dis).argsort()
                    self.k_dis = np.array(self.k_dis)[dis_arg].tolist()[0:k]
                    self.k_point = np.array(self.k_point)[dis_arg].tolist()[0:k]
                
                if new_dis < self.mindis:
                    self.nearest_node = root
                    self.mindis = new_dis
                stack.append(root)
            else:
                break
        return stack, self.mindis, self.nearest_node
    
    def knn_search(self, root, point, k):
        self.init_use()
        stack, __m, __n = self.kd_get_path(root, point, k)
        while True:
            back_node = stack.pop()
            if back_node.parent == None:
                break
            if abs(point[back_node.parent.dim] - back_node.parent.val[back_node.parent.dim]) <= self.mindis:
                if back_node.parent.right and back_node.parent.left and back_node.parent.left == back_node:
                    back_stack, __m, __n = self.kd_get_path(back_node.parent.right, point, k)
                elif back_node.parent.right and back_node.parent.left and back_node.parent.right == back_node:
                    back_stack, __m, __n = self.kd_get_path(back_node.parent.left, point, k)
        return self.mindis, self.k_point, self.k_dis
    
    def predict_y(self, root, points, k):
        _result = []
        for point in points:
            mindis, point_node, dis = self.knn_search(root, point, k)
            _result.append(np.argmax(np.bincount(np.array([nn.y for nn in point_node]))))
        return _result
    
    def predict_point(self, root, points, k):
        _result = []
        for point in points:
            mindis, point_node, dis = self.knn_search(root, point, k)
            _result.append([nn.val for nn in point_node])
        return _result
   


### 3: Test

In [174]:
import time
t0 = time.clock()
kdtree = kd_tree(X)
root = kdtree.build_tree(X, y, 0, None, None)
yres = kdtree.predict_y(root, np.array([[1.95,0,0],[600,1000,200],[7,500,1568],[1.95,0,0],[600,1000,200],[7,500,1568]]), 5)
t1 = time.clock()
print(yres)
print(t1 - t0)

[1, 1, 0, 1, 1, 0]
1.7096930000000015
