From cc238fa2ecae71280b6ed7f8b73f2d182ba9c0f0 Mon Sep 17 00:00:00 2001 From: DrDavidS Date: Wed, 18 Sep 2019 14:24:25 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86p=E6=97=A0=E7=A9=B7?= =?UTF-8?q?=E6=97=B6=E5=AF=B9=E5=88=87=E6=AF=94=E9=9B=AA=E5=A4=AB=E8=B7=9D?= =?UTF-8?q?=E7=A6=BB=E7=9A=84=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在 p=\infty 情况下,这应当是闵可夫斯基距离的特例,即切比雪夫距离。这里原文应当是把两者搞混了。 --- .../3.KNearestNeighbors-checkpoint.ipynb" | 1304 +++++++++++++++++ .../3.KNearestNeighbors.ipynb" | 4 +- 2 files changed, 1306 insertions(+), 2 deletions(-) create mode 100644 "\347\254\25403\347\253\240 k\350\277\221\351\202\273\346\263\225/.ipynb_checkpoints/3.KNearestNeighbors-checkpoint.ipynb" diff --git "a/\347\254\25403\347\253\240 k\350\277\221\351\202\273\346\263\225/.ipynb_checkpoints/3.KNearestNeighbors-checkpoint.ipynb" "b/\347\254\25403\347\253\240 k\350\277\221\351\202\273\346\263\225/.ipynb_checkpoints/3.KNearestNeighbors-checkpoint.ipynb" new file mode 100644 index 0000000..c474547 --- /dev/null +++ "b/\347\254\25403\347\253\240 k\350\277\221\351\202\273\346\263\225/.ipynb_checkpoints/3.KNearestNeighbors-checkpoint.ipynb" @@ -0,0 +1,1304 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 第3章 k近邻法" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1.$k$近邻法是基本且简单的分类与回归方法。$k$近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的$k$个最近邻训练实例点,然后利用这$k$个训练实例点的类的多数来预测输入实例点的类。\n", + "\n", + "2.$k$近邻模型对应于基于训练数据集对特征空间的一个划分。$k$近邻法中,当训练集、距离度量、$k$值及分类决策规则确定后,其结果唯一确定。\n", + "\n", + "3.$k$近邻法三要素:距离度量、$k$值的选择和分类决策规则。常用的距离度量是欧氏距离及更一般的**pL**距离。$k$值小时,$k$近邻模型更复杂;$k$值大时,$k$近邻模型更简单。$k$值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的$k$。\n", + "\n", + "常用的分类决策规则是多数表决,对应于经验风险最小化。\n", + "\n", + "4.$k$近邻法的实现需要考虑如何快速搜索k个最近邻点。**kd**树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树,表示对$k$维空间的一个划分,其每个结点对应于$k$维空间划分中的一个超矩形区域。利用**kd**树可以省去对大部分数据点的搜索, 从而减少搜索的计算量。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 距离度量" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "设特征空间$x$是$n$维实数向量空间 ,$x_{i}, x_{j} \\in \\mathcal{X}$,$x_{i}=\\left(x_{i}^{(1)}, x_{i}^{(2)}, \\cdots, x_{i}^{(n)}\\right)^{\\mathrm{T}}$,$x_{j}=\\left(x_{j}^{(1)}, x_{j}^{(2)}, \\cdots, x_{j}^{(n)}\\right)^{\\mathrm{T}}$\n", + ",则:$x_i$,$x_j$的$L_p$距离定义为:\n", + "\n", + "\n", + "$L_{p}\\left(x_{i}, x_{j}\\right)=\\left(\\sum_{i=1}^{n}\\left|x_{i}^{(i)}-x_{j}^{(l)}\\right|^{p}\\right)^{\\frac{1}{p}}$\n", + "\n", + "- $p= 1$ 曼哈顿距离\n", + "- $p= 2$ 欧氏距离\n", + "- $p= \\infty$ 切比雪夫距离" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from itertools import combinations" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def L(x, y, p=2):\n", + " # x1 = [1, 1], x2 = [5,1]\n", + " if len(x) == len(y) and len(x) > 1:\n", + " sum = 0\n", + " for i in range(len(x)):\n", + " sum += math.pow(abs(x[i] - y[i]), p)\n", + " return math.pow(sum, 1 / p)\n", + " else:\n", + " return 0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 课本例3.1" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "x1 = [1, 1]\n", + "x2 = [5, 1]\n", + "x3 = [4, 4]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(4.0, '1-[5, 1]')\n", + "(4.0, '1-[5, 1]')\n", + "(3.7797631496846193, '1-[4, 4]')\n", + "(3.5676213450081633, '1-[4, 4]')\n" + ] + } + ], + "source": [ + "# x1, x2\n", + "for i in range(1, 5):\n", + " r = {'1-{}'.format(c): L(x1, c, p=i) for c in [x2, x3]}\n", + " print(min(zip(r.values(), r.keys())))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "python实现,遍历所有数据点,找出$n$个距离最近的点的分类情况,少数服从多数" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "from sklearn.datasets import load_iris\n", + "from sklearn.model_selection import train_test_split\n", + "from collections import Counter" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# data\n", + "iris = load_iris()\n", + "df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", + "df['label'] = iris.target\n", + "df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", + "# data = np.array(df.iloc[:100, [0, 1, -1]])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sepal lengthsepal widthpetal lengthpetal widthlabel
05.13.51.40.20
14.93.01.40.20
24.73.21.30.20
34.63.11.50.20
45.03.61.40.20
55.43.91.70.40
64.63.41.40.30
75.03.41.50.20
84.42.91.40.20
94.93.11.50.10
105.43.71.50.20
114.83.41.60.20
124.83.01.40.10
134.33.01.10.10
145.84.01.20.20
155.74.41.50.40
165.43.91.30.40
175.13.51.40.30
185.73.81.70.30
195.13.81.50.30
205.43.41.70.20
215.13.71.50.40
224.63.61.00.20
235.13.31.70.50
244.83.41.90.20
255.03.01.60.20
265.03.41.60.40
275.23.51.50.20
285.23.41.40.20
294.73.21.60.20
..................
1206.93.25.72.32
1215.62.84.92.02
1227.72.86.72.02
1236.32.74.91.82
1246.73.35.72.12
1257.23.26.01.82
1266.22.84.81.82
1276.13.04.91.82
1286.42.85.62.12
1297.23.05.81.62
1307.42.86.11.92
1317.93.86.42.02
1326.42.85.62.22
1336.32.85.11.52
1346.12.65.61.42
1357.73.06.12.32
1366.33.45.62.42
1376.43.15.51.82
1386.03.04.81.82
1396.93.15.42.12
1406.73.15.62.42
1416.93.15.12.32
1425.82.75.11.92
1436.83.25.92.32
1446.73.35.72.52
1456.73.05.22.32
1466.32.55.01.92
1476.53.05.22.02
1486.23.45.42.32
1495.93.05.11.82
\n", + "

150 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " sepal length sepal width petal length petal width label\n", + "0 5.1 3.5 1.4 0.2 0\n", + "1 4.9 3.0 1.4 0.2 0\n", + "2 4.7 3.2 1.3 0.2 0\n", + "3 4.6 3.1 1.5 0.2 0\n", + "4 5.0 3.6 1.4 0.2 0\n", + "5 5.4 3.9 1.7 0.4 0\n", + "6 4.6 3.4 1.4 0.3 0\n", + "7 5.0 3.4 1.5 0.2 0\n", + "8 4.4 2.9 1.4 0.2 0\n", + "9 4.9 3.1 1.5 0.1 0\n", + "10 5.4 3.7 1.5 0.2 0\n", + "11 4.8 3.4 1.6 0.2 0\n", + "12 4.8 3.0 1.4 0.1 0\n", + "13 4.3 3.0 1.1 0.1 0\n", + "14 5.8 4.0 1.2 0.2 0\n", + "15 5.7 4.4 1.5 0.4 0\n", + "16 5.4 3.9 1.3 0.4 0\n", + "17 5.1 3.5 1.4 0.3 0\n", + "18 5.7 3.8 1.7 0.3 0\n", + "19 5.1 3.8 1.5 0.3 0\n", + "20 5.4 3.4 1.7 0.2 0\n", + "21 5.1 3.7 1.5 0.4 0\n", + "22 4.6 3.6 1.0 0.2 0\n", + "23 5.1 3.3 1.7 0.5 0\n", + "24 4.8 3.4 1.9 0.2 0\n", + "25 5.0 3.0 1.6 0.2 0\n", + "26 5.0 3.4 1.6 0.4 0\n", + "27 5.2 3.5 1.5 0.2 0\n", + "28 5.2 3.4 1.4 0.2 0\n", + "29 4.7 3.2 1.6 0.2 0\n", + ".. ... ... ... ... ...\n", + "120 6.9 3.2 5.7 2.3 2\n", + "121 5.6 2.8 4.9 2.0 2\n", + "122 7.7 2.8 6.7 2.0 2\n", + "123 6.3 2.7 4.9 1.8 2\n", + "124 6.7 3.3 5.7 2.1 2\n", + "125 7.2 3.2 6.0 1.8 2\n", + "126 6.2 2.8 4.8 1.8 2\n", + "127 6.1 3.0 4.9 1.8 2\n", + "128 6.4 2.8 5.6 2.1 2\n", + "129 7.2 3.0 5.8 1.6 2\n", + "130 7.4 2.8 6.1 1.9 2\n", + "131 7.9 3.8 6.4 2.0 2\n", + "132 6.4 2.8 5.6 2.2 2\n", + "133 6.3 2.8 5.1 1.5 2\n", + "134 6.1 2.6 5.6 1.4 2\n", + "135 7.7 3.0 6.1 2.3 2\n", + "136 6.3 3.4 5.6 2.4 2\n", + "137 6.4 3.1 5.5 1.8 2\n", + "138 6.0 3.0 4.8 1.8 2\n", + "139 6.9 3.1 5.4 2.1 2\n", + "140 6.7 3.1 5.6 2.4 2\n", + "141 6.9 3.1 5.1 2.3 2\n", + "142 5.8 2.7 5.1 1.9 2\n", + "143 6.8 3.2 5.9 2.3 2\n", + "144 6.7 3.3 5.7 2.5 2\n", + "145 6.7 3.0 5.2 2.3 2\n", + "146 6.3 2.5 5.0 1.9 2\n", + "147 6.5 3.0 5.2 2.0 2\n", + "148 6.2 3.4 5.4 2.3 2\n", + "149 5.9 3.0 5.1 1.8 2\n", + "\n", + "[150 rows x 5 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')\n", + "plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')\n", + "plt.xlabel('sepal length')\n", + "plt.ylabel('sepal width')\n", + "plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "data = np.array(df.iloc[:100, [0, 1, -1]])\n", + "X, y = data[:,:-1], data[:,-1]\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class KNN:\n", + " def __init__(self, X_train, y_train, n_neighbors=3, p=2):\n", + " \"\"\"\n", + " parameter: n_neighbors 临近点个数\n", + " parameter: p 距离度量\n", + " \"\"\"\n", + " self.n = n_neighbors\n", + " self.p = p\n", + " self.X_train = X_train\n", + " self.y_train = y_train\n", + "\n", + " def predict(self, X):\n", + " # 取出n个点\n", + " knn_list = []\n", + " for i in range(self.n):\n", + " dist = np.linalg.norm(X - self.X_train[i], ord=self.p)\n", + " knn_list.append((dist, self.y_train[i]))\n", + "\n", + " for i in range(self.n, len(self.X_train)):\n", + " max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))\n", + " dist = np.linalg.norm(X - self.X_train[i], ord=self.p)\n", + " if knn_list[max_index][0] > dist:\n", + " knn_list[max_index] = (dist, self.y_train[i])\n", + "\n", + " # 统计\n", + " knn = [k[-1] for k in knn_list]\n", + " count_pairs = Counter(knn)\n", + "# max_count = sorted(count_pairs, key=lambda x: x)[-1]\n", + " max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]\n", + " return max_count\n", + "\n", + " def score(self, X_test, y_test):\n", + " right_count = 0\n", + " n = 10\n", + " for X, y in zip(X_test, y_test):\n", + " label = self.predict(X)\n", + " if label == y:\n", + " right_count += 1\n", + " return right_count / len(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "clf = KNN(X_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.95" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "clf.score(X_test, y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Point: 1.0\n" + ] + } + ], + "source": [ + "test_point = [6.0, 3.0]\n", + "print('Test Point: {}'.format(clf.predict(test_point)))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')\n", + "plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')\n", + "plt.plot(test_point[0], test_point[1], 'bo', label='test_point')\n", + "plt.xlabel('sepal length')\n", + "plt.ylabel('sepal width')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### scikit-learn实例" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.neighbors import KNeighborsClassifier" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", + " metric_params=None, n_jobs=None, n_neighbors=5, p=2,\n", + " weights='uniform')" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "clf_sk = KNeighborsClassifier()\n", + "clf_sk.fit(X_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.95" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "clf_sk.score(X_test, y_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "### sklearn.neighbors.KNeighborsClassifier\n", + "\n", + "- n_neighbors: 临近点个数\n", + "- p: 距离度量\n", + "- algorithm: 近邻算法,可选{'auto', 'ball_tree', 'kd_tree', 'brute'}\n", + "- weights: 确定近邻的权重" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### kd树" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**kd**树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。\n", + "\n", + "**kd**树是二叉树,表示对$k$维空间的一个划分(partition)。构造**kd**树相当于不断地用垂直于坐标轴的超平面将$k$维空间切分,构成一系列的k维超矩形区域。kd树的每个结点对应于一个$k$维超矩形区域。\n", + "\n", + "构造**kd**树的方法如下:\n", + "\n", + "构造根结点,使根结点对应于$k$维空间中包含所有实例点的超矩形区域;通过下面的递归方法,不断地对$k$维空间进行切分,生成子结点。在超矩形区域(结点)上选择一个坐标轴和在此坐标轴上的一个切分点,确定一个超平面,这个超平面通过选定的切分点并垂直于选定的坐标轴,将当前超矩形区域切分为左右两个子区域\n", + "(子结点);这时,实例被分到两个子区域。这个过程直到子区域内没有实例时终止(终止时的结点为叶结点)。在此过程中,将实例保存在相应的结点上。\n", + "\n", + "通常,依次选择坐标轴对空间切分,选择训练实例点在选定坐标轴上的中位数\n", + "(median)为切分点,这样得到的**kd**树是平衡的。注意,平衡的**kd**树搜索时的效率未必是最优的。\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 构造平衡kd树算法\n", + "输入:$k$维空间数据集$T=\\{x_1,x_2,…,x_N\\}$,\n", + "\n", + "其中$x_{i}=\\left(x_{i}^{(1)}, x_{i}^{(2)}, \\cdots, x_{i}^{(k)}\\right)^{\\mathrm{T}}$ ,$i=1,2,…,N$;\n", + "\n", + "输出:**kd**树。\n", + "\n", + "(1)开始:构造根结点,根结点对应于包含$T$的$k$维空间的超矩形区域。\n", + "\n", + "选择$x^{(1)}$为坐标轴,以T中所有实例的$x^{(1)}$坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴$x^{(1)}$垂直的超平面实现。\n", + "\n", + "由根结点生成深度为1的左、右子结点:左子结点对应坐标$x^{(1)}$小于切分点的子区域, 右子结点对应于坐标$x^{(1)}$大于切分点的子区域。\n", + "\n", + "将落在切分超平面上的实例点保存在根结点。\n", + "\n", + "(2)重复:对深度为$j$的结点,选择$x^{(1)}$为切分的坐标轴,$l=j(modk)+1$,以该结点的区域中所有实例的$x^{(1)}$坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴$x^{(1)}$垂直的超平面实现。\n", + "\n", + "由该结点生成深度为$j+1$的左、右子结点:左子结点对应坐标$x^{(1)}$小于切分点的子区域,右子结点对应坐标$x^{(1)}$大于切分点的子区域。\n", + "\n", + "将落在切分超平面上的实例点保存在该结点。\n", + "\n", + "(3)直到两个子区域没有实例存在时停止。从而形成**kd**树的区域划分。" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# kd-tree每个结点中主要包含的数据结构如下\n", + "class KdNode(object):\n", + " def __init__(self, dom_elt, split, left, right):\n", + " self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)\n", + " self.split = split # 整数(进行分割维度的序号)\n", + " self.left = left # 该结点分割超平面左子空间构成的kd-tree\n", + " self.right = right # 该结点分割超平面右子空间构成的kd-tree\n", + "\n", + "\n", + "class KdTree(object):\n", + " def __init__(self, data):\n", + " k = len(data[0]) # 数据维度\n", + "\n", + " def CreateNode(split, data_set): # 按第split维划分数据集exset创建KdNode\n", + " if not data_set: # 数据集为空\n", + " return None\n", + " # key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较\n", + " # operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号\n", + " #data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序\n", + " data_set.sort(key=lambda x: x[split])\n", + " split_pos = len(data_set) // 2 # //为Python中的整数除法\n", + " median = data_set[split_pos] # 中位数分割点\n", + " split_next = (split + 1) % k # cycle coordinates\n", + "\n", + " # 递归的创建kd树\n", + " return KdNode(\n", + " median,\n", + " split,\n", + " CreateNode(split_next, data_set[:split_pos]), # 创建左子树\n", + " CreateNode(split_next, data_set[split_pos + 1:])) # 创建右子树\n", + "\n", + " self.root = CreateNode(0, data) # 从第0维分量开始构建kd树,返回根节点\n", + "\n", + "\n", + "# KDTree的前序遍历\n", + "def preorder(root):\n", + " print(root.dom_elt)\n", + " if root.left: # 节点不为空\n", + " preorder(root.left)\n", + " if root.right:\n", + " preorder(root.right)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# 对构建好的kd树进行搜索,寻找与目标点最近的样本点:\n", + "from math import sqrt\n", + "from collections import namedtuple\n", + "\n", + "# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数\n", + "result = namedtuple(\"Result_tuple\",\n", + " \"nearest_point nearest_dist nodes_visited\")\n", + "\n", + "\n", + "def find_nearest(tree, point):\n", + " k = len(point) # 数据维度\n", + "\n", + " def travel(kd_node, target, max_dist):\n", + " if kd_node is None:\n", + " return result([0] * k, float(\"inf\"),\n", + " 0) # python中用float(\"inf\")和float(\"-inf\")表示正负无穷\n", + "\n", + " nodes_visited = 1\n", + "\n", + " s = kd_node.split # 进行分割的维度\n", + " pivot = kd_node.dom_elt # 进行分割的“轴”\n", + "\n", + " if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)\n", + " nearer_node = kd_node.left # 下一个访问节点为左子树根节点\n", + " further_node = kd_node.right # 同时记录下右子树\n", + " else: # 目标离右子树更近\n", + " nearer_node = kd_node.right # 下一个访问节点为右子树根节点\n", + " further_node = kd_node.left\n", + "\n", + " temp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域\n", + "\n", + " nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”\n", + " dist = temp1.nearest_dist # 更新最近距离\n", + "\n", + " nodes_visited += temp1.nodes_visited\n", + "\n", + " if dist < max_dist:\n", + " max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内\n", + "\n", + " temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离\n", + " if max_dist < temp_dist: # 判断超球体是否与超平面相交\n", + " return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断\n", + "\n", + " #----------------------------------------------------------------------\n", + " # 计算目标点与分割点的欧氏距离\n", + " temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))\n", + "\n", + " if temp_dist < dist: # 如果“更近”\n", + " nearest = pivot # 更新最近点\n", + " dist = temp_dist # 更新最近距离\n", + " max_dist = dist # 更新超球体半径\n", + "\n", + " # 检查另一个子结点对应的区域是否有更近的点\n", + " temp2 = travel(further_node, target, max_dist)\n", + "\n", + " nodes_visited += temp2.nodes_visited\n", + " if temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离\n", + " nearest = temp2.nearest_point # 更新最近点\n", + " dist = temp2.nearest_dist # 更新最近距离\n", + "\n", + " return result(nearest, dist, nodes_visited)\n", + "\n", + " return travel(tree.root, point, float(\"inf\")) # 从根节点开始递归" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 例3.2" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[7, 2]\n", + "[5, 4]\n", + "[2, 3]\n", + "[4, 7]\n", + "[9, 6]\n", + "[8, 1]\n" + ] + } + ], + "source": [ + "data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]\n", + "kd = KdTree(data)\n", + "preorder(kd.root)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "from time import clock\n", + "from random import random\n", + "\n", + "# 产生一个k维随机向量,每维分量值在0~1之间\n", + "def random_point(k):\n", + " return [random() for _ in range(k)]\n", + " \n", + "# 产生n个k维随机向量 \n", + "def random_points(k, n):\n", + " return [random_point(k) for _ in range(n)] " + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4)\n" + ] + } + ], + "source": [ + "ret = find_nearest(kd, [3,4.5])\n", + "print (ret)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "time: 5.4623788 s\n", + "Result_tuple(nearest_point=[0.09929288205798159, 0.4954936771850429, 0.8005722800665575], nearest_dist=0.004597223680778027, nodes_visited=42)\n" + ] + } + ], + "source": [ + "N = 400000\n", + "t0 = clock()\n", + "kd2 = KdTree(random_points(3, N)) # 构建包含四十万个3维空间样本点的kd树\n", + "ret2 = find_nearest(kd2, [0.1,0.5,0.8]) # 四十万个样本点中寻找离目标最近的点\n", + "t1 = clock()\n", + "print (\"time: \",t1-t0, \"s\")\n", + "print (ret2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "----\n", + "参考代码:https://github.com/wzyonggege/statistical-learning-method\n", + "\n", + "中文注释制作:机器学习初学者\n", + "\n", + "微信公众号:ID:ai-start-com\n", + "\n", + "配置环境:python 3.5+\n", + "\n", + "代码全部测试通过。\n", + "![gongzhong](../gongzhong.jpg)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git "a/\347\254\25403\347\253\240 k\350\277\221\351\202\273\346\263\225/3.KNearestNeighbors.ipynb" "b/\347\254\25403\347\253\240 k\350\277\221\351\202\273\346\263\225/3.KNearestNeighbors.ipynb" index 99cd355..c474547 100644 --- "a/\347\254\25403\347\253\240 k\350\277\221\351\202\273\346\263\225/3.KNearestNeighbors.ipynb" +++ "b/\347\254\25403\347\253\240 k\350\277\221\351\202\273\346\263\225/3.KNearestNeighbors.ipynb" @@ -41,7 +41,7 @@ "\n", "- $p= 1$ 曼哈顿距离\n", "- $p= 2$ 欧氏距离\n", - "- $p= inf$ 闵式距离minkowski_distance " + "- $p= \\infty$ 切比雪夫距离" ] }, { @@ -1296,7 +1296,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.4" + "version": "3.6.8" } }, "nbformat": 4,