统计学习方法第三章 k近邻法
笔记见http://jieguangzhou.github.io

In [1]:
import numpy as np
import pandas as pd
import os
import random
from collections import Counter

In [2]:
# 用于打印数据
def print_var(var, name='x'):
    print(name+':\n',var, '\nshape_or_type:', var.shape if hasattr(var, 'shape') else type(var))

In [3]:
# 用pandas导入数据
from sklearn.datasets import load_iris
iris = load_iris()
target_names = iris['target_names']
target = [target_names[i] for i in iris['target']]
data = pd.DataFrame(iris['data'], columns=iris['feature_names'])
data['Label'] = target

print(len(data))
data.head()

150


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),Label
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa


In [4]:
# 随机打乱数据集顺序，以7：3的比例切分数据集
indexs = list(range(len(data)))
random.shuffle(indexs)
train_ratio = 0.7
train_data_num = int(len(data) * train_ratio)

train_idxs = indexs[:train_data_num]
test_idxs = indexs[train_data_num:]

train_data = data.iloc[train_idxs]
test_data = data.iloc[test_idxs]

In [5]:
print(len(train_data))
train_data.head()

105


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),Label
40,5.0,3.5,1.3,0.3,setosa
135,7.7,3.0,6.1,2.3,virginica
56,6.3,3.3,4.7,1.6,versicolor
38,4.4,3.0,1.3,0.2,setosa
115,6.4,3.2,5.3,2.3,virginica


In [6]:
print(len(test_data))
test_data.head()

45


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),Label
148,6.2,3.4,5.4,2.3,virginica
25,5.0,3.0,1.6,0.2,setosa
96,5.7,2.9,4.2,1.3,versicolor
90,5.5,2.6,4.4,1.2,versicolor
130,7.4,2.8,6.1,1.9,virginica


将DataFrame转为ndarray用于矩阵计算

In [7]:
train_matrix = train_data[iris['feature_names']].values
train_label = train_data['Label'].values
print_var(train_matrix[:5], 'train_matrix')
print_var(train_label[:5], 'train_label')

train_matrix:
 [[5.  3.5 1.3 0.3]
 [7.7 3.  6.1 2.3]
 [6.3 3.3 4.7 1.6]
 [4.4 3.  1.3 0.2]
 [6.4 3.2 5.3 2.3]] 
shape_or_type: (5, 4)
train_label:
 ['setosa' 'virginica' 'versicolor' 'setosa' 'virginica'] 
shape_or_type: (5,)


In [8]:
test_matrix = test_data[iris['feature_names']].values
test_label = test_data['Label'].values
print_var(test_matrix[:5], 'train_matrix')
print_var(test_label[:5], 'train_label')

train_matrix:
 [[6.2 3.4 5.4 2.3]
 [5.  3.  1.6 0.2]
 [5.7 2.9 4.2 1.3]
 [5.5 2.6 4.4 1.2]
 [7.4 2.8 6.1 1.9]] 
shape_or_type: (5, 4)
train_label:
 ['virginica' 'setosa' 'versicolor' 'versicolor' 'virginica'] 
shape_or_type: (5,)


In [9]:
sample_idx = 0
sample_test = test_matrix[sample_idx, :]
sample_label = test_label[sample_idx]
print_var(sample_test, 'sample_test')
print()
print_var(sample_label, 'sample_label')

sample_test:
 [6.2 3.4 5.4 2.3] 
shape_or_type: (4,)

sample_label:
 virginica 
shape_or_type: ()


Lp距离计算公式
$$
L_p(x_i, x_j)=(\sum_{l=1}^n |x_i^{(l)} - x_j^{(l)}|^p)^{\frac{1}{p}}
$$
其中p >= 1,
当p=2时，为欧式距离
$$
L_2(x_i, x_j)=(\sum_{l=1}^n |x_i^{(l)} - x_j^{(l)}|^2)^{\frac{1}{2}}
$$
当p=1时，为曼哈顿距离
$$
L_1(x_i, x_j)=\sum_{l=1}^n |x_i^{(l)} - x_j^{(l)}|
$$
当 p=∞时，它是各个坐标距离的最大值

$$
L_\infty(x_i, x_j)=max_l \;|x_i^{(l)} - x_j^{(l)}|
$$


In [10]:
# 计算距离
def distance(train_matrix, array, p=2):
    """
    train_matrix: 训练集矩阵
    array: 预测的x的向量
    p: Lp距离中的p值, p >= 1
    """
    assert p >= 1
    diff = np.abs(train_matrix - array)
    
    if p != np.inf:
        p_sum = np.sum(np.power(diff, p), axis=1)
        L_p = np.power(p_sum, 1/p)
        
    else:
        L_p = np.max(diff, axis=1)
        
    return L_p

In [11]:
# 当p=2时，样例计算结果
exp_l_p = distance(train_matrix, sample_test, 2)
exp_l_p

array([4.71805044, 1.70293864, 1.        , 4.96185449, 0.3       ,
       4.99099189, 0.83666003, 1.50996689, 1.21243557, 1.52643375,
       2.48797106, 2.96479342, 2.20227155, 2.20454077, 4.58911756,
       0.69282032, 0.9       , 1.40356688, 4.74973683, 4.75078941,
       2.02731349, 4.64112055, 4.71911009, 1.37477271, 4.26028168,
       4.56179789, 2.463737  , 1.3190906 , 0.55677644, 0.9486833 ,
       1.46628783, 5.20576603, 0.81853528, 1.2489996 , 4.411349  ,
       1.56843871, 0.67082039, 2.3430749 , 1.45945195, 4.66047208,
       0.86023253, 4.3150898 , 1.27279221, 3.03644529, 1.3453624 ,
       1.21243557, 1.28452326, 1.43178211, 1.21243557, 4.56398948,
       2.06397674, 4.26497362, 1.22474487, 0.78740079, 4.67867503,
       2.1330729 , 4.25440948, 1.7       , 1.79164729, 0.78740079,
       4.52990066, 0.6244998 , 1.43874946, 1.8973666 , 2.15870331,
       4.4855323 , 4.47995536, 1.40712473, 1.55563492, 1.02956301,
       0.98488578, 1.14455231, 1.05356538, 0.78740079, 4.67546

In [12]:
def k_most_near(l_p, train_label, k=5):
    
    #取距离最小的k个的index,按照index从train_label中取出labels
    index = l_p.argsort()[:k]
    k_label = train_label[index]
    
    # 用python自带的Counter计算每个label的数量，并取数量最多的一个
    cnter = Counter(k_label.tolist())
    most_label_cnt = cnter.most_common(1)[0]
    most_label = most_label_cnt[0]
    
    # 检查是否有多个lebel拥有同数量且最高，如果有则k+1再运行一次。
    if len(cnter) > 1:
        second_label_cnt = cnter.most_common(2)[1]
        
        if most_label_cnt[1] == second_label_cnt[1]:
            
            most_label = k_most_near(l_p, train_label, k+1)
            print(cnter.most_common(2))
            print(most_label)
    
    return most_label

In [13]:
most_label = k_most_near(exp_l_p, train_label)
print(most_label)
print(sample_label)
print(most_label == sample_label)

virginica
virginica
True


In [14]:
# 预测，并计算准确率
p = 2
result = []
for array, label in zip(test_matrix, test_label):
    l_p = distance(train_matrix, array, p)
    pre_label = k_most_near(l_p, train_label)
    result.append((pre_label, label, pre_label==label))
accuary = sum([i for _, _, i in result]) / len(result)

print('accuary', accuary)
result

accuary 0.9555555555555556


[('virginica', 'virginica', True),
 ('setosa', 'setosa', True),
 ('versicolor', 'versicolor', True),
 ('versicolor', 'versicolor', True),
 ('virginica', 'virginica', True),
 ('setosa', 'setosa', True),
 ('setosa', 'setosa', True),
 ('versicolor', 'versicolor', True),
 ('virginica', 'versicolor', False),
 ('setosa', 'setosa', True),
 ('setosa', 'setosa', True),
 ('setosa', 'setosa', True),
 ('setosa', 'setosa', True),
 ('setosa', 'setosa', True),
 ('virginica', 'virginica', True),
 ('versicolor', 'versicolor', True),
 ('virginica', 'virginica', True),
 ('setosa', 'setosa', True),
 ('setosa', 'setosa', True),
 ('setosa', 'setosa', True),
 ('setosa', 'setosa', True),
 ('versicolor', 'versicolor', True),
 ('virginica', 'virginica', True),
 ('setosa', 'setosa', True),
 ('virginica', 'virginica', True),
 ('versicolor', 'versicolor', True),
 ('versicolor', 'versicolor', True),
 ('versicolor', 'versicolor', True),
 ('setosa', 'setosa', True),
 ('virginica', 'virginica', True),
 ('virginica', '