In [1]:
import numpy as np
import random
import math

import sklearn.datasets as datasets
from sklearn.model_selection import train_test_split
from collections import Counter

In [2]:
dataset = datasets.load_wine()

In [3]:
sample_rate = 0.5 # 采样率（可以选择小于1）
assert sample_rate>0 and sample_rate<=1

In [4]:
nSample = math.floor(dataset.data.shape[0]*sample_rate)

In [5]:
nSample

89

In [6]:
dataset

{'data': array([[1.423e+01, 1.710e+00, 2.430e+00, ..., 1.040e+00, 3.920e+00,
         1.065e+03],
        [1.320e+01, 1.780e+00, 2.140e+00, ..., 1.050e+00, 3.400e+00,
         1.050e+03],
        [1.316e+01, 2.360e+00, 2.670e+00, ..., 1.030e+00, 3.170e+00,
         1.185e+03],
        ...,
        [1.327e+01, 4.280e+00, 2.260e+00, ..., 5.900e-01, 1.560e+00,
         8.350e+02],
        [1.317e+01, 2.590e+00, 2.370e+00, ..., 6.000e-01, 1.620e+00,
         8.400e+02],
        [1.413e+01, 4.100e+00, 2.740e+00, ..., 6.100e-01, 1.600e+00,
         5.600e+02]]),
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [7]:
dataset.data.shape

(178, 13)

In [8]:
idx = random.sample(range(dataset.data.shape[0]), nSample)

In [9]:
idx

[116,
 109,
 79,
 69,
 141,
 17,
 103,
 26,
 28,
 89,
 1,
 38,
 4,
 99,
 165,
 49,
 74,
 53,
 85,
 67,
 77,
 149,
 96,
 145,
 100,
 34,
 52,
 58,
 156,
 10,
 65,
 148,
 133,
 152,
 20,
 119,
 168,
 92,
 111,
 117,
 83,
 159,
 86,
 136,
 142,
 113,
 8,
 41,
 97,
 13,
 120,
 121,
 150,
 73,
 112,
 122,
 0,
 146,
 36,
 60,
 39,
 102,
 22,
 37,
 32,
 81,
 71,
 78,
 18,
 55,
 80,
 173,
 138,
 64,
 14,
 19,
 61,
 50,
 31,
 151,
 175,
 107,
 177,
 62,
 125,
 163,
 124,
 153,
 23]

In [10]:
X = dataset.data[idx][:]
y = dataset.target[idx][:]
y_stat = Counter(y)

In [11]:
y_stat

Counter({1: 39, 2: 21, 0: 29})

In [12]:
print('======================= dataset information =======================')
print('Total sample number: %d, Feature dimension: %d, Category number: %d' % (X.shape[0], X.shape[1], len(y_stat)))

Total sample number: 89, Feature dimension: 13, Category number: 3


In [13]:
type(y_stat)

collections.Counter

In [14]:
for category in y_stat:
    print('category %d has %d samples' % (category, y_stat[category]))

category 1 has 39 samples
category 2 has 21 samples
category 0 has 29 samples


In [15]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
print('Training sample number: %d, Test sample number: %d' % (X_train.shape[0], X_test.shape[0]))

Training sample number: 71, Test sample number: 18


In [16]:
#KNN类定义
class KNN:
    def __init__(self, X_train, y_train, n_neighbors=3, p=2): # 通过n_neighbors修改k值
        """
        parameter: n_neighbors 临近点个数
        parameter: p 距离度量
        """
        self.n = n_neighbors
        self.p = p
        self.X_train = X_train
        self.y_train = y_train

    def predict(self, X):
        # 取出n个点
        knn_list = []
        for i in range(self.n):
            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
            knn_list.append((dist, self.y_train[i]))

        for i in range(self.n, len(self.X_train)):
            max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))
            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
            if knn_list[max_index][0] > dist:
                knn_list[max_index] = (dist, self.y_train[i])

        # 统计
        knn = [k[-1] for k in knn_list]
        count_pairs = Counter(knn)
#         max_count = sorted(count_pairs, key=lambda x: x)[-1]
        max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]
        return max_count

    def score(self, X_test, y_test):
        right_count = 0
        n = 10
        for X, y in zip(X_test, y_test):
            label = self.predict(X)
            if label == y:
                right_count += 1
        return right_count / len(X_test)

In [17]:
K=3
assert K<=X_train.shape[0]
clf = KNN(X_train, y_train, K)

In [18]:
print('precision rate: {:.2%}'.format(clf.score(X_test, y_test)))

precision rate: 55.56%


In [19]:
test_idx = 0
assert test_idx<X_test.shape[0]
test_point = X_test[0]


In [20]:
test_point

array([1.307e+01, 1.500e+00, 2.100e+00, 1.550e+01, 9.800e+01, 2.400e+00,
       2.640e+00, 2.800e-01, 1.370e+00, 3.700e+00, 1.180e+00, 2.690e+00,
       1.020e+03])

In [21]:
X_test

array([[1.307e+01, 1.500e+00, 2.100e+00, 1.550e+01, 9.800e+01, 2.400e+00,
        2.640e+00, 2.800e-01, 1.370e+00, 3.700e+00, 1.180e+00, 2.690e+00,
        1.020e+03],
       [1.324e+01, 2.590e+00, 2.870e+00, 2.100e+01, 1.180e+02, 2.800e+00,
        2.690e+00, 3.900e-01, 1.820e+00, 4.320e+00, 1.040e+00, 2.930e+00,
        7.350e+02],
       [1.161e+01, 1.350e+00, 2.700e+00, 2.000e+01, 9.400e+01, 2.740e+00,
        2.920e+00, 2.900e-01, 2.490e+00, 2.650e+00, 9.600e-01, 3.260e+00,
        6.800e+02],
       [1.406e+01, 1.630e+00, 2.280e+00, 1.600e+01, 1.260e+02, 3.000e+00,
        3.170e+00, 2.400e-01, 2.100e+00, 5.650e+00, 1.090e+00, 3.710e+00,
        7.800e+02],
       [1.208e+01, 2.080e+00, 1.700e+00, 1.750e+01, 9.700e+01, 2.230e+00,
        2.170e+00, 2.600e-01, 1.400e+00, 3.300e+00, 1.270e+00, 2.960e+00,
        7.100e+02],
       [1.279e+01, 2.670e+00, 2.480e+00, 2.200e+01, 1.120e+02, 1.480e+00,
        1.360e+00, 2.400e-01, 1.260e+00, 1.080e+01, 4.800e-01, 1.470e+00,
        4.80

In [25]:
X_train[0].shape

(13,)

In [26]:
X_test.shape

(18, 13)

In [27]:
X_test - X_train[0]

array([[-1.16e+00, -2.10e-01, -3.30e-01, -1.00e-01, -2.90e+01, -4.00e-01,
        -4.20e-01,  0.00e+00, -9.20e-01, -1.94e+00,  1.40e-01, -1.23e+00,
        -4.50e+01],
       [-9.90e-01,  8.80e-01,  4.40e-01,  5.40e+00, -9.00e+00,  0.00e+00,
        -3.70e-01,  1.10e-01, -4.70e-01, -1.32e+00,  0.00e+00, -9.90e-01,
        -3.30e+02],
       [-2.62e+00, -3.60e-01,  2.70e-01,  4.40e+00, -3.30e+01, -6.00e-02,
        -1.40e-01,  1.00e-02,  2.00e-01, -2.99e+00, -8.00e-02, -6.60e-01,
        -3.85e+02],
       [-1.70e-01, -8.00e-02, -1.50e-01,  4.00e-01, -1.00e+00,  2.00e-01,
         1.10e-01, -4.00e-02, -1.90e-01,  1.00e-02,  5.00e-02, -2.10e-01,
        -2.85e+02],
       [-2.15e+00,  3.70e-01, -7.30e-01,  1.90e+00, -3.00e+01, -5.70e-01,
        -8.90e-01, -2.00e-02, -8.90e-01, -2.34e+00,  2.30e-01, -9.60e-01,
        -3.55e+02],
       [-1.44e+00,  9.60e-01,  5.00e-02,  6.40e+00, -1.50e+01, -1.32e+00,
        -1.70e+00, -4.00e-02, -1.03e+00,  5.16e+00, -5.60e-01, -2.45e+00,
        -5.8

In [28]:
X_train[0]

array([1.423e+01, 1.710e+00, 2.430e+00, 1.560e+01, 1.270e+02, 2.800e+00,
       3.060e+00, 2.800e-01, 2.290e+00, 5.640e+00, 1.040e+00, 3.920e+00,
       1.065e+03])

In [29]:
X_test

array([[1.307e+01, 1.500e+00, 2.100e+00, 1.550e+01, 9.800e+01, 2.400e+00,
        2.640e+00, 2.800e-01, 1.370e+00, 3.700e+00, 1.180e+00, 2.690e+00,
        1.020e+03],
       [1.324e+01, 2.590e+00, 2.870e+00, 2.100e+01, 1.180e+02, 2.800e+00,
        2.690e+00, 3.900e-01, 1.820e+00, 4.320e+00, 1.040e+00, 2.930e+00,
        7.350e+02],
       [1.161e+01, 1.350e+00, 2.700e+00, 2.000e+01, 9.400e+01, 2.740e+00,
        2.920e+00, 2.900e-01, 2.490e+00, 2.650e+00, 9.600e-01, 3.260e+00,
        6.800e+02],
       [1.406e+01, 1.630e+00, 2.280e+00, 1.600e+01, 1.260e+02, 3.000e+00,
        3.170e+00, 2.400e-01, 2.100e+00, 5.650e+00, 1.090e+00, 3.710e+00,
        7.800e+02],
       [1.208e+01, 2.080e+00, 1.700e+00, 1.750e+01, 9.700e+01, 2.230e+00,
        2.170e+00, 2.600e-01, 1.400e+00, 3.300e+00, 1.270e+00, 2.960e+00,
        7.100e+02],
       [1.279e+01, 2.670e+00, 2.480e+00, 2.200e+01, 1.120e+02, 1.480e+00,
        1.360e+00, 2.400e-01, 1.260e+00, 1.080e+01, 4.800e-01, 1.470e+00,
        4.80

In [32]:
dist = np.linalg.norm(X_test - X_train[0], ord=2)

In [33]:
dist

1796.3683062950759

In [39]:
knn_list = []
for i in range(3):
    dist = np.linalg.norm(X_test - X_train[i], ord=2)
    knn_list.append((dist, y_train[i]))

In [40]:
knn_list

[(1796.3683062950759, 0), (1089.149930151267, 2), (1082.7171009190185, 1)]

In [41]:
y_train

array([0, 2, 1, 2, 1, 0, 1, 1, 2, 2, 2, 0, 0, 1, 2, 1, 0, 0, 1, 1, 2, 1,
       1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 2, 1, 1, 0, 2, 2, 0, 1, 0, 1,
       0, 1, 1, 2, 1, 1, 1, 0, 2, 1, 0, 0, 0, 1, 2, 2, 1, 2, 1, 1, 1, 1,
       1, 2, 1, 0, 0])

In [42]:
max(knn_list, key=lambda x: x[0])

(1796.3683062950759, 0)

In [43]:
knn_list.index(max(knn_list, key=lambda x: x[0]))

0

In [44]:
for i in range(3, len(X_train)):
    max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))
    dist = np.linalg.norm(X_test - X_train[i], ord=2)
    if knn_list[max_index][0] > dist:
        knn_list[max_index] = (dist, y_train[i])

In [45]:
knn_list

[(1015.6544030171426, 2), (1010.5146543649649, 2), (1010.3701408393869, 1)]

In [46]:
knn = [k[-1] for k in knn_list]

In [47]:
knn

[2, 2, 1]

In [48]:
count_pairs = Counter(knn)

In [49]:
count_pairs

Counter({2: 2, 1: 1})

In [50]:
sorted(count_pairs.items(), key=lambda x: x[1])

[(1, 1), (2, 2)]

In [51]:
sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]

2

In [53]:
list(zip(X_test, y_test))

[(array([1.307e+01, 1.500e+00, 2.100e+00, 1.550e+01, 9.800e+01, 2.400e+00,
         2.640e+00, 2.800e-01, 1.370e+00, 3.700e+00, 1.180e+00, 2.690e+00,
         1.020e+03]), 0),
 (array([1.324e+01, 2.590e+00, 2.870e+00, 2.100e+01, 1.180e+02, 2.800e+00,
         2.690e+00, 3.900e-01, 1.820e+00, 4.320e+00, 1.040e+00, 2.930e+00,
         7.350e+02]), 0),
 (array([1.161e+01, 1.350e+00, 2.700e+00, 2.000e+01, 9.400e+01, 2.740e+00,
         2.920e+00, 2.900e-01, 2.490e+00, 2.650e+00, 9.600e-01, 3.260e+00,
         6.800e+02]), 1),
 (array([1.406e+01, 1.630e+00, 2.280e+00, 1.600e+01, 1.260e+02, 3.000e+00,
         3.170e+00, 2.400e-01, 2.100e+00, 5.650e+00, 1.090e+00, 3.710e+00,
         7.800e+02]), 0),
 (array([1.208e+01, 2.080e+00, 1.700e+00, 1.750e+01, 9.700e+01, 2.230e+00,
         2.170e+00, 2.600e-01, 1.400e+00, 3.300e+00, 1.270e+00, 2.960e+00,
         7.100e+02]), 1),
 (array([1.279e+01, 2.670e+00, 2.480e+00, 2.200e+01, 1.120e+02, 1.480e+00,
         1.360e+00, 2.400e-01, 1.260e+00, 1.0