In [32]:
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import numpy as np

# 获取MNIST数据集,并抽样一部分数据以便后续的计算
idx = np.random.choice(70000,5000,replace=False)
mnist = fetch_openml("mnist_784")
X, y = mnist.data.to_numpy(), mnist.target.to_numpy().astype('int')
X = X[idx]
y = y[idx]

# 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(X_train.shape, y_train.shape)


(4000, 784) (4000,)


# 基于sklearn的kd树

In [33]:
# 创建KNeighborsClassifier模型，使用kd树作为搜索算法
knn = KNeighborsClassifier(n_neighbors=3, algorithm='kd_tree')

# 在训练集上训练模型
knn.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = knn.predict(X_test)

# 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
print(f"KNN Accuracy: {accuracy * 100:.2f}%")

KNN Accuracy: 92.60%


# 自定义kd树

In [39]:
# 定义KD树节点类
class Node:
    def __init__(self, data, left=None, right=None):
        self.data = data #节点本身的数据
        self.left = left #节点的左子树
        self.right = right #节点右子树
    def __iter__(self):
        yield self.data

# 递归方法构建KD树

def build_kd_tree(X, depth=0):
    if len(X) == 0:
        return None
    k = X.shape[1]
    axis = depth % k #根据当前深度，选择划分的维度
    X = X[X[:, axis].argsort()]
    median = X.shape[0] // 2 #将当前结点数据一分为二
    return Node(data=X[median], left=build_kd_tree(X[:median], depth + 1), right=build_kd_tree(X[median + 1:], depth + 1))

# 计算点之间的距离，这里使用欧几里得距离
def euclidean_distance(x1, x2):
    return np.sqrt(np.sum((x1 - x2) ** 2))

# 搜索KD树
def search_kd_tree(tree, target, k=3):
    if tree is None:
        return []
    k_nearest = [] #list用于储存target当前遍历到的k个k近邻
    stack = [(tree, 0)] #用于储存待遍历节点的stack
    while stack:
        node, depth = stack.pop() # 节点出栈
        if node is None:
            continue
        distance = euclidean_distance(target, node.data) #计算需要分类的目标点与节点的距离
        if len(k_nearest) < k: # 当k_nearest未装满时，直接将节点放入
            k_nearest.append(node.data)
        else: #当k_nearest装满时，对比该节点与k_nearest中与目标点距离最远的节点的距离，如果小于则替换，如果大于则不替换
            max_distance = 0
            max_pos = 0
            for i in range(k):
                if max_distance < euclidean_distance(k_nearest[i],target):
                    max_distance = euclidean_distance(k_nearest[i],target)
                    max_pos = i
            if distance < max_distance:
                k_nearest[max_pos] = node.data
        axis = depth % target.shape[0] #计算当前深度对应的划分维度
        axis_diff = target[axis] - node.data[axis] #计算该维度下目标点与当前节点的差
        if axis_diff <= 0: #当差小于0时则，该节点的左子树入栈 #如果k_nearest未装满或k_nearest中相距目标点最远的点与目标点的距离大于axis_diff的绝对值时，则右子树也入栈
            if node.left is not None:
                stack.append((node.left,depth+1))
            if node.right is not None:
                if len(k_nearest) < k :
                    stack.append((node.right,depth+1))
                else:
                    max_distance = 0
                    for i in range(len(k_nearest)):
                        if max_distance < euclidean_distance(k_nearest[i],target):
                            max_distance = euclidean_distance(k_nearest[i],target)
                    if abs(axis_diff) < max_distance:
                        stack.append((node.right,depth+1))
        else:#当差大于0时则，该节点的右子树入栈，#如果k_nearest未装满或k_nearest中相距目标点最远的点与目标点的距离大于axis_diff的绝对值时，则左子树也入栈
            if node.right is not None:
                stack.append((node.right,depth+1))
            if node.left is not None:
                if len(k_nearest) < k :
                    stack.append((node.left,depth+1))
                else:
                    max_distance = 0
                    for i in range(len(k_nearest)):
                        if max_distance < euclidean_distance(k_nearest[i],target):
                            max_distance = euclidean_distance(k_nearest[i],target)
                    if abs(axis_diff) < max_distance:
                        stack.append((node.left,depth+1))
    return [data for data in k_nearest] #返回遍历完的kd树后的k_nearest

# 使用KNN算法分类
def knn_classifier(X_train, y_train, X_test, k=3):
    y_pred = []
    for test_point in X_test:
        k_nearest = search_kd_tree(kd_tree, test_point, k)
        labels = [y_train[np.where((X_train == point.data).all(axis=1))[0][0]] for point in k_nearest]
        counts = np.bincount(labels)#计算k_nearest中样本最多的标签，预测目标样本为该标签
        y_pred.append(np.argmax(counts))
    return y_pred

# 构建KD树
kd_tree = build_kd_tree(X_train)

# 使用KNN算法进行分类
k_neighbors = 3
y_pred = knn_classifier(X_train, y_train, X_test, k_neighbors)

# 评估分类性能
accuracy = accuracy_score(y_test, y_pred)
print(f"KNN Accuracy: {accuracy * 100:.2f}%")

KNN Accuracy: 92.60%
