In [2]:
import numpy as np
import random

In [3]:
class Tree:
    def __init__(self, data=None, parent=None):
        self.left = None
        self.right = None
        self.data = data
        self.parent = parent
        self.neighbor = None
    
    def traverse(self):
        print(self.data)
        
        if self.left:
            self.left.traverse()
        if self.right:
            self.right.traverse()
    
    def __getattribute__(self, item):
        if item == 'neighbor':
            if self.parent == None:
                return None
            else:
                if self == self.parent.left:
                    return self.parent.right
                else:
                    return self.parent.left
        else:
            return super().__getattribute__(item)

In [4]:
tree = Tree(1)
left_node = Tree(2, tree)
right_node = Tree(3,tree)

In [5]:
tree.left = left_node
tree.right = right_node
left_node.parent = tree
right_node.parent = tree

In [6]:
right_node.neighbor.data

2

In [7]:
tree.traverse()

1
2
3


In [213]:
class KNN:
    def __init__(self, x, y, k=1):
        self.train_x = x
        self.train_y = y
        self.k = k
        self.tree = Tree()
        self.cur_node = None
        self.min_dis = float('inf')
        
    def clean(self):
        self.cur_node = None
        self.min_dis = float('inf')
        
    def median(self, x_choose):
        x_set = sorted(x_choose)
        
        return x_set[int(len(x_set) / 2)]
    
    def kv_tree(self, x_set:np.array, root, deep):
        
        if len(x_set) <= 1:
            return
        
        l = (deep % self.k) + 1
            
        x_choose = x_set[:, l - 1]
        med = self.median(x_choose)

        left_data = [x for x in x_set if x[l - 1] < med]
        right_data = [x for x in x_set if x[l - 1] > med]
        
        root.data = np.array([x for x in x_set if x[l - 1] == med])

        #left_data = np.where(len(left_data) > 0, np.array(left_data), None)
        #right_data = np.where(len(right_data) > 0, np.array(right_data), None)
        left_data = np.array(left_data)
        left_node = Tree(left_data)
        left_node.parent = root
        root.left = left_node
        
        right_data = np.array(right_data)
        right_node = Tree(right_data)
        right_node.parent = root
        root.right = right_node
        
        self.kv_tree(left_data, root.left, deep + 1)
        self.kv_tree(right_data, root.right, deep + 1)
        '''
        这段的目的是不构造空叶子节点，
        但是缺少空节点有个问题，kv树搜索的时候，不从叶子节点开始，会导致某些节点没有被搜索到
        if len(left_data) > 0:
            left_data = np.array(left_data)
            left_node = Tree(left_data)
            root.left = left_node
            left_node.parent = root
            self.kv_tree(left_data, root.left, deep + 1)
        else:
            root.left = None
        
        if len(right_data) > 0:
            right_data = np.array(right_data)
            right_node = Tree(right_data)
            root.right = right_node
            right_node.parent = root
            self.kv_tree(right_data, root.right, deep + 1)
        else:
            root.right = None  
        '''
      
        
    def predict(self, x):
        self.clean()
        
        leaf, deep = self.search_leaf(self.tree, x, 0)
        
        self.cur_node = leaf
        
        self.search_back(leaf, x, deep)
        
        return self.cur_node, self.min_dis

    
    def update_min_dis(self, node, x):
        #print(len(node.data))
        if len(node.data) <= 0:
            return
        
        dis = [np.linalg.norm(n - x) for n in node.data]
        min_dis = min(dis)
        if min_dis < self.min_dis:
            #str_data = lambda d : 'None' if d == None else d.data
            #print("update node {} => {}".format(str_data(self.cur_node), node.data))
            #print("update dis {} => {}".format(self.min_dis, dis))
            self.min_dis = min_dis
            self.cur_node = node
    
    def search_neighbor(self, node, x, deep):
        
        nodes = [[node, deep]]
        
        while len(nodes):
            cur_node, cur_deep = nodes.pop(0)
            #print("search neighbor: {}".format(cur_node.data))
            self.update_min_dis(cur_node, x)
            if cur_node.left:
                nodes += [[cur_node.left, cur_deep - 1]]
            if cur_node.right:
                nodes += [[cur_node.right, cur_deep - 1]]
                 
    def search_back(self, node, x, deep):
        
        if deep <= 0:
            return
        
        #print("----- {}".format(node.data))
        parent = node.parent
        #print("parent {}".format(parent.data))
        
        l = (deep % self.k) + 1 - 1
        
        # intersect
        if x[l] - parent.data[0][l] <= self.min_dis:
            # check if min distance in parent node
            self.update_min_dis(parent, x)
                
            # check neighborhood
            self.search_neighbor(node.neighbor, x, deep)
            
            self.search_back(parent, x, deep - 1)
            
        else:
            return self.search_back(parent, x, deep - 1)
    
    def search_leaf(self, tree, x, deep):
        #print(tree.data)
        if len(tree.data) <= 0:
            #self.update_min_dis(tree.parent, x)
            return tree, deep
        
        l = (deep % self.k) + 1 - 1       
        
        if x[0] < tree.data[0][l]:
            if tree.left == None:
                self.update_min_dis(tree, x)
                return tree, deep
            else:
                return self.search_leaf(tree.left, x, deep + 1)
        else:
            if tree.right == None:  
                self.update_min_dis(tree, x)
                return tree, deep
            else:
                return self.search_leaf(tree.right, x, deep + 1)

In [169]:
train_x = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
train_x = np.array(train_x)
train_x

array([[2, 3],
       [5, 4],
       [9, 6],
       [4, 7],
       [8, 1],
       [7, 2]])

In [175]:
train_y = np.array([0, 1, 1, 0, 1, 0])
train_y

array([0, 1, 1, 0, 1, 0])

In [214]:
knn = KNN(train_x, train_y, 2)
knn.kv_tree(knn.train_x, knn.tree, 0)
knn.tree.traverse()

[[7 2]]
[[5 4]]
[[2 3]]
[[4 7]]
[[9 6]]
[[8 1]]
[]


In [201]:
test_x = np.array([1,1])
test_x

array([1, 1])

In [215]:
knn.search_leaf(knn.tree, test_x, 0)[0].data

array([[2, 3]])

In [216]:
knn.predict(test_x)[0].data

array([[2, 3]])

In [217]:
def test_knn(train_x, test_x):
    dis = [np.linalg.norm(n - test_x) for n in train_x]
    #print(dis)
    return np.array([train_x[np.argmin(dis)]]), min(dis)

In [194]:
test_knn(train_x, test_x)

(array([[8, 1]]), 1.4142135623730951)

In [199]:
def random_test():
    test_x = np.random.randint(10, size=(100, 2))
    #test_x = [[5,4]]

    for x in test_x:
        print(x)
        knn_dis = knn.predict(x)[1]
        test_dis = test_knn(train_x, x)[1]
        if knn_dis != test_dis:
            print(x)
            print("dis knn {}, test {}".format(knn_dis, test_dis))
            print(knn.predict(x)[0].data)
            print(test_knn(train_x, x)[0])
    
    pass

In [200]:
random_test()

[1 1]


ValueError: min() arg is an empty sequence

KV树在构建的时候，有几个问题
首先它的中位数，在遇到数据个数是偶数的的时候，不采用中位数平均的形势，因为平均可能会导致root节点没有任何数据
为了保证root有数据，所以不取平均