***统计学习方法 ***`P41 - P43`

***在机器学习实战中， 作者给出KNN算法的雏形，其中所用到的数据结构是一个很简单的线性数组（Numpy），当然其内部已经进行了某种优化，但在这里我们假设它搜索最近邻居时是逐个搜索，也就是O（N）时间复杂度。在统计学习方法中，作者提出了KD-tree这一数据结构，使得搜索的时间复杂度可以达到O（logN）。下面主要就是KD-tree的实现。***

In [1]:
import numpy as np
import ipdb
from time import clock
from random import random
import sys

In [29]:
#Kd_tree主要有左孩子，右孩子，自身的坐标，以及坐标分类所依据的维度四部分组成
class KD_Node(object):
    def __init__(self, point, dim=None, left=None, right=None):
        self.point = point#坐标
        self.dim = dim#分类依据的维度
        self.left = left
        self.right = right
        
class KD_tree(object):
    root = None
    def __init__(self, data): 
        #构建KD树， 使用递归
        def create_tree(data):
            max_std = 0
            index = 0
            if len(data) == 0:
                return ;
            for i in np.arange(data.shape[1]):
                std = data[:,i].std()
                if std > max_std:
                    max_std = std
                    index = i
            data = data[data[:, index].argsort()]
            root = KD_Node(data[int(len(data) / 2)], index)
            #ipdb.set_trace()
            
            left = data[:int(len(data) / 2)]
            right = data[np.array([int(len(data) / 2)+1, len(data)]).min():]

            root.left = create_tree(left)
            root.right = create_tree(right)
            return root
        self.root = create_tree(data)
        
def pre_order(root):
    print(root.point)
    if root.left:
        pre_order(root.left)
    if root.right:
        pre_order(root.right)

In [30]:
data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
kd = KD_tree(data)
pre_order(kd.root)

[7 2]
[5 4]
[2 3]
[4 7]
[9 6]
[8 1]


In [51]:
def nearest_neighbour(root, node):
    def My_print():
        print("Distance = ", distance, "\nkN_Point = ",final_point.point)
    if root is None:
        return float("inf"), None
    nodes = []
    #沿着Kd树一直向下搜寻最近点， 直至达到叶节点
    #ipdb.set_trace()
    further_node = None
    while ((root.left != None) or (root.right != None)):
        nodes.append(root)
        if (root.left == None):
            further_node = root.left
            root = root.right
        elif (root.right == None):
            further_node = root.right
            root = root.left
        elif (root.point[root.dim] > node[root.dim]):
            further_node = root.right
            root = root.left
        else:
            further_node = root.left
            root = root.right
    nodes.append(root)
    final_point = root
    distance = np.sqrt(((root.point - np.array(node))**2).sum())
    i = len(nodes) - 2
    My_print()
    #return nodes
    #沿着所经过的节点一路回朔， 如果 到节点边界的距离 < distance， 那么进入另一分支进行迭代
    while i != -1:
        root = nodes[i]
        #print(root.point)
        temp_dis = np.sqrt(((root.point - np.array(node))**2).sum())
        if temp_dis < distance:
            distance = temp_dis
            final_point = root
            My_print()
        if abs(root.point[root.dim] - node[root.dim]) < distance:
            #ipdb.set_trace()
            if (further_node != None):
                temp_dis, temp_point = nearest_neighbour(further_node, node)
            else:
                pass
            if temp_dis < distance:
                distance = temp_dis
                final_point = temp_point
        i -= 1
    return distance, final_point

In [52]:
data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
kd = KD_tree(data)
distance, ret = nearest_neighbour(kd.root, [3,4.5])
print(ret.point)

Distance =  2.692582403567252 
kN_Point =  [4 7]
Distance =  2.0615528128088303 
kN_Point =  [5 4]
Distance =  1.8027756377319946 
kN_Point =  [2 3]
[2 3]


In [76]:

def random_point(k):
    return [random() for _ in range(k)]

def random_points(k, n):
    return [random_point(k) for _ in range(n)]  

N = 10000
t0 = clock()
kd2 = KD_tree(np.array(random_points(3, N)))           
distance, ret2 = nearest_neighbour(kd2.root, [0.1,0.5,0.8])
t1 = clock()
print("time: ",t1-t0, "s")
print(ret2)

Distance =  0.08083255216865692 
kN_Point =  [0.07772661 0.42486948 0.78017065]
Distance =  0.054850931150903365 
kN_Point =  [0.05211767 0.48524768 0.7776779 ]
Distance =  0.02431545418476312 
kN_Point =  [0.11146914 0.47856045 0.80021467]
time:  0.6916060000000002 s
<__main__.KD_Node object at 0x7f81845a7ac8>
