# 作業
請使用不同的資料集，並使用 hyper-parameter search 的方式，看能不能找出最佳的超參數組合

In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
plt.style.use('bmh')
import warnings
warnings.simplefilter('ignore')

In [18]:
from sklearn.datasets import load_digits

X, y = load_digits(return_X_y=True)
print(X.shape, y.shape)

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.25, random_state=0)

(1797, 64) (1797,)


In [19]:
def get_best_model_and_accuracy(model, params, X, y):
    from sklearn.model_selection import GridSearchCV
    grid = GridSearchCV(model, params, error_score=0, cv=5, n_jobs=-1)
    grid.fit(X, y)
    print(f"Best accuracy: {grid.best_score_:.3f}")
    print(f"Best parameters: {grid.best_params_}")
    print(f"Avg. time to fit: {grid.cv_results_['mean_fit_time'].mean():.3f}")
    print(f"Avg. time to predict: {grid.cv_results_['mean_score_time'].mean():.3f}")

In [24]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
import time

clf = RandomForestClassifier()

model = Pipeline([
    ('clf', clf)
])

params = {
    'clf__n_estimators': [100, 200, 300],
    'clf__max_depth': [20, 30, 40],
    'clf__min_samples_split': [2, 3, 4],
    'clf__min_samples_leaf': [1, 2, 3],
}

time_start = time.time()
get_best_model_and_accuracy(model, params, X, y)
print(f"Time elapsed = {time.time() - time_start} (sec)")

Best accuracy: 0.942
Best parameters: {'clf__max_depth': 40, 'clf__min_samples_leaf': 1, 'clf__min_samples_split': 2, 'clf__n_estimators': 200}
Avg. time to fit: 0.670
Avg. time to predict: 0.036
Time elapsed = 40.53124213218689 (sec)


In [25]:
from sklearn.metrics import accuracy_score

clf = RandomForestClassifier(n_estimators=200, max_depth=40, min_samples_leaf=1, min_samples_split=2)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

print(f"Testing accuracy = {accuracy_score(y_test, y_pred):.3f}")

Testing accuracy = 0.978
