## [作業重點]
了解如何使用 Sklearn 中的 hyper-parameter search 找出最佳的超參數

### 作業
請使用不同的資料集，並使用 hyper-parameter search 的方式，看能不能找出最佳的超參數組合

In [1]:
from sklearn import datasets, metrics
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import cross_val_score

In [2]:
boston = datasets.load_breast_cancer()

x_train, x_test, y_train, y_test = train_test_split(boston.data, boston.target, test_size=0.25, random_state=1)

clf = GradientBoostingClassifier(random_state=7,max_features='log2')

In [3]:
clf.fit(x_train, y_train)
y_pred = clf.predict(x_test)

In [4]:
acc = metrics.accuracy_score(y_test, y_pred)
print("Acuuracy: ", acc)

Acuuracy:  0.958041958041958


In [5]:
print("estimators: ",clf.n_estimators_)

estimators:  100


In [6]:
n_estimators = [100, 200, 300]
max_features = ['auto', 'sqrt', 'log2']
param_grid = dict(n_estimators=n_estimators, max_features=max_features)
param_grid

{'n_estimators': [100, 200, 300], 'max_features': ['auto', 'sqrt', 'log2']}

In [20]:
grid_search = GridSearchCV(clf, param_grid, scoring="accuracy", n_jobs=-1, verbose=1)
grid_result = grid_search.fit(x_train, y_train)

Fitting 3 folds for each of 9 candidates, totalling 27 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  27 out of  27 | elapsed:    1.3s finished


In [21]:
print("Best Accuracy: %f using %s" % (grid_result.best_score_, grid_result.best_params_))

Best Accuracy: 0.964789 using {'max_features': 'sqrt', 'n_estimators': 200}


In [22]:
clf_bestparam = GradientBoostingClassifier(max_features=grid_result.best_params_['max_features'],
                                           n_estimators=grid_result.best_params_['n_estimators']
                                           , random_state=1)

clf_bestparam.fit(x_train, y_train)

y_pred = clf_bestparam.predict(x_test)

In [23]:
acc = metrics.accuracy_score(y_test, y_pred)
print("Acuuracy: ", acc)

Acuuracy:  0.965034965034965


In [19]:
import warnings
warnings.filterwarnings('ignore')