In [1]:
import numpy as np

## KD-Tree 最近邻

In [2]:
class Node:
    def __init__(self, data, parent, dim):
        self.data = data
        self.parent = parent
        self.lChild = None
        self.rChild = None
        self.dim = dim
        # only track search_Up process
        self.up_traced = False

    def setLChild(self, lChild):
        self.lChild = lChild

    def setRChild(self, rChild):
        self.rChild = rChild

In [3]:
class KdTree:
    def __init__(self, train):
        self.root = self.__build(train, 1, None)

    def __build(self, train, depth, parent):  # 递归建树
        (m, k) = train.shape
        if m == 0:
            return None

        train = train[train[:, depth % k].argsort()]

        root = Node(train[m//2], parent, depth % k)
        root.setLChild(self.__build(train[:m//2, :], depth+1, root))
        root.setRChild(self.__build(train[m//2+1:, :], depth+1, root))
        return root

    def findNearestPointAndDistance(self, point):  # 查找与point距离最近的k个点
        point = np.array(point)
        node = self.__findSmallestSubSpace(point, self.root)
        print("Start node:", node.data)
        return self.__searchUp(point, node, node, np.linalg.norm(point - node.data))
    
    def printkdTree(self):
        self.__preOrderKdTree(self.root)
    
    def __preOrderKdTree(self,node):
        if(node):
            print(node.data,node.dim)
        if(node.lChild):
            self.__preOrderKdTree(node.lChild)
        if(node.rChild):
            self.__preOrderKdTree(node.rChild)


    def __searchUp(self, point, node, nearestPoint, nearestDistance):
        if node.parent is None:
            return [nearestPoint, nearestDistance]

        print("UP:", node.parent.data)
        node.parent.up_traced = True
        distance = np.linalg.norm(node.parent.data - point) #计算待求节点与其所在子空间的父节点距离
        if distance < nearestDistance:
            nearestDistance = distance
            nearestPoint = node.parent

        #以待求节点为圆心，当前最近距离为半径的圆，计算是否与其所在子空间父节点的另一个子空间相交
        distance = np.abs(node.parent.data[node.parent.dim] - point[node.parent.dim])
        if distance < nearestDistance:
            [p, d] = self.__searchDown(point, node.parent)
            if d < nearestDistance:
                nearestDistance = d
                nearestPoint = p

        [p, d] = self.__searchUp(point, node.parent, nearestPoint, nearestDistance)
        if d < nearestDistance:
            nearestDistance = d
            nearestPoint = p

        return [nearestPoint, nearestDistance]

    def __searchDown(self, point, node):
        nearestDistance = np.linalg.norm(node.data - point)
        nearestPoint = node

        print("DOWN:", node.data)
        if node.lChild is not None and node.lChild.up_traced is False:
            [p, d] = self.__searchDown(point, node.lChild)
            if d < nearestDistance:
                nearestDistance = d
                nearestPoint = p   

        if node.rChild is not None and node.rChild.up_traced is False:
            [p, d] = self.__searchDown(point, node.rChild)
            if d < nearestDistance:
                nearestDistance = d
                nearestPoint = p
            
        print("---- ", nearestPoint.data, nearestDistance)
        return [nearestPoint, nearestDistance]

    def __findSmallestSubSpace(self, point, node):  # 找到这个点所在的最小的子空间
        """
        从根节点出发，递归地向下访问kd树。如果point当前维的坐标小于切分点的坐标，则
        移动到左子节点，否则移动到右子节点。直到子节点为叶节点为止。
        """
        # New search: clean up up_traced flag for all up path nodes
        node.up_traced = False
        if point[node.dim] < node.data[node.dim]:
            if node.lChild is None:
                return node
            else:
                return self.__findSmallestSubSpace(point, node.lChild)
        else:
            if node.rChild is None:
                return node
            else:
                return self.__findSmallestSubSpace(point, node.rChild)
            

In [4]:
train = np.array([[2, 5], [3, 2], [3, 7], [8, 3], [6, 6], [1, 1], [1, 8]])
kdTree = KdTree(train)
target = np.array([2, 2])
print('target :', target)
[p, d] = kdTree.findNearestPointAndDistance(target)

print(p.data, d)
print('---------------------')

(m, k) = train.shape
for i in range(m):
    print(train[i], np.linalg.norm(train[i]-target))


target : [2 2]
Start node: [1 1]
UP: [3 2]
UP: [2 5]
[3 2] 1.0
---------------------
[2 5] 3.0
[3 2] 1.0
[3 7] 5.0990195135927845
[8 3] 6.082762530298219
[6 6] 5.656854249492381
[1 1] 1.4142135623730951
[1 8] 6.082762530298219


In [5]:
train = np.array([[2, 5], [3, 2], [3, 7], [8, 3], [6, 6], [1, 1], [1, 8]])
kdTree = KdTree(train)
target = np.array([6, 4])
print('target :', target)
[p, d] = kdTree.findNearestPointAndDistance(target)

print(p.data, d)
print('---------------------')

(m, k) = train.shape
for i in range(m):
    print(train[i], np.linalg.norm(train[i]-target))

target : [6 4]
Start node: [8 3]
UP: [3 2]
UP: [2 5]
DOWN: [2 5]
DOWN: [3 7]
DOWN: [1 8]
----  [1 8] 6.4031242374328485
DOWN: [6 6]
----  [6 6] 2.0
----  [6 6] 2.0
----  [6 6] 2.0
[6 6] 2.0
---------------------
[2 5] 4.123105625617661
[3 2] 3.605551275463989
[3 7] 4.242640687119285
[8 3] 2.23606797749979
[6 6] 2.0
[1 1] 5.830951894845301
[1 8] 6.4031242374328485


In [6]:
kdTree.printkdTree()

[2 5] 1


AttributeError: 'KdTree' object has no attribute 'preOrderKdTree'