# key words
- model selection: train_test_split, cross_val_score, GridSearchCV, RandomizedSearchCV
- regressor: KNeighborsRegressor

In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn import datasets
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, RandomizedSearchCV
from sklearn.neighbors import KNeighborsClassifier

iris = datasets.load_iris()
X = iris.data[:, 2:]
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=7)

knn_clf = KNeighborsClassifier()
param_grid = {'n_neighbors': list(range(3, 9, 1))}

# GridSearchCV
gs = GridSearchCV(knn_clf, param_grid, cv=10, iid=False)
%timeit gs.fit(X_train, y_train)
print("GridSearchCV best_params", gs.best_params_)
print([pair for pair in zip(gs.cv_results_['params'], gs.cv_results_['mean_test_score'])])

# RandomizedSearchCV
param_dist = {'n_neighbors': list(range(3, 50, 1))}
rs = RandomizedSearchCV(knn_clf, param_dist, cv=10, n_iter=15, iid=False)
%timeit rs.fit(X_train, y_train)
print("RandomizedSearchCV best_params", rs.best_params_)
print([pair for pair in zip(rs.cv_results_['params'], rs.cv_results_['mean_test_score'])])

119 ms ± 3.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
GridSearchCV best_params {'n_neighbors': 3}
[({'n_neighbors': 3}, 0.9566666666666667), ({'n_neighbors': 4}, 0.9400000000000001), ({'n_neighbors': 5}, 0.9566666666666667), ({'n_neighbors': 6}, 0.9483333333333333), ({'n_neighbors': 7}, 0.9566666666666667), ({'n_neighbors': 8}, 0.9566666666666667)]
298 ms ± 8.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
RandomizedSearchCV best_params {'n_neighbors': 12}
[({'n_neighbors': 45}, 0.95), ({'n_neighbors': 12}, 0.9566666666666667), ({'n_neighbors': 11}, 0.9566666666666667), ({'n_neighbors': 19}, 0.9416666666666667), ({'n_neighbors': 27}, 0.9566666666666667), ({'n_neighbors': 33}, 0.9400000000000001), ({'n_neighbors': 41}, 0.9400000000000001), ({'n_neighbors': 14}, 0.9566666666666667), ({'n_neighbors': 46}, 0.9400000000000001), ({'n_neighbors': 49}, 0.95), ({'n_neighbors': 9}, 0.9566666666666667), ({'n_neighbors': 25}, 0.9483333333333333), ({'n_neighbors': 18}, 0.95)