***统计学习方法 ***`P41 - P43`

***在机器学习实战中， 作者给出KNN算法的雏形，其中所用到的数据结构是一个很简单的线性数组（Numpy），当然其内部已经进行了某种优化，但在这里我们假设它搜索最近邻居时是逐个搜索，也就是O（N）时间复杂度。在统计学习方法中，作者提出了KD-tree这一数据结构，使得搜索的时间复杂度可以达到O（logN）。下面主要就是KD-tree的实现。***

In [180]:
import numpy as np
import ipdb
from time import clock
from random import random

In [102]:
#Kd_tree主要有左孩子，右孩子，自身的坐标，以及坐标分类所依据的维度四部分组成
class KD_Node(object):
    def __init__(self, point, dim=None, left=None, right=None):
        self.point = point#坐标
        self.dim = dim#分类依据的维度
        self.left = left
        self.right = right
        
class KD_tree(object):
    root = None
    def __init__(self, data): 
        #构建KD树， 使用递归
        def create_tree(data):
            max_std = 0
            index = 0
            #ipdb.set_trace()
            for i in np.arange(data.shape[1]):
                std = data[:,i].std()
                if std > max_std:
                    max_std = std
                    index = i
            data = data[data[:, index].argsort()]
            root = KD_Node(data[int(len(data) / 2)], index)

            left = data[:int(len(data) / 2)]
            right = data[np.array([int(len(data) / 2)+1, len(data)]).min():]

            if len(left) > 0:
                root.left = create_tree(left)
            if len(right) > 0:
                root.right = create_tree(right)
            return root
        self.root = create_tree(data)
        
def pre_order(root):
    print(root.point)
    if root.left:
        pre_order(root.left)
    if root.right:
        pre_order(root.right)

In [161]:
data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
kd = KD_tree(data)
pre_order(kd.root)

[7 2]
[5 4]
[2 3]
[4 7]
[9 6]
[8 1]


In [202]:
def nearest_neighbour(root, node):
    def My_print():
        print("Distance = ", distance, "\nkN_Point = ",final_point.point)
    nodes = []
    #沿着Kd树一直向下搜寻最近点， 直至达到叶节点
    while ((root.left != None) | (root.right != None)):
        nodes.append(root)
        if root.point[root.dim] > node[root.dim]:
            root = root.left
        else:
            root = root.right
    nodes.append(root)
    final_point = root
    distance = np.sqrt(((root.point - np.array(node))**2).sum())
    i = len(nodes) - 2
    My_print()
    #return nodes
    #沿着所经过的节点一路回朔， 如果 到节点边界的距离 < distance， 那么进入另一分支进行迭代
    while i != -1:
        root = nodes[i]
        #print(root.point)
        temp_dis = np.sqrt(((root.point - np.array(node))**2).sum())
        if temp_dis < distance:
            distance = temp_dis
            final_point = root
            My_print()
        if abs(root.point[root.dim] - node[root.dim]) < distance:
            if (root.left == nodes[len(nodes) - 1]):
                temp_dis, temp_point = nearest_neighbour(root.right, node)
            else:
                temp_dis, temp_point = nearest_neighbour(root.left, node)
            if temp_dis < distance:
                distance = temp_dis
                final_point = temp_point
        i -= 1
    return distance, final_point

In [169]:
nearest_neighbour(kd.root, [3, 4.5])

(1.8027756377319946, <__main__.KD_Node at 0x7fa61c0289e8>)

In [207]:
data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
kd = KD_tree(data)
distance, ret = nearest_neighbour(kd.root, [3,4.5])
print(ret.point)

def random_point(k):
    return [random() for _ in range(k)]

def random_points(k, n):
    return [random_point(k) for _ in range(n)]  

N = 10
t0 = clock()
kd2 = KD_tree(np.array(random_points(3, N)))           # 构建包含四十万个3维空间样本点的kd树
pre_order(kd2.root)
distance, ret2 = nearest_neighbour(kd2.root, [0.1,0.5,0.8])     # 四十万个样本点中寻找离目标最近的点
t1 = clock()
print("time: ",t1-t0, "s")
print(ret2)

Distance =  2.692582403567252 
kN_Point =  [4 7]
Distance =  2.0615528128088303 
kN_Point =  [5 4]
Distance =  1.8027756377319946 
kN_Point =  [2 3]
[2 3]
[0.59611687 0.01513444 0.01185983]
[0.38323696 0.62628713 0.2556212 ]
[0.3070893  0.89402261 0.22716096]
[0.0437612  0.65950089 0.19557503]
[0.44856427 0.80008483 0.5858697 ]
[0.25999713 0.61400155 0.37587771]
[0.83671964 0.79485535 0.70678786]
[0.92265523 0.08156015 0.5699779 ]
[0.8464582  0.45193156 0.09022666]
[0.97785555 0.49459712 0.77342025]


AttributeError: 'NoneType' object has no attribute 'left'