In [18]:
import numpy as np
import random

In [19]:
class Tree:
    def __init__(self, data=None, parent=None):
        self.left = None
        self.right = None
        self.data = data
        self.parent = parent
        self.neighbor = None
    
    def traverse(self):
        print(self.data)
        
        if self.left:
            self.left.traverse()
        if self.right:
            self.right.traverse()
    
    def __getattribute__(self, item):
        if item == 'neighbor':
            if self.parent == None:
                return None
            else:
                if self == self.parent.left:
                    return self.parent.right
                else:
                    return self.parent.left
        else:
            return super().__getattribute__(item)

In [20]:
tree = Tree(1)
left_node = Tree(2, tree)
right_node = Tree(3,tree)

In [21]:
tree.left = left_node
tree.right = right_node
left_node.parent = tree
right_node.parent = tree

In [22]:
right_node.neighbor.data

2

In [23]:
tree.traverse()

1
2
3


In [167]:
class KNN:
    def __init__(self, x=None, y=None, k=1):
        self.train_x = x
        self.train_y = y
        self.k = k
        self.tree = Tree()
        self.cur_node = None
        self.min_dis = float('inf')
        
    def clean(self):
        self.cur_node = None
        self.min_dis = float('inf')
        
    def median(self, x_choose):
        x_set = sorted(x_choose)
        
        return x_set[int(len(x_set) / 2)]
    
    def kv_tree(self, x_set:np.array, root, deep):
        
        if len(x_set) <= 1:
            return
        
        #x_set = x_set[:, 0:-1]
        
        l = (deep % self.k) + 1
        
        x_choose = x_set[:, l - 1]
        med = self.median(x_choose)

        left_data = [x for x in x_set if x[l - 1] < med]
        right_data = [x for x in x_set if x[l - 1] > med]
        
        root.data = np.array([x for x in x_set if x[l - 1] == med])

        #left_data = np.where(len(left_data) > 0, np.array(left_data), None)
        #right_data = np.where(len(right_data) > 0, np.array(right_data), None)
        left_data = np.array(left_data)
        left_node = Tree(left_data)
        left_node.parent = root
        root.left = left_node

        right_data = np.array(right_data)
        right_node = Tree(right_data)
        right_node.parent = root
        root.right = right_node

        self.kv_tree(left_data, root.left, deep + 1)
        self.kv_tree(right_data, root.right, deep + 1)
        '''
        这段的目的是不构造空叶子节点，
        但是缺少空节点有个问题，kv树搜索的时候，不从叶子节点开始，会导致某些节点没有被搜索到
        if len(left_data) > 0:
            left_data = np.array(left_data)
            left_node = Tree(left_data)
            root.left = left_node
            left_node.parent = root
            self.kv_tree(left_data, root.left, deep + 1)
        else:
            root.left = None
        
        if len(right_data) > 0:
            right_data = np.array(right_data)
            right_node = Tree(right_data)
            root.right = right_node
            right_node.parent = root
            self.kv_tree(right_data, root.right, deep + 1)
        else:
            root.right = None  
        '''
        
    def fit(self, x, y):
        self.train_x = x
        self.train_y = y.reshape(-1,1)
        
        self.train_x = np.append(self.train_x, self.train_y, axis=1)

        self.kv_tree(self.train_x, self.tree, 0)    
        
    def predict(self, x):
        self.clean()
        y = []
        
        for _x in x:
            leaf, deep = self.search_leaf(self.tree, _x, 0)
        
            #self.cur_node = leaf
        
            self.search_back(leaf, _x, deep)
            
            y.append(self.cur_node[-1])
            
            self.clean()
        return np.array(y)
        
        #return self.cur_node, self.min_dis

    
    def update_min_dis(self, node, x):
        #print(len(node.data))
        if len(node.data) <= 0:
            return
        
        dis = [np.linalg.norm(n[0:-1] - x) for n in node.data]
        min_dis = np.argmin(dis)
        if dis[min_dis] < self.min_dis:
            #str_data = lambda d : 'None' if d == None else d.data
            #print("update node {} => {}".format(str_data(self.cur_node), node.data))
            #print("update dis {} => {}".format(self.min_dis, dis[min_dis]))
            self.min_dis = dis[min_dis]
            self.cur_node = node.data[min_dis]
    
    def search_neighbor(self, node, x, deep):
        
        nodes = [[node, deep]]
        
        while len(nodes):
            cur_node, cur_deep = nodes.pop(0)
            #print("search neighbor: {}".format(cur_node.data))
            self.update_min_dis(cur_node, x)
            if cur_node.left:
                nodes += [[cur_node.left, cur_deep - 1]]
            if cur_node.right:
                nodes += [[cur_node.right, cur_deep - 1]]
                 
    def search_back(self, node, x, deep):
        
        if deep <= 0:
            return
        
        #print("----- {}".format(node.data))
        parent = node.parent
        #print("parent {}".format(parent.data))
        
        l = (deep % self.k) + 1 - 1
        
        # intersect
        if x[l] - parent.data[0][l] <= self.min_dis:
            # check if min distance in parent node
            self.update_min_dis(parent, x)
                
            # check neighborhood
            self.search_neighbor(node.neighbor, x, deep)
            
            self.search_back(parent, x, deep - 1)
            
        else:
            return self.search_back(parent, x, deep - 1)
    
    def search_leaf(self, tree, x, deep):
        #print(tree.data)
        if len(tree.data) <= 0:
            #self.update_min_dis(tree.parent, x)
            return tree, deep
        
        l = (deep % self.k) + 1 - 1       
        
        if x[0] < tree.data[0][l]:
            if tree.left == None:
                self.update_min_dis(tree, x)
                return tree, deep
            else:
                return self.search_leaf(tree.left, x, deep + 1)
        else:
            if tree.right == None:  
                self.update_min_dis(tree, x)
                return tree, deep
            else:
                return self.search_leaf(tree.right, x, deep + 1)

In [38]:
train_x = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
train_x = np.array(train_x)
train_x.shape

(6, 2)

In [39]:
train_y = np.array([0, 1, 1, 0, 1, 0])
train_y = train_y.reshape(-1,1)
train_y

array([[0],
       [1],
       [1],
       [0],
       [1],
       [0]])

In [40]:
knn = KNN(train_x, train_y, 2)
knn.fit(train_x, train_y)
knn.tree.traverse()

[[7 2 0]]
[[5 4 1]]
[[2 3 0]]
[[4 7 0]]
[[9 6 1]]
[[8 1 1]]
[]


In [41]:
test_x = np.array([4,3])
test_x

array([4, 3])

In [42]:
knn.search_leaf(knn.tree, test_x, 0)[0].data

update dis inf => 4.0


array([[4, 7, 0]])

In [43]:
knn.predict([test_x])

update dis inf => 4.0
update dis 4.0 => 1.4142135623730951


array([1])

In [44]:
def test_knn(train_x, test_x):
    dis = [np.linalg.norm(n - test_x) for n in train_x]
    #print(dis)
    return np.array([train_x[np.argmin(dis)]]), min(dis)

In [45]:
test_knn(train_x, test_x)

(array([[5, 4]]), 1.4142135623730951)

In [46]:
def random_test():
    test_x = np.random.randint(10, size=(100, 2))
    #test_x = [[5,4]]

    for x in test_x:
        knn_dis = knn.predict(x)[1]
        test_dis = test_knn(train_x, x)[1]
        if knn_dis != test_dis:
            print(x)
            print("dis knn {}, test {}".format(knn_dis, test_dis))
            print(knn.predict(x)[0].data)
            print(test_knn(train_x, x)[0])
    
    pass

KV树在构建的时候，有几个问题
首先它的中位数，在遇到数据个数是偶数的的时候，不采用中位数平均的形势，因为平均可能会导致root节点没有任何数据
为了保证root有数据，所以不取平均

问题2，当train_x特别稀疏的时候，比如`train_x[0,:1]`都是0的话，这样在构建kv树的时候，就有可能所有的x全都被放到根节点，
算法的复杂度就跌落到$O(n)$了

一个很奇怪的问题，如果不归一化处理，norm算出来的距离和sklearn knn的距离不一样(欧几里得距离)，很奇怪，归一化以后就正常了

In [48]:
import tensorflow as tf

In [168]:
mnist = tf.keras.datasets.mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data()

In [169]:
train_x = train_x.reshape(train_x.shape[0], -1) / 255.0
test_x = test_x.reshape(test_x.shape[0], -1) / 255.0

In [170]:
train_x_sample = train_x[:100]
train_y_sampe = train_y[:100]

In [171]:
knn = KNN(train_x, train_y, 2)
knn.fit(train_x, train_y)

In [173]:
knn_y = knn.predict(test_x[:10])

In [174]:
knn_y

array([7., 2., 1., 0., 4., 1., 4., 9., 5., 9.])

In [175]:
from sklearn.neighbors import KNeighborsClassifier

In [176]:
neigh = KNeighborsClassifier(n_neighbors=1,metric='euclidean')

In [177]:
neigh.fit(train_x, train_y)

KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='euclidean',
           metric_params=None, n_jobs=None, n_neighbors=1, p=2,
           weights='uniform')

In [180]:
sk_knn_y = neigh.predict(test_x)

In [182]:
from sklearn.metrics import accuracy_score

In [183]:
accuracy_score(sk_knn_y, test_y)

0.9691