In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math

In [2]:
def getSplit(data):
    #将数据进行分隔
    X=data.iloc[:,:-1]
    Y=data.iloc[:,-1]
    return X,Y

In [3]:
# 树的节点类
class node:
    def __init__(self,dim,data,label,rc,lc):
        #dim 是分类分类的维度
        #data 是分类的数据X
        #label 是分类的类别
        #rc 是右孩子
        #lc 是左孩子
        self.dim=dim
        self.data=data
        self.label=label
        self.rightChild=rc
        self.leftChild=lc
        self.parent=None

In [4]:
#kd树类
class kdTree:
    def __init__(self, data, labels):
        # root 是kd树的根节点
        self.__length = data.shape[0]
        self.root = None
        self.__createTree(data, labels)

    def getLength(self):
        return self.__length

    @staticmethod
    def createNode(data, labels):
        if data.shape[0] == 0:
            return None
        DimVar = np.var(data, axis=0)
        Dim = np.argmax(DimVar)  # 找到标准差最大的索引
        DimIndex = np.argsort(data[:, Dim], axis=0)  # 按照上一步得到的维度进行排序
        SortedData = data[DimIndex, :]  # 对传入的数据data按照得到的顺序进行重排序
        SortedLabel = labels[DimIndex]  # 对传入的数据labels按照得到的顺序进行重排序
        dimPos = data.shape[0] // 2  # 分界点

        # 创建根节点
        cur = node(Dim, SortedData[dimPos, :], SortedLabel[dimPos], None, None)
        if data.shape[0] != 0:
            leftData = SortedData[:dimPos, :]
            leftLabels = SortedLabel[:dimPos]
            lc = kdTree.createNode(leftData, leftLabels)
            if lc != None:
                lc.parent = cur
            cur.leftChild = lc

        if data.shape[0] > 1:
            rightDate = SortedData[dimPos + 1:, :]
            rightLabels = SortedLabel[dimPos + 1:]
            rc = kdTree.createNode(rightDate, rightLabels)
            if rc != None:
                rc.parent = cur
            cur.rightChild = rc
        return cur

    def __createTree(self, data, labels):
        # 这个方法写的是创建树的
        if type(data) != np.ndarray:
            data = np.array(data)
        if type(labels) != np.ndarray:
            labels = np.array(labels)
        self.root = self.createNode(data, labels)

    def KNNSearchMostNeighboorNode(self, node, rootNode):
        # 找到与他相近的叶子节点
        cur = rootNode
        parent = None
        while cur != None:
            if cur.data[cur.dim] >= node[cur.dim]:
                parent = cur
                cur = cur.leftChild
            elif cur.data[cur.dim] < node[cur.dim]:
                parent = cur
                cur = cur.rightChild
        return parent

    def __knnTools(self, data, neighboorNode, k, rootNode):
        if rootNode == None:
            return
        neighboor = self.KNNSearchMostNeighboorNode(data, rootNode)  # 找到我们离目标点最近的叶子节点
        nebDis = math.sqrt(((neighboor.data - data) ** 2).sum())  # 计算叶子节点的距离
        if len(neighboorNode) < k or neighboorNode[k - 1][1] > nebDis:
            neighboorNode.append((neighboor.label, nebDis, neighboor))  # 添加该叶子节点的信息添加到里边去
            neighboorNode.sort(key=lambda x: x[1])  # 按照距离进行从小到大排序
            neighboorNode = neighboorNode[:k]
        parent = neighboor.parent
        child = neighboor
        while parent != None:
            parentDis = math.sqrt(((parent.data - data) ** 2).sum())
            if len(neighboorNode) < k or neighboorNode[len(neighboorNode) - 1][1] > parentDis:
                neighboorNode.append((parent.label, parentDis, parent))
                neighboorNode.sort(key=lambda x: x[1])
                neighboorNode = neighboorNode[:k]
            nodeNum = k - len(neighboorNode)  # 要在他的兄弟节点里边要找的节点数
            for item in neighboorNode:
                if item[1] > abs(data[parent.dim] - parent.data[parent.dim]):
                    # 如果说距离大于的话我们就要在他的兄弟节点里边找了
                    nodeNum += 1  # 要寻找的节点的数目加1
            if nodeNum:
                subNeighboorNode = []
                brotherNode = parent.leftChild if parent.rightChild != None and parent.rightChild == child else parent.rightChild
                if brotherNode != None:
                    fatherNode=brotherNode.parent
                    brotherNode.parent=None
                    subNeighboorNode=self.__knnTools(data, subNeighboorNode, nodeNum, brotherNode)
                    neighboorNode.extend(subNeighboorNode)
                    neighboorNode.sort(key=lambda x: x[1])
                    neighboorNode = neighboorNode[:k]
                    brotherNode.parent=fatherNode
            child = parent
            parent = child.parent
        return neighboorNode

    def KNN(self, data, k):
        # kd树的搜索方法
        # k表示我们打算采取的近邻数
        # data表示我们目标查询的点的数据
        neighboorNode = []  # 这里写的是我们存的近邻点，点的距离和类别
        neighboorNode=self.__knnTools(data, neighboorNode, k, self.root)
        label = dict()
        for item in neighboorNode:
            num = label.get(item[0], 0)
            label[item[0]] = num + 1
        sortedKNN = sorted(label.items(), key=lambda x: x[1], reverse=True)
        for item in sortedKNN:
            print("类别为{}的有{}个".format(item[0],item[1]))
        return sortedKNN[0][0]

    def __toDict(self, rootNode):  # 根据给定的根节点进行遍历
        if rootNode == None:
            return None
        res = dict()
        res['data'] = rootNode.data
        ls = self.__toDict(rootNode.leftChild)
        if ls != None:
            res['left'] = ls
        rs = self.__toDict(rootNode.rightChild)
        if rs != None:
            res['right'] = rs
        return res

    def ToDict(self):  # 将一棵树变为一个字典显示
        return self.__toDict(self.root)

In [5]:
def testTree():
    #根据李航的书的55页写的建树例子，结果一样
    data=np.array([[7,2],[5,4],[9,6],[2,3],[4,7],[8,1]])
    label=np.array([1,2,1,2,1,5])
    rc=kdTree.kdTree(data,label)
    s=rc.ToDict()
    for k,v in s.items():
        print(k,'\t',v)
    print("------------")
    res=rc.KNN((8,6),2)
    print(res)

In [6]:
abalone=pd.read_table('abalone.txt',header=None)
abalone.columns=['性别','长度','直径','高度','整体重量','肉重量','内脏重量','壳重','年龄']
data,label=getSplit(abalone)
rc=kdTree(data.iloc[:4175,:],label[:4175])
print("kd树的节点个数:",rc.getLength())
x=np.array(data.iloc[-1,:])
y=label.iloc[-1]
print(rc.KNN(x,4))

kd树的节点个数: 4175
类别为11的有2个
类别为12的有1个
类别为10的有1个
11


In [7]:
rc=kdTree(data,label)
print("kd树的节点个数:",rc.getLength())
x=np.array(data.iloc[-1,:])
y=label.iloc[-1]
print(rc.KNN(x,4))

kd树的节点个数: 4177
类别为12的有2个
类别为11的有2个
12
