In [2]:
from typing import List
from collections import namedtuple
import time
#导入函数

class Point(namedtuple("Point", "x y")):      #定义一个point类，元素类型为tuple，包含‘x y’属性
    def __repr__(self) -> str:                #定义一个__repr__方法
        return f'Point{tuple(self)!r}'        #返回格式化字符串


class Rectangle(namedtuple("Rectangle", "lower upper")):        #定义一个Rectangle类，元素类型为tuple，包含‘lower upper’属性
    def __repr__(self) -> str:        #定义一个__repr__方法
        return f'Rectangle{tuple(self)!r}'      #返回格式化字符串

    def is_contains(self, p: Point) -> bool:       #定义一个is_contains方法，判断该点是否在Rectangle中
        return self.lower.x <= p.x <= self.upper.x and self.lower.y <= p.y <= self.upper.y
        #self,lower.x<=p.x<=self.upper.x并且self.lower.y<=p.y<=self.upper.y,则该点位于Rectangle中,返回True


class Node:        #定义一个Node类，类型为tuple，包含“location left right”属性
    """
    location: Point
    left: Node
    right: Node
    """
    def __init__(self, location, depth=0, left=None, right=None):
        self.location, self.depth, self.left, self.right = \
            (location, depth, left, right)

    def __repr__(self):#定义一个__repr__方法
        return f'{tuple(self)!r}'#返回格式化字符串


class KDTree:      #定义一个KDTree类
    """k-d tree"""

    def __init__(self):    #构造函数，有root和n两个属性
        self._root = None
        self._n = 0

    def insert(self, p: List[Point],depth=0):        #定义一个insert方法，用来插入元素
        """insert a list of points"""
        self._n = 2
        if(len(p)<=0):#递归出口
            return None
        else:
            length = len(p)
            # 选择切分的维度
            aim_axis = depth % (self._n)

            # 排序寻找中位数
            sorted_p = sorted(p, key=lambda item: item[aim_axis])
            # 记录该空间对应的节点
            mid = length // 2    
            node = Node(sorted_p[mid],depth=depth)

                # 记录根节点
            if(depth == 0):
                self._root = node

            node.left = self.insert(sorted_p[:mid], depth=depth+1)   #递归记录左子树的节点
            node.right = self.insert(sorted_p[mid+1:], depth=depth+1)   #递归记录右子树的节点
            return node
                
 
                        
        
        
        
    def range(self, rectangle: Rectangle) -> List[Point]:   #定义一个range方法，用来寻找在rectangle范围内的point
        """range query"""
        result = []
        def Judge(p, rectangle):
            if p == None:   #设置递归出口
                return
            elif rectangle.is_contains(p.location):
                result.append(p.location)                           #如果该点在矩形范围中则加入到结果中
                Judge(p.left, rectangle)                            #同时继续查找左子树的节点
                Judge(p.right, rectangle)                           #继续查找右子树的节点
            else:
                if(p.depth % 2 == 0):     #如果depth为偶数，则该点子节点是按第一维度大小划分的
                    if(p.location[0] > rectangle.upper.x):           #该节点x大于矩形最大值x则往左子树继续查找
                        Judge(p.left, rectangle)
                    elif(p.location[0] < rectangle.lower.x):        #该节点x小于矩形最小x则往右子树继续查找
                        Judge(p.right, rectangle)
                    else:       #特殊情况x在矩形范围中，而y值不在矩形范围中，故左右子树都有可能，所以左右子树都要继续查找
                        Judge(p.left, rectangle)
                        Judge(p.right, rectangle)
                elif(p.depth % 2 == 1):                             #depth为奇数，则该点子节点是按照第二维度大小划分的
                    if(p.location[1] > rectangle.upper.y):          #该节点y值大于矩形最大y则往其左子树继续查找
                        Judge(p.left, rectangle)
                    elif(p.location[1] < rectangle.lower.y):        #该节点y值小于矩形最小y则往其右子树继续查找
                        Judge(p.left, rectangle)
                    else:        #特殊情况，y在矩形范围内，而x不在矩形范围中
                        Judge(p.left, rectangle)
                        Judge(p.right, rectangle)
        Judge(self._root, rectangle)
        return result         
                
        
def range_test():          #定义一个测试range方法的函数
    points = [Point(7, 2), Point(5, 4), Point(9, 6), Point(4, 7), Point(8, 1), Point(2, 3)]    #定义一些点
    kd = KDTree()          #构造一个KDTree
    kd.insert(points,0)          #插入点
    result = kd.range(Rectangle(Point(0, 0), Point(6, 6)))#用range方法遍历KDtree中在（0,0）-（6,6）范围中的点
    assert sorted(result) == sorted([Point(2, 3), Point(5, 4)])#检验这些点是否有且仅有（2,3）和（5,4）以判断range是否有效


def performance_test():#定义一个性能测试函数
    points = [Point(x, y) for x in range(1000) for y in range(1000)]     #构造一个x、y in range（1000）的点集

    lower = Point(500, 500)
    upper = Point(504, 504)     #定义lower和upper
    rectangle = Rectangle(lower, upper)    #构造rectangle
    #  naive method
    start = int(round(time.time() * 1000))    #记录起始时间
    result1 = [p for p in points if rectangle.is_contains(p)]    #运用for循环逐一判断point是否在rectangle中
    end = int(round(time.time() * 1000))       #记录结束时间
    print(f'Naive method: {end - start}ms')        #输出naive method的运行时间

    kd = KDTree()       #构造KDTree
    kd.insert(points)         #插入points
    # k-d tree
    start = int(round(time.time() * 1000))     #记录起始时间
    result2 = kd.range(rectangle)           #运用range方法遍历对比
    end = int(round(time.time() * 1000))           #记录结束时间
    print(f'K-D tree: {end - start}ms')          #输出k-d tree方法的运行时间

    assert sorted(result1) == sorted(result2)      #判断两种方法的结果是否相同


if __name__ == '__main__':
    range_test()#测试range方法
    performance_test()#测试kd tree的性能

Naive method: 118ms
K-D tree: 0ms


In [None]:
3.time comlexity of range query
First, the worst case is that all elements are in the rectangle.
So the time complexity is O(n).
Second, for the best cases.
if the first element's x position is larger than the upper's x, then all point in its right can't be in the rectangle.for the next element is the smae.
so the best cases is that we only need to search the leftmost linked list.
and the time complexity equals to its height, so it time comlexity is O(logn)
In general the time comlexity range from O(logn) to O(n)

In [None]:
4.time performance

In [3]:
def performance_test():#定义一个性能测试函数
    points = [Point(x, y) for x in range(2000) for y in range(1000)]     #构造一个x、y in range（1000）的点集

    lower = Point(500, 500)
    upper = Point(504, 504)     #定义lower和upper
    rectangle = Rectangle(lower, upper)    #构造rectangle
    #  naive method
    start = int(round(time.time() * 1000))    #记录起始时间
    result1 = [p for p in points if rectangle.is_contains(p)]    #运用for循环逐一判断point是否在rectangle中
    end = int(round(time.time() * 1000))       #记录结束时间
    print(f'Naive method: {end - start}ms')        #输出naive method的运行时间

    kd = KDTree()       #构造KDTree
    kd.insert(points)         #插入points
    # k-d tree
    start = int(round(time.time() * 1000))     #记录起始时间
    result2 = kd.range(rectangle)           #运用range方法遍历对比
    end = int(round(time.time() * 1000))           #记录结束时间
    print(f'K-D tree: {end - start}ms')          #输出k-d tree方法的运行时间

    assert sorted(result1) == sorted(result2)      #判断两种方法的结果是否相同

if __name__ == '__main__':
    performance_test()#测试kd tree的性能


Naive method: 273ms
K-D tree: 0ms


In [6]:
def performance_test():#定义一个性能测试函数
    points = [Point(x, y) for x in range(2000) for y in range(2000)]     #构造一个x、y in range（1000）的点集

    lower = Point(500, 500)
    upper = Point(504, 504)     #定义lower和upper
    rectangle = Rectangle(lower, upper)    #构造rectangle
    #  naive method
    start = int(round(time.time() * 1000))    #记录起始时间
    result1 = [p for p in points if rectangle.is_contains(p)]    #运用for循环逐一判断point是否在rectangle中
    end = int(round(time.time() * 1000))       #记录结束时间
    print(f'Naive method: {end - start}ms')        #输出naive method的运行时间

    kd = KDTree()       #构造KDTree
    kd.insert(points)         #插入points
    # k-d tree
    start = int(round(time.time() * 1000))     #记录起始时间
    result2 = kd.range(rectangle)           #运用range方法遍历对比
    end = int(round(time.time() * 1000))           #记录结束时间
    print(f'K-D tree: {end - start}ms')          #输出k-d tree方法的运行时间

    assert sorted(result1) == sorted(result2)      #判断两种方法的结果是否相同

if __name__ == '__main__':
    performance_test()#测试kd tree的性能


Naive method: 491ms
K-D tree: 0ms


In [7]:
def performance_test():#定义一个性能测试函数
    points = [Point(x, y) for x in range(3000) for y in range(2000)]     #构造一个x、y in range（1000）的点集

    lower = Point(500, 500)
    upper = Point(504, 504)     #定义lower和upper
    rectangle = Rectangle(lower, upper)    #构造rectangle
    #  naive method
    start = int(round(time.time() * 1000))    #记录起始时间
    result1 = [p for p in points if rectangle.is_contains(p)]    #运用for循环逐一判断point是否在rectangle中
    end = int(round(time.time() * 1000))       #记录结束时间
    print(f'Naive method: {end - start}ms')        #输出naive method的运行时间

    kd = KDTree()       #构造KDTree
    kd.insert(points)         #插入points
    # k-d tree
    start = int(round(time.time() * 1000))     #记录起始时间
    result2 = kd.range(rectangle)           #运用range方法遍历对比
    end = int(round(time.time() * 1000))           #记录结束时间
    print(f'K-D tree: {end - start}ms')          #输出k-d tree方法的运行时间

    assert sorted(result1) == sorted(result2)      #判断两种方法的结果是否相同

if __name__ == '__main__':
    performance_test()#测试kd tree的性能

Naive method: 746ms
K-D tree: 0ms


In [None]:
we can see that if range of points increases, the time performance of naive method increases
but the time performance for kd-tree remain consistant

In [None]:
5.implement the nearest neighbor query

In [None]:
    def divDataToLeftOrRight(self, find_data):
        '''
        根据传入的数据将其分给左节点(0)或右节点(1)
        '''
        data_value = find_data[self.split]
        if data_value < self.range:
            return 0
        else:
            return 1

    def getSearchPath(self, ls_path, find_data):
        '''
        二叉查找到叶节点上
        '''
        now_node = ls_path[-1]
        if now_node == None:
            return ls_path
        now_split = now_node.divDataToLeftOrRight(find_data)
        if now_split == 0:
            next_node = now_node.left
        else:
            next_node = now_node.right
        while(next_node!=None):
            ls_path.append(next_node)
            next_split = next_node.divDataToLeftOrRight(find_data)
            if next_split == 0:
                next_node = next_node.left
            else:
                next_node = next_node.right
        return ls_path
    def getNestNode(self, find_data, min_dist, min_data):
        '''
        回溯查找目标点的最近邻距离
        '''
        ls_path = []
        ls_path.append(self)
        self.getSearchPath(ls_path, find_data)
        now_node = ls_path.pop()
        now_node.isinvted = True
        min_data = now_node.nodedata
        min_dist = np.linalg.norm(find_data-min_data)
        while(len(ls_path)!=0):
            back_node = ls_path.pop()   ### 向上回溯一个节点
            if back_node.isinvted == True:
                continue
            else:
                back_node.isinvted = True
            back_dist = np.linalg.norm(find_data-back_node.nodedata)
            if back_dist < min_dist:
                min_data = back_node.nodedata
                min_dist = back_dist
            if np.abs(find_data[back_node.split]-back_node.range) < min_dist:
                ls_path.append(back_node)
                if back_node.left.isinvted == True:
                    if back_node.right == None:
                        continue
                    ls_path.append(back_node.right)
                else:
                    if back_node.left == None:
                        continue
                    ls_path.append(back_node.left)
                ls_path = back_node.getSearchPath(ls_path, find_data)
                now_node = ls_path.pop()
                now_node.isinvted = True
                now_dist = np.linalg.norm(find_data-now_node.nodedata)
                if now_dist < min_dist:
                    min_data = now_node.nodedata
                    min_dist = now_dist
        print("min distance:{}  min data:{}".format(min_dist, min_data))
        return min_dist