In [0]:
import os 
import struct
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def read_velodyne_bin(path):
    '''
    :param path:
    :return: homography matrix of the point cloud, N*3
    '''
    pc_list = []
    with open(path, 'rb') as f:
        content = f.read()
        pc_iter = struct.iter_unpack('ffff', content)
        for idx, point in enumerate(pc_iter):
            pc_list.append([point[0], point[1], point[2]])
    return np.asarray(pc_list, dtype=np.float32)

path='/content/drive/My Drive/lesson2code/000000.bin'
origindata=read_velodyne_bin(path)



In [0]:
import copy


class DistIndex:
    def __init__(self, distance, index):
        self.distance = distance
        self.index = index

    def __lt__(self, other):
        return self.distance < other.distance


class KNNResultSet:
    def __init__(self, capacity):
        self.capacity = capacity
        self.count = 0
        self.worst_dist = 1e10
        self.dist_index_list = []
        for i in range(capacity):
            self.dist_index_list.append(DistIndex(self.worst_dist, 0))

        self.comparison_counter = 0

    def size(self):
        return self.count

    def full(self):
        return self.count == self.capacity

    def worstDist(self):
        return self.worst_dist

    def add_point(self, dist, index):
        self.comparison_counter += 1
        if dist > self.worst_dist:
            return

        if self.count < self.capacity:
            self.count += 1

        i = self.count - 1
        while i > 0:
            if self.dist_index_list[i-1].distance > dist:
                self.dist_index_list[i] = copy.deepcopy(self.dist_index_list[i-1])
                i -= 1
            else:
                break

        self.dist_index_list[i].distance = dist
        self.dist_index_list[i].index = index
        self.worst_dist = self.dist_index_list[self.capacity-1].distance
        
    def __str__(self):
        output = ''
        for i, dist_index in enumerate(self.dist_index_list):
            output += '%d - %.2f\n' % (dist_index.index, dist_index.distance)
        output += 'In total %d comparison operations.' % self.comparison_counter
        return output


class RadiusNNResultSet:
    def __init__(self, radius):
        self.radius = radius
        self.count = 0
        self.worst_dist = radius
        self.dist_index_list = []

        self.comparison_counter = 0

    def size(self):
        return self.count

    def worstDist(self):
        return self.radius

    def add_point(self, dist, index):
        self.comparison_counter += 1
        if dist > self.radius:
            return

        self.count += 1
        self.dist_index_list.append(DistIndex(dist, index))

    def __str__(self):
        self.dist_index_list.sort()
        output = ''
        for i, dist_index in enumerate(self.dist_index_list):
            output += '%d - %.2f\n' % (dist_index.index, dist_index.distance)
        output += 'In total %d neighbors within %f.\nThere are %d comparison operations.' \
                  % (self.count, self.radius, self.comparison_counter)
        return output

In [0]:
import random
import math
import numpy as np
import time

def bruteSearch(db: np.ndarray,result_set:KNNResultSet, query: np.ndarray):
    
    diff = np.linalg.norm(np.expand_dims(query, 0) - db, axis=1)
    nn_idx = np.argsort(diff)
    nn_dist = diff[nn_idx]
    for idx, distindex in enumerate(result_set.dist_index_list):
        distindex.distance=nn_dist[idx]
        distindex.index=nn_idx[idx]
    return False
k=8
brute_time_sum=0
for ind ,p in enumerate(origindata):
    begin_t = time.time()
    result_set = KNNResultSet(capacity=k)
    bruteSearch(origindata,result_set,p)
    print(ind/len(origindata))
    
    
    brute_time_sum += time.time() - begin_t
print('brute:',brute_time_sum)

#耗时1634.29秒
#耗时1348.81秒-GPU

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
0.9599095196842814
0.9599175409888664
0.9599255622934514
0.9599335835980364
0.9599416049026214
0.9599496262072064
0.9599576475117914
0.9599656688163763
0.9599736901209612
0.9599817114255462
0.9599897327301312
0.9599977540347162
0.9600057753393012
0.9600137966438862
0.9600218179484712
0.9600298392530561
0.9600378605576411
0.9600458818622261
0.960053903166811
0.960061924471396
0.960069945775981
0.960077967080566
0.9600859883851509
0.9600940096897359
0.9601020309943209
0.9601100522989059
0.9601180736034909
0.9601260949080759
0.9601341162126609
0.9601421375172458
0.9601501588218307
0.9601581801264157
0.9601662014310007
0.9601742227355857
0.9601822440401707
0.9601902653447557
0.9601982866493406
0.9602063079539256
0.9602143292585106
0.9602223505630956
0.9602303718676806
0.9602383931722656
0.9602464144768506
0.9602544357814355
0.9602624570860204
0.9602704783906054
0.9602784996951904
0.9602865209997754
0.9602945423043604
0.9603025636089454
0.96031058491

In [0]:
from scipy import spatial
import time

k=8

def scipyKdtreeSearch(tree:spatial.KDTree,result_set:KNNResultSet,point: np.ndarray):
    scipy_nn_dis,scipy_nn_idx=tree.query(point,result_set.capacity)
    for idx, distindex in enumerate(result_set.dist_index_list):
        distindex.distance=scipy_nn_dis[idx]
        distindex.index=scipy_nn_idx[idx]
    return False


construction_time_sum = 0
knn_time_sum = 0

begin_t = time.time()
tree = spatial.KDTree(origindata)
construction_time_sum += time.time() - begin_t

for ind ,p in enumerate(origindata):
    begin_t = time.time()
    result_set = KNNResultSet(capacity=k)
    scipyKdtreeSearch(tree,result_set,p)
    knn_time_sum += time.time() - begin_t
    print(ind/len(origindata))
print('scipykdtreebuilding:',construction_time_sum)
print('scipykdtree:',knn_time_sum)

#耗时0.6428+113.64秒-GPU

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
0.9599175409888664
0.9599255622934514
0.9599335835980364
0.9599416049026214
0.9599496262072064
0.9599576475117914
0.9599656688163763
0.9599736901209612
0.9599817114255462
0.9599897327301312
0.9599977540347162
0.9600057753393012
0.9600137966438862
0.9600218179484712
0.9600298392530561
0.9600378605576411
0.9600458818622261
0.960053903166811
0.960061924471396
0.960069945775981
0.960077967080566
0.9600859883851509
0.9600940096897359
0.9601020309943209
0.9601100522989059
0.9601180736034909
0.9601260949080759
0.9601341162126609
0.9601421375172458
0.9601501588218307
0.9601581801264157
0.9601662014310007
0.9601742227355857
0.9601822440401707
0.9601902653447557
0.9601982866493406
0.9602063079539256
0.9602143292585106
0.9602223505630956
0.9602303718676806
0.9602383931722656
0.9602464144768506
0.9602544357814355
0.9602624570860204
0.9602704783906054
0.9602784996951904
0.9602865209997754
0.9602945423043604
0.9603025636089454
0.9603105849135304
0.96031860621

In [0]:
# kdtree的具体实现，包括构建和查找

import random
import math
import numpy as np
import time


# Node类，Node是tree的基本组成元素
class Node:
    def __init__(self, axis, value, left, right, point_indices):
        self.axis = axis
        self.value = value
        self.left = left
        self.right = right
        self.point_indices = point_indices

    def is_leaf(self):
        if self.value is None:
            return True
        else:
            return False

    def __str__(self):
        output = ''
        output += 'axis %d, ' % self.axis
        if self.value is None:
            output += 'split value: leaf, '
        else:
            output += 'split value: %.2f, ' % self.value
        output += 'point_indices: '
        output += str(self.point_indices.tolist())
        return output

# 功能：构建树之前需要对value进行排序，同时对一个的key的顺序也要跟着改变
# 输入：
#     key：键
#     value:值
# 输出：
#     key_sorted：排序后的键
#     value_sorted：排序后的值
def sort_key_by_vale(key, value):
    assert key.shape == value.shape
    assert len(key.shape) == 1
    sorted_idx = np.argsort(value)
    key_sorted = key[sorted_idx]
    value_sorted = value[sorted_idx]
    return key_sorted, value_sorted


def axis_round_robin(axis, dim):
    if axis == dim-1:
        return 0
    else:
        return axis + 1

# 功能：通过递归的方式构建树
# 输入：
#     root: 树的根节点
#     db: 点云数据
#     point_indices：排序后的键
#     axis: scalar
#     leaf_size: scalar
# 输出：
#     root: 即构建完成的树
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
    if root is None:
        root = Node(axis, None, None, None, point_indices)

    # determine whether to split into left and right
    if len(point_indices) > leaf_size:
        # --- get the split position ---
        point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis])  # M
        
        # 作业1
        # 屏蔽开始
        middle_left_idx = math.ceil(point_indices_sorted.shape[0] / 2) - 1
        middle_left_point_idx = point_indices_sorted[middle_left_idx]
        middle_left_point_value = db[middle_left_point_idx, axis]
        middle_right_idx = middle_left_idx + 1
        middle_right_point_idx = point_indices_sorted[middle_right_idx]
        middle_right_point_value = db[middle_right_point_idx, axis]
        root.value = (middle_left_point_value + middle_right_point_value) * 0.5


        root.left=kdtree_recursive_build(root.left,db,point_indices_sorted[0:middle_right_idx],axis_round_robin(axis, dim=db.shape[1]),leaf_size)
        root.right=kdtree_recursive_build(root.right,db,point_indices_sorted[middle_right_idx:],axis_round_robin(axis, dim=db.shape[1]),leaf_size)
        # 屏蔽结束
    return root


# 功能：翻转一个kd树
# 输入：
#     root：kd树
#     depth: 当前深度
#     max_depth：最大深度
def traverse_kdtree(root: Node, depth, max_depth):
    depth[0] += 1
    if max_depth[0] < depth[0]:
        max_depth[0] = depth[0]

    if root.is_leaf():
        print(root)
    else:
        traverse_kdtree(root.left, depth, max_depth)
        traverse_kdtree(root.right, depth, max_depth)

    depth[0] -= 1

# 功能：构建kd树（利用kdtree_recursive_build功能函数实现的对外接口）
# 输入：
#     db_np：原始数据
#     leaf_size：scale
# 输出：
#     root：构建完成的kd树
def kdtree_construction(db_np, leaf_size):
    N, dim = db_np.shape[0], db_np.shape[1]

    # build kd_tree recursively
    root = None
    root = kdtree_recursive_build(root,db_np,np.arange(N),axis=0,leaf_size=leaf_size)
    return root


# 功能：通过kd树实现knn搜索，即找出最近的k个近邻
# 输入：
#     root: kd树
#     db: 原始数据
#     result_set：搜索结果
#     query：索引信息
# 输出：
#     搜索失败则返回False
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
    if root is None:
        return False

    if root.is_leaf():
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False

    # 作业2
    # 提示：仍通过递归的方式实现搜索
    # 屏蔽开始
    if query[root.axis]<=root.value:
        kdtree_knn_search(root.left,db,result_set,query)
        if math.fabs(query[root.axis]-root.value)<result_set.worstDist():
            kdtree_knn_search(root.right,db,result_set,query)
    else:
        kdtree_knn_search(root.right,db,result_set,query)
        if math.fabs(query[root.axis]-root.value)<result_set.worstDist():
            kdtree_knn_search(root.left,db,result_set,query)
    

    # 屏蔽结束

    return False

# 功能：通过kd树实现radius搜索，即找出距离radius以内的近邻
# 输入：
#     root: kd树
#     db: 原始数据
#     result_set:搜索结果
#     query：索引信息
# 输出：
#     搜索失败则返回False
def kdtree_radius_search(root: Node, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
    if root is None:
        return False

    if root.is_leaf():
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False
    
    # 作业3
    # 提示：通过递归的方式实现搜索
    # 屏蔽开始
    if query[root.axis]<=root.value:
        kdtree_knn_search(root.left,db,result_set,query)
        if math.fabs(query[root.axis]-root.value)<result_set.worstDist():
            kdtree_knn_search(root.right,db,result_set,query)
    else:
        kdtree_knn_search(root.right,db,result_set,query)
        if math.fabs(query[root.axis]-root.value)<result_set.worstDist():
            kdtree_knn_search(root.left,db,result_set,query)
    # 屏蔽结束

    return False



def main():
    # configuration
    leaf_size = 10
    k = 8

    db_np = origindata
    construction_time_sum = 0
    begin_t = time.time()
    root = kdtree_construction(db_np, leaf_size=leaf_size)
    construction_time_sum += time.time() - begin_t
    
    # depth = [0]
    # max_depth = [0]
    # traverse_kdtree(root, depth, max_depth)
    # print("tree max depth: %d" % max_depth[0])

    
    
    
    knn_time_sum = 0
    for ind ,p in enumerate(origindata):
        begin_t = time.time()
        result_set = KNNResultSet(capacity=k)
        kdtree_knn_search(root, db_np, result_set, p)
        knn_time_sum += time.time() - begin_t
        print(ind/len(origindata))
    print('kdtreebuilding:',construction_time_sum)
    print('kdtree:',knn_time_sum)
    


if __name__ == '__main__':
    main()
    #耗时0.3410+198.71秒-GPU

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
0.9599175409888664
0.9599255622934514
0.9599335835980364
0.9599416049026214
0.9599496262072064
0.9599576475117914
0.9599656688163763
0.9599736901209612
0.9599817114255462
0.9599897327301312
0.9599977540347162
0.9600057753393012
0.9600137966438862
0.9600218179484712
0.9600298392530561
0.9600378605576411
0.9600458818622261
0.960053903166811
0.960061924471396
0.960069945775981
0.960077967080566
0.9600859883851509
0.9600940096897359
0.9601020309943209
0.9601100522989059
0.9601180736034909
0.9601260949080759
0.9601341162126609
0.9601421375172458
0.9601501588218307
0.9601581801264157
0.9601662014310007
0.9601742227355857
0.9601822440401707
0.9601902653447557
0.9601982866493406
0.9602063079539256
0.9602143292585106
0.9602223505630956
0.9602303718676806
0.9602383931722656
0.9602464144768506
0.9602544357814355
0.9602624570860204
0.9602704783906054
0.9602784996951904
0.9602865209997754
0.9602945423043604
0.9603025636089454
0.9603105849135304
0.96031860621

In [18]:
a=[[2,3]]
print(a[0][1])
import numpy as np
diff=np.linalg.norm([0,0,0]-[1,1,1])

3


TypeError: ignored

In [58]:
# kdtree的修改的具体实现，包括构建和查找

import random
import math
import numpy as np
import time


# Node类，Node是tree的基本组成元素
class Node:
    def __init__(self, axis, valueindex, left, right, point_indices):
        self.axis = axis
        self.valueindex = valueindex
        self.left = left
        self.right = right
        self.point_indices = point_indices

    def is_leaf(self):
        if len(self.point_indices)==0:
            return True
        else:
            return False

    def __str__(self):
        output = ''
        output += 'axis %d, ' % self.axis
        if self.valueindex is None:
            output += 'split value: leaf, '
        else:
            output +=( 'split valueindex: ' + str(self.valueindex))
        output += 'point_indices: '
        output += str(self.point_indices.tolist())
        return output

# 功能：构建树之前需要对value进行排序，同时对一个的key的顺序也要跟着改变
# 输入：
#     key：键
#     value:值
# 输出：
#     key_sorted：排序后的键
#     value_sorted：排序后的值
def sort_key_by_vale(key, value):
    assert key.shape == value.shape
    assert len(key.shape) == 1
    sorted_idx = np.argsort(value)
    key_sorted = key[sorted_idx]
    value_sorted = value[sorted_idx]
    return key_sorted, value_sorted


def axis_round_robin(axis, dim):
    if axis == dim-1:
        return 0
    else:
        return axis + 1

# 功能：通过递归的方式构建树
# 输入：
#     root: 树的根节点
#     db: 点云数据
#     point_indices：排序后的键
#     axis: scalar
#     leaf_size: scalar
# 输出：
#     root: 即构建完成的树
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
    if root is None:
        root = Node(axis, None, None, None, point_indices)

    # determine whether to split into left and right
    if len(point_indices) > leaf_size:
        # --- get the split position ---
        point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis])  # M
        middle_left_idx = math.ceil(point_indices_sorted.shape[0] / 2) - 1
        middle_left_point_idx = point_indices_sorted[middle_left_idx]
        middle_right_idx = middle_left_idx + 1
        
        root.valueindex=middle_left_point_idx
        
        root.left=kdtree_recursive_build(root.left,db,point_indices_sorted[0:middle_left_idx],axis_round_robin(axis, dim=db.shape[1]),leaf_size)
        root.right=kdtree_recursive_build(root.right,db,point_indices_sorted[middle_right_idx:],axis_round_robin(axis, dim=db.shape[1]),leaf_size)
        
        
    return root


# 功能：翻转一个kd树
# 输入：
#     root：kd树
#     depth: 当前深度
#     max_depth：最大深度
def traverse_kdtree(root: Node, depth, max_depth):
    depth[0] += 1
    if max_depth[0] < depth[0]:
        max_depth[0] = depth[0]

    if root.is_leaf():
        print(root)
    else:
        traverse_kdtree(root.left, depth, max_depth)
        traverse_kdtree(root.right, depth, max_depth)

    depth[0] -= 1

# 功能：构建kd树（利用kdtree_recursive_build功能函数实现的对外接口）
# 输入：
#     db_np：原始数据
#     leaf_size：scale
# 输出：
#     root：构建完成的kd树
def kdtree_construction(db_np, leaf_size):
    N, dim = db_np.shape[0], db_np.shape[1]

    # build kd_tree recursively
    root = None
    root = kdtree_recursive_build(root,db_np,np.arange(N),axis=0,leaf_size=leaf_size)
    return root


# 功能：通过kd树实现knn搜索，即找出最近的k个近邻
# 输入：
#     root: kd树
#     db: 原始数据
#     result_set：搜索结果
#     query：索引信息
# 输出：
#     搜索失败则返回False
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
    if root is None:
        return False

    
        
    diff=np.linalg.norm(query-db[root.valueindex])
    
    result_set.add_point(diff, root.valueindex)
    
    if root.is_leaf():
        return False

    
    if query[root.axis]<=db[root.valueindex][root.axis]:
        kdtree_knn_search(root.left,db,result_set,query)
        if math.fabs(query[root.axis]-db[root.valueindex][root.axis])<result_set.worstDist():
            kdtree_knn_search(root.right,db,result_set,query)
    else:
        kdtree_knn_search(root.right,db,result_set,query)
        if math.fabs(query[root.axis]-db[root.valueindex][root.axis])<result_set.worstDist():
            kdtree_knn_search(root.left,db,result_set,query)
    

    

    return False

# 功能：通过kd树实现radius搜索，即找出距离radius以内的近邻
# 输入：
#     root: kd树
#     db: 原始数据
#     result_set:搜索结果
#     query：索引信息
# 输出：
#     搜索失败则返回False
def kdtree_radius_search(root: Node, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
    if root is None:
        return False

    
    diff=np.linalg.norm(query-db[root.valueindex])
    result_set.add_point(diff, root.valueindex)

    if root.is_leaf():
        return False
    
    if query[root.axis]<=db[root.valueindex][root.axis]:
        
        kdtree_knn_search(root.left,db,result_set,query)
        if math.fabs(query[root.axis]-db[root.valueindex][root.axis])<result_set.worstDist():
            kdtree_knn_search(root.right,db,result_set,query)
    else:
        kdtree_knn_search(root.right,db,result_set,query)
        if math.fabs(query[root.axis]-db[root.valueindex][root.axis])<result_set.worstDist():
            kdtree_knn_search(root.left,db,result_set,query)
    

    return False



def main():
    # configuration
    leaf_size = 0
    k = 8
    db_size = 64

    dim = 3
    db_np = np.random.rand(db_size, dim)
    root = kdtree_construction(db_np, leaf_size=leaf_size)


    query = np.asarray([0, 0, 0])
    result_set = KNNResultSet(capacity=k)
    kdtree_knn_search(root, db_np, result_set, query)
    print(result_set)
    diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
    nn_idx = np.argsort(diff)
    nn_dist = diff[nn_idx]
    print(nn_idx[0:k])
    print(nn_dist[0:k])
    
    


if __name__ == '__main__':
    main()
    

49 - 0.30
1 - 0.30
57 - 0.42
62 - 0.42
42 - 0.48
7 - 0.50
22 - 0.51
46 - 0.53
In total 76 comparison operations.
[49  1 57 62 42  7 22 46]
[0.29596767 0.29824499 0.41981616 0.423367   0.48134431 0.4973639
 0.51408302 0.52594974]


In [0]:
# octree的具体实现，包括构建和查找

import random
import math
import numpy as np
import time


# 节点，构成OCtree的基本元素
class Octant:
    def __init__(self, children, center, extent, point_indices, is_leaf):
        self.children = children
        self.center = center
        self.extent = extent
        self.point_indices = point_indices
        self.is_leaf = is_leaf

    def __str__(self):
        output = ''
        output += 'center: [%.2f, %.2f, %.2f], ' % (self.center[0], self.center[1], self.center[2])
        output += 'extent: %.2f, ' % self.extent
        output += 'is_leaf: %d, ' % self.is_leaf
        output += 'children: ' + str([x is not None for x in self.children]) + ", "
        output += 'point_indices: ' + str(self.point_indices)
        return output

# 功能：翻转octree
# 输入：
#     root: 构建好的octree
#     depth: 当前深度
#     max_depth：最大深度
def traverse_octree(root: Octant, depth, max_depth):
    depth[0] += 1
    if max_depth[0] < depth[0]:
        max_depth[0] = depth[0]

    if root is None:
        pass
    elif root.is_leaf:
        print(root)
    else:
        for child in root.children:
            traverse_octree(child, depth, max_depth)
    depth[0] -= 1

# 功能：通过递归的方式构建octree
# 输入：
#     root：根节点
#     db：原始数据
#     center: 中心
#     extent: 当前分割区间
#     point_indices: 点的key
#     leaf_size: scale
#     min_extent: 最小分割区间
def octree_recursive_build(root, db, center, extent, point_indices, leaf_size, min_extent):
    if len(point_indices) == 0:
        return None

    if root is None:
        root = Octant([None for i in range(8)], center, extent, point_indices, is_leaf=True)

    # determine whether to split this octant
    if len(point_indices) <= leaf_size or extent <= min_extent:
        root.is_leaf = True
    else:
        # 作业4
        # 屏蔽开始
        root.is_leaf=False
        children_point_indices = [[] for i in range(8)]
        for point_idx in point_indices:
            point_db = db[point_idx]
            morton_code = 0
            if point_db[0] > center[0]:
                morton_code = morton_code | 1
            if point_db[1] > center[1]:
                morton_code = morton_code | 2
            if point_db[2] > center[2]:
                morton_code = morton_code | 4
            children_point_indices[morton_code].append(point_idx)
        # create children

        factor = [-0.5, 0.5]
        for i in range(8):
            child_center_x = center[0] + factor[(i & 1) > 0] * extent
            child_center_y = center[1] + factor[(i & 2) > 0] * extent
            child_center_z = center[2] + factor[(i & 4) > 0] * extent
            child_extent = 0.5 * extent
            child_center = np.asarray([child_center_x, child_center_y, child_center_z])
            root.children[i] = octree_recursive_build(root.children[i],db,child_center,child_extent,children_point_indices[i],leaf_size,min_extent)
                                                                                                                                                              
        # 屏蔽结束
    return root

# 功能：判断当前query区间是否在octant内
# 输入：
#     query: 索引信息
#     radius：索引半径
#     octant：octree
# 输出：
#     判断结果，即True/False
def inside(query: np.ndarray, radius: float, octant:Octant):
    """
    Determines if the query ball is inside the octant
    :param query:
    :param radius:
    :param octant:
    :return:
    """
    query_offset = query - octant.center
    query_offset_abs = np.fabs(query_offset)
    possible_space = query_offset_abs + radius
    return np.all(possible_space < octant.extent)

# 功能：判断当前query区间是否和octant有重叠部分
# 输入：
#     query: 索引信息
#     radius：索引半径
#     octant：octree
# 输出：
#     判断结果，即True/False
def overlaps(query: np.ndarray, radius: float, octant:Octant):
    """
    Determines if the query ball overlaps with the octant
    :param query:
    :param radius:
    :param octant:
    :return:
    """
    query_offset = query - octant.center
    query_offset_abs = np.fabs(query_offset)

    # completely outside, since query is outside the relevant area
    max_dist = radius + octant.extent
    if np.any(query_offset_abs > max_dist):
        return False

    # if pass the above check, consider the case that the ball is contacting the face of the octant
    if np.sum((query_offset_abs < octant.extent).astype(np.int)) >= 2:
        return True

    # conside the case that the ball is contacting the edge or corner of the octant
    # since the case of the ball center (query) inside octant has been considered,
    # we only consider the ball center (query) outside octant
    x_diff = max(query_offset_abs[0] - octant.extent, 0)
    y_diff = max(query_offset_abs[1] - octant.extent, 0)
    z_diff = max(query_offset_abs[2] - octant.extent, 0)

    return x_diff * x_diff + y_diff * y_diff + z_diff * z_diff < radius * radius


# 功能：判断当前query是否包含octant
# 输入：
#     query: 索引信息
#     radius：索引半径
#     octant：octree
# 输出：
#     判断结果，即True/False
def contains(query: np.ndarray, radius: float, octant:Octant):
    """
    Determine if the query ball contains the octant
    :param query:
    :param radius:
    :param octant:
    :return:
    """
    query_offset = query - octant.center
    query_offset_abs = np.fabs(query_offset)

    query_offset_to_farthest_corner = query_offset_abs + octant.extent
    return np.linalg.norm(query_offset_to_farthest_corner) < radius

# 功能：在octree中查找信息
# 输入：
#    root: octree
#    db：原始数据
#    result_set: 索引结果
#    query：索引信息
def octree_radius_search_fast(root: Octant, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
    if root is None:
        return False

    # 作业5
    # 提示：尽量利用上面的inside、overlaps、contains等函数
    # 屏蔽开始
    if contains(query,result_set.worstDist(),root):
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False

    if root.is_leaf and len(root.point_indices) > 0:
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        # check whether we can stop search now
        return inside(query, result_set.worstDist(), root)

    for c ,croot in enumerate(root.children):
        if croot==None:
            continue
        if overlaps(query,result_set.worstDist(),croot)==False:
            continue
        if octree_radius_search(croot, db, result_set, query):
            return True
    

    # 屏蔽结束

    return inside(query, result_set.worstDist(), root)


# 功能：在octree中查找radius范围内的近邻
# 输入：
#     root: octree
#     db: 原始数据
#     result_set: 搜索结果
#     query: 搜索信息
def octree_radius_search(root: Octant, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
    if root is None:
        return False

    if root.is_leaf and len(root.point_indices) > 0:
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        # check whether we can stop search now
        return inside(query, result_set.worstDist(), root)

    # 作业6
    # 屏蔽开始
    point_db = query
    morton_code = 0
    if point_db[0] > root.center[0]:
        morton_code = morton_code | 1
    if point_db[1] > root.center[1]:
        morton_code = morton_code | 2
    if point_db[2] > root.center[2]:
        morton_code = morton_code | 4

    if octree_radius_search(root.children[morton_code],db,result_set,query):
        return True

    for c ,croot in enumerate(root.children):
        if c == morton_code or croot==None:
            continue
        if overlaps(query,result_set.worstDist(),croot)==False:
            continue
        if octree_radius_search(croot, db, result_set, query):
            return True
    # 屏蔽结束

    # final check of if we can stop search
    return inside(query, result_set.worstDist(), root)

# 功能：在octree中查找最近的k个近邻
# 输入：
#     root: octree
#     db: 原始数据
#     result_set: 搜索结果
#     query: 搜索信息
def octree_knn_search(root: Octant, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
    if root is None:
        return False

    if root.is_leaf and len(root.point_indices) > 0:
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        # check whether we can stop search now
        return inside(query, result_set.worstDist(), root)

    # 作业7
    # 屏蔽开始
    point_db = query
    morton_code = 0
    if point_db[0] > root.center[0]:
        morton_code = morton_code | 1
    if point_db[1] > root.center[1]:
        morton_code = morton_code | 2
    if point_db[2] > root.center[2]:
        morton_code = morton_code | 4

    if octree_knn_search(root.children[morton_code],db,result_set,query):
        return True

    for c ,croot in enumerate(root.children):
        if c == morton_code or croot==None:
            continue
        if overlaps(query,result_set.worstDist(),croot)==False:
            continue
        if octree_knn_search(croot, db, result_set, query):
            return True

    # 屏蔽结束

    # final check of if we can stop search
    return inside(query, result_set.worstDist(), root)

# 功能：构建octree，即通过调用octree_recursive_build函数实现对外接口
# 输入：
#    dp_np: 原始数据
#    leaf_size：scale
#    min_extent：最小划分区间
def octree_construction(db_np, leaf_size, min_extent):
    N, dim = db_np.shape[0], db_np.shape[1]
    db_np_min = np.amin(db_np, axis=0)
    db_np_max = np.amax(db_np, axis=0)
    db_extent = np.max(db_np_max - db_np_min) * 0.5
    db_center = db_np_min + db_extent

    root = None
    root = octree_recursive_build(root, db_np, db_center, db_extent, list(range(N)),
                                  leaf_size, min_extent)

    return root

def main():
    # configuration
    leaf_size = 4
    min_extent = 0.0001
    k = 8

    db_np = origindata
    construction_time_sum = 0
    begin_t = time.time()
    root = octree_construction(db_np, leaf_size, min_extent)
    construction_time_sum += time.time() - begin_t

    # depth = [0]
    # max_depth = [0]
    # traverse_octree(root, depth, max_depth)
    # print("tree max depth: %d" % max_depth[0])


    knn_time_sum = 0
    for ind ,p in enumerate(origindata):
        begin_t = time.time()
        result_set = KNNResultSet(capacity=k)
        octree_knn_search(root, db_np, result_set, p)
        knn_time_sum += time.time() - begin_t
        print(ind/len(origindata))
    print('octreebuilding:',construction_time_sum)
    print('octree:',knn_time_sum)

    # query = np.asarray([0, 0, 0])
    # result_set = KNNResultSet(capacity=k)
    # octree_knn_search(root, db_np, result_set, query)
    # print(result_set)
    
    # diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
    # nn_idx = np.argsort(diff)
    # nn_dist = diff[nn_idx]
    # print(nn_idx[0:k])
    # print(nn_dist[0:k])

    # begin_t = time.time()
    # print("Radius search normal:")
    # for i in range(100):
    #     query = np.random.rand(3)
    #     result_set = RadiusNNResultSet(radius=0.5)
    #     octree_radius_search(root, db_np, result_set, query)
    # # print(result_set)
    # print("Search takes %.3fms\n" % ((time.time() - begin_t) * 1000))

    # begin_t = time.time()
    # print("Radius search fast:")
    # for i in range(100):
    #     query = np.random.rand(3)
    #     result_set = RadiusNNResultSet(radius = 0.5)
    #     octree_radius_search_fast(root, db_np, result_set, query)
    # # print(result_set)
    # print("Search takes %.3fms\n" % ((time.time() - begin_t)*1000))



if __name__ == '__main__':
    main()
    #耗时10.97+567.97秒

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
0.9599175409888664
0.9599255622934514
0.9599335835980364
0.9599416049026214
0.9599496262072064
0.9599576475117914
0.9599656688163763
0.9599736901209612
0.9599817114255462
0.9599897327301312
0.9599977540347162
0.9600057753393012
0.9600137966438862
0.9600218179484712
0.9600298392530561
0.9600378605576411
0.9600458818622261
0.960053903166811
0.960061924471396
0.960069945775981
0.960077967080566
0.9600859883851509
0.9600940096897359
0.9601020309943209
0.9601100522989059
0.9601180736034909
0.9601260949080759
0.9601341162126609
0.9601421375172458
0.9601501588218307
0.9601581801264157
0.9601662014310007
0.9601742227355857
0.9601822440401707
0.9601902653447557
0.9601982866493406
0.9602063079539256
0.9602143292585106
0.9602223505630956
0.9602303718676806
0.9602383931722656
0.9602464144768506
0.9602544357814355
0.9602624570860204
0.9602704783906054
0.9602784996951904
0.9602865209997754
0.9602945423043604
0.9603025636089454
0.9603105849135304
0.96031860621

In [0]:
1 | 2

3