In [13]:
import numpy as np
import matplotlib.pyplot as plt
import sys
from sklearn.datasets.samples_generator import make_blobs

In [2]:
class TreeNone:
    def __init__(self):
        self.data = None
        self.left_subnode = None
        self.right_subnode = None
        self.parent_node = None
        self.is_calculated = False
        self.is_leaf = False
        self.depth = None
    

In [75]:
class KD_Tree:
    def __init__(self, n_samples=500):
        self.n_samples = n_samples
        self.training_x = self.generate_data()
    
    def generate_data(self):
        x, _ = make_blobs(n_samples=self.n_samples, n_features=2, centers=1)
        test = [[6.27,5.50],[1.24,-2.86],[17.05,-12.79],[-6.88,-5.40],[-2.96,-0.50],[7.75,-22.68],
                [10.80,-5.03],[-4.60,-10.55],[-4.96,12.61],[1.75,12.26],[15.31,-13.16],[7.83,15.70],[14.63,-0.35]]
        return np.array(test)
        return x
    
    def construct_kd_tree(self, training_samples, cur_node=None, depth=0):
        index = (depth % len(training_samples[0]))
        
        # 计算切分节点
        split_node = self.get_split_node(training_samples, index, cur_node)
        
        # 切分左右子树
        left_subtree = [x for x in training_samples if x[index] < split_node.data[index]]
        right_subtree = [x for x in training_samples if x[index] > split_node.data[index]]
        
        # 继续切分左右子树
        cur_node = split_node
        cur_node.depth = depth
        if len(left_subtree) != 0:
            if len(left_subtree) == 1:
                node = TreeNone()
                node.data = left_subtree[0]
                node.parent_node = cur_node
                node.is_leaf = True
                node.depth = depth + 1
                cur_node.left_subnode = node
            else:
                cur_node.left_subnode = self.construct_kd_tree(left_subtree, cur_node, depth+1)
        
        if len(right_subtree) != 0:
            if len(right_subtree) == 1:
                node = TreeNone()
                node.data = right_subtree[0]
                node.parent_node = cur_node
                node.is_leaf = True
                node.depth = depth + 1
                cur_node.right_subnode = node
            else:
                cur_node.right_subnode = self.construct_kd_tree(right_subtree, cur_node, depth+1)

        return cur_node
    
    def get_split_node(self, samples, index, cur_node):
        samples = sorted(samples, key=lambda x: x[index])
        num = len(samples)
        
        node = TreeNone()
        node.parent_node = cur_node
        node.data = samples[int((num+1)/2)-1] if num % 2 != 0 else samples[int(num/2)-1]
        
        return node
    
    def print_tree(self, node, depth=0):
        if node == None:
            print('None', 'True')
            return
        if node.is_leaf:
            print(node.depth, node.data, node.is_leaf)
        else:
            print(node.depth, node.data, node.is_leaf)
            self.print_tree(node.left_subnode, depth+1)
            self.print_tree(node.right_subnode, depth+1)
            
    def normal_predict(self, test_data):
        distance = [(d, np.sum(np.square(test_data-d))) for d in self.training_x]
        rank = sorted(distance, key=lambda x:x[-1])
        k_list = [s[0] for s in rank[:3]]
        print(k_list)
    
    def search_nodes(self, root, test_data):
        def calculate_distance(p1, p2):
            return np.sum(np.square(p1-p2))
        
        def get_max_distance(k_list, test_data):
            max_dis = np.max([np.sum(np.square(p-test_data)) for p in k_list])
            max_index = np.argmax([np.sum(np.square(p-test_data)) for p in k_list])
            return max_dis, max_index
            
        def update_list(k_list, new_point, test_data):
            max_dis, max_index = get_max_distance(k_list, test_data)
            new_dis = calculate_distance(new_point, test_data)
            if new_dis < max_dis:
                del k_list[max_index]
                k_list.append(new_point)

        def find_leaf(subroot, k_list):
            node = subroot
            while True:
                if node.is_leaf:
                    if len(k_list) < 3:
                        k_list.append(node.data)
                    else:
                        update_list(k_list, node.data, test_data)
                    node.is_calculated = True
                    break
                index = node.depth % len(node.data)
                if test_data[index] < node.data[index]:
                    node = node.left_subnode
                else:
                    node = node.right_subnode
                    
            return node
        
        k_list = []
        node = find_leaf(root, k_list)
    
        while node.parent_node != None:
            node = node.parent_node
            if node.is_calculated:
                continue
            
            if len(k_list) < 3:
                k_list.append(node.data)
            else:
                update_list(k_list, node.data, test_data)
            node.is_calculated = True
            
            # 寻找当前节点下未被访问过的分支
            # 搜索过程是从子节点向上回溯，因此最多只可能有一个未访问的分支
            branch = None
            if (node.left_subnode != None and node.left_subnode.is_calculated == False):
                brach = node.left_subnode
            elif (node.right_subnode != None and node.right_subnode.is_calculated == False):
                branch = node.right_subnode
            # 如果当前节点没有未访问过的分支 或 预测点到切分线距离不小于所有候选点，不用继续向下寻找
            if branch == None:
                continue            
            # 计算当前节点的切分线
            index = node.depth % len(node.data)
            split_line = np.zeros((node.data.shape[-1]))
            split_line[index] = node.data[index]
            point2line = calculate_distance(test_data, split_line)
            max_dis, _ = get_max_distance(k_list, test_data)
            if point2line >= max_dis:
                continue
            else:
                # 再次找到当前分支下符合条件的叶节点，同时更新候选列表
                node = find_leaf(branch, k_list)
        
        print(k_list)

In [76]:
kd = KD_Tree(n_samples=5)
root = kd.construct_kd_tree(training_samples=kd.training_x)
# kd.print_tree(root)
kd.search_nodes(root, [-1, -5])
kd.normal_predict([-1, -5])

[array([-6.88, -5.4 ]), array([ 1.24, -2.86]), array([-2.96, -0.5 ])]
[array([ 1.24, -2.86]), array([-2.96, -0.5 ]), array([-6.88, -5.4 ])]
