## 更多kNN中的超参数

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

In [2]:
digits_data = datasets.load_digits()
X = digits_data.data
y = digits_data.target

In [3]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [4]:
from sklearn.neighbors import KNeighborsClassifier

sk_knn_clf = KNeighborsClassifier(n_neighbors=4)
sk_knn_clf.fit(X_train, y_train)
sk_knn_clf.score(X_test, y_test)

0.98055555555555551

In [5]:
sk_knn_clf = KNeighborsClassifier(n_neighbors=4, p=1)
sk_knn_clf.fit(X_train, y_train)
sk_knn_clf.score(X_test, y_test)

0.97777777777777775

### weights

uniform 和 distance

In [6]:
sk_knn_clf = KNeighborsClassifier(n_neighbors=4, weights='distance')
sk_knn_clf.fit(X_train, y_train)
sk_knn_clf.score(X_test, y_test)

0.98333333333333328

### algorithm

brute, kd_tree 和 ball_tree

In [7]:
sk_knn_clf = KNeighborsClassifier(n_neighbors=4, algorithm="kd_tree")
sk_knn_clf.fit(X_train, y_train)
sk_knn_clf.score(X_test, y_test)

0.98055555555555551

In [8]:
sk_knn_clf = KNeighborsClassifier(n_neighbors=4, algorithm="ball_tree")
sk_knn_clf.fit(X_train, y_train)
sk_knn_clf.score(X_test, y_test)

0.98055555555555551

In [9]:
sk_knn_clf = KNeighborsClassifier(n_neighbors=4, algorithm="kd_tree", leaf_size=10)
sk_knn_clf.fit(X_train, y_train)
sk_knn_clf.score(X_test, y_test)

0.98055555555555551

### Grid Search

In [10]:
param_grid = [
    {
        'algorithm': ['brute'], 
        'n_neighbors': [i for i in range(2, 11)], 
        'weights': ['uniform', 'distance'],
        'p': [1, 2]
    },
    {
        'algorithm': ['ball_tree', 'kd_tree'],
        'n_neighbors': [i for i in range(2, 11)], 
        'weights': ['uniform', 'distance'],
        'p': [1, 2]
    }
]

In [11]:
knn_clf = KNeighborsClassifier()

In [13]:
%%time
from sklearn.model_selection import GridSearchCV

grid_search = GridSearchCV(knn_clf, param_grid, verbose=1)
grid_search.fit(X_train, y_train)

Fitting 3 folds for each of 108 candidates, totalling 324 fits
CPU times: user 47.8 s, sys: 989 ms, total: 48.8 s
Wall time: 45.5 s


[Parallel(n_jobs=1)]: Done 324 out of 324 | elapsed:   45.4s finished


In [14]:
grid_search.best_estimator_

KNeighborsClassifier(algorithm='brute', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=2, p=2,
           weights='distance')

In [15]:
grid_search.best_score_

0.98747390396659707

In [16]:
clf = grid_search.best_estimator_
clf.score(X_test, y_test)

0.98611111111111116