<img src="image/3/kd_tree_image.png" alt="alt text" style="width: 30%; height: auto;">


In [5]:
import math
# kd-tree每个结点中主要包含的数据结构如下
class KdNode(object):
    def __init__(self, dom_elt, split,parent, left, right):
        self.dom_elt = dom_elt  # k维向量节点(k维空间中的一个样本点)
        self.split = split  # 整数（进行分割维度的序号）
        self.parent = parent  # 父节点
        self.left = left  # 该结点分割超平面左子空间构成的kd-tree
        self.right = right  # 该结点分割超平面右子空间构成的kd-tree


class KdTree(object):
    def __init__(self, data):
        if not data:
            return None
        k = len(data[0])  # 数据维度

        def CreateNode(parent_node,split, data_set):  # 按第split维划分数据集exset创建KdNode
            if not data_set:  # 数据集为空
                return None
            # key参数的值为一个函数，此函数只有一个参数且返回一个值用来进行比较
            # operator模块提供的itemgetter函数用于获取对象的哪些维的数据，参数为需要获取的数据在对象中的序号
            #data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
            data_set.sort(key=lambda x: x[split])
            split_pos = len(data_set) // 2  # //为Python中的整数除法
            median = data_set[split_pos]  # 中位数分割点
            split_next = (split + 1) % k  # cycle coordinates

            # 递归的创建kd树
            now_node = KdNode(
                median,
                split,
                parent_node,
                None,
                None
            )  # 创建当前节点
            now_node.left = CreateNode(now_node, split_next, data_set[:split_pos])  # 创建左子树
            now_node.right = CreateNode(now_node, split_next, data_set[split_pos + 1:])  # 创建右子树
            return now_node
            
        self.root = CreateNode(None,0, data)  # 从第0维分量开始构建kd树,返回根节点


        
     


# KDTree的前序遍历
def preorder(root):
    print(root.dom_elt,root.split)
    if root.left:  # 节点不为空
        preorder(root.left)
    if root.right:
        preorder(root.right)
 

In [7]:
data = [[1,1,1],[2,2,2],[3,3,3],[4,4,4],[5,5,5],[6,6,6],[7,7,7],[8,8,8],[9,9,9],[10,10,10]]
kd = KdTree(data)
preorder(kd.root)

[6, 6, 6] 0
[3, 3, 3] 1
[2, 2, 2] 2
[1, 1, 1] 0
[5, 5, 5] 2
[4, 4, 4] 0
[9, 9, 9] 1
[8, 8, 8] 2
[7, 7, 7] 0
[10, 10, 10] 2


In [8]:
print(kd.root.left.parent.dom_elt)

[6, 6, 6]


<img src="image/3/search_kd_tree_image.png" alt="alt text" style="width: 30%; height: auto;">

In [36]:
class Searcher:
    def __init__(self,new_point,kd_tree):
        self.new_point = new_point
        self.kd_tree = kd_tree
        self.nearest_point = None
        self.nearest_distance = float("inf")
    
    # 距离计算
    def caculate_Euclidean_distance(self,point1,point2):
        if point1 is None or point2 is None:
            return None
        if len(point1)!= len(point2):
            return None
        sum = 0
        for i in range(len(point1)):
            sum += (point1[i]-point2[i])**2
        return math.sqrt(sum)
    def caculate_Manhattan_distance(self,point1,point2):
        if point1 is None or point2 is None:
            return None
        if len(point1)!= len(point2):
            return None
        sum = 0
        for i in range(len(point1)):
            sum += math.fabs(point1[i]-point2[i])
        return sum
    
    # 返回距离输入点最近的区域节点，和他们之间的距离
    def find_pre_nearest_point(self,root):
            
            new_point=self.new_point
            if root is None:
                return None
            if root.left is None and root.right is None:
                return root,self.caculate_Euclidean_distance(root.dom_elt,new_point)
            split=root.split
            if new_point[split] <= root.dom_elt[split] and root.left is not None:
                return self.find_pre_nearest_point(root.left)
            if new_point[split] > root.dom_elt[split] and root.right is not None:
                return self.find_pre_nearest_point(root.right)
            else:
                 return root,self.caculate_Euclidean_distance(root.dom_elt,new_point)

    # 搜索
    def search(self, now_node, nearest_node, nearest_distance):
        """
        在 KD 树中搜索最近邻节点

        参数:
            now_node: 当前节点
            nearest_node: 当前最近邻节点
            nearest_distance: 当前最近邻距离

        返回:
            最近邻节点和最近邻距离
        """
        new_point = self.new_point
        # 当前节点为空或者已经是根节点，返回
        if now_node is None or now_node.parent is None:
            return None
        distance = self.caculate_Euclidean_distance(now_node.dom_elt, new_point)
        if distance < nearest_distance:
            nearest_distance = distance
            nearest_node = now_node
            if now_node.left is not None:
                left_nearest_node, left_nearest_distance = self.search(now_node.left, nearest_node, nearest_distance)
                if left_nearest_distance < nearest_distance:
                    nearest_distance = left_nearest_distance
                    nearest_node = left_nearest_node
            if now_node.right is not None:
                right_nearest_node, right_nearest_distance = self.search(now_node.right, nearest_node, nearest_distance)
                if right_nearest_distance < nearest_distance:
                    nearest_distance = right_nearest_distance
                    nearest_node = right_nearest_node
            else:
                return self.search(now_node.parent, nearest_node, nearest_distance)
        self.nearest_distance = nearest_distance
        self.nearest_point = nearest_node
        return nearest_node, nearest_distance

In [37]:
new_point = [6.5, -1,9]

searcher=Searcher(new_point,kd)
a,b=searcher.find_pre_nearest_point(kd.root)
print(a.dom_elt)
print(b)

[8, 8, 8]
9.17877987534291


In [42]:
new_point = [3,7,4]

searcher=Searcher(new_point,kd)
a,b=searcher.find_pre_nearest_point(kd.root)
nearest_point,nearest_distance =searcher.search(a.parent,a,b)
print(a.dom_elt)
print(b)

[4, 4, 4]
3.1622776601683795


In [43]:

print(nearest_point.dom_elt,nearest_distance)
print(searcher.nearest_point.dom_elt)
print(searcher.nearest_distance)

[5, 5, 5] 3.0
[5, 5, 5]
3.0
