<a href="https://colab.research.google.com/github/galileo15640215/kikagaku/blob/master/kikagaku03.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

機械学習 実践（ハイパーパラメータ調整）
https://www.kikagaku.ai/tutorial/basic_of_machine_learning/learn/machine_learning_hyperparameters

In [0]:
import numpy as np
import pandas as pd

In [0]:
# 乳がんに関するデータセットの読み込み
from sklearn.datasets import load_breast_cancer
dataset = load_breast_cancer()

In [0]:
t = dataset.target
x = dataset.data

In [0]:
x.shape, t.shape

((569, 30), (569,))

In [0]:
from sklearn.model_selection import train_test_split
x_train_val, x_test, t_train_val, t_test = train_test_split(x, t, test_size=0.2, random_state=1)

In [0]:
# 検証用データセット：学習用データセット＝ 30 ： 70
x_train, x_val, t_train, t_val = train_test_split(x_train_val, t_train_val, test_size=0.3, random_state=1)

In [0]:
x_train.shape, x_val.shape, x_test.shape

((318, 30), (137, 30), (114, 30))

In [0]:
from sklearn.tree import DecisionTreeClassifier
dtree = DecisionTreeClassifier(random_state=0)

In [0]:
dtree.fit(x_train, t_train)

DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=None, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=0, splitter='best')

In [0]:
print('train score : ', dtree.score(x_train, t_train))
print('validation score : ', dtree.score(x_val, t_val))

train score :  1.0
validation score :  0.927007299270073


In [0]:
# ハイパーパラメータを設定して、モデルの定義
dtree = DecisionTreeClassifier(max_depth=10, min_samples_split=30, random_state=0)

dtree.fit(x_train, t_train)

DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=10, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=30,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=0, splitter='best')

In [0]:
print('train score : ', dtree.score(x_train, t_train))
print('validation score : ', dtree.score(x_val, t_val))

train score :  0.9308176100628931
validation score :  0.9562043795620438


In [0]:
print('test score :', dtree.score(x_test, t_test))

test score : 0.9298245614035088


In [0]:
# GridSearchCV クラスのインポート
from sklearn.model_selection import GridSearchCV

In [0]:
# 学習に使用するアルゴリズムの定義
estimator = DecisionTreeClassifier(random_state=0)

In [0]:
# 探索するハイパーパラメータと範囲の定義
param_grid = [{
    'max_depth': [3, 20, 50],
    'min_samples_split': [3, 20, 30]
}]

In [0]:
# データセット分割数を定義
cv = 5

In [0]:
# GridSearchCV クラスを用いたモデルの定義
tuned_model = GridSearchCV(estimator=estimator, 
                           param_grid=param_grid, 
                           cv=cv, return_train_score=False)

In [0]:
# モデルの学習＆検証
tuned_model.fit(x_train_val, t_train_val)

GridSearchCV(cv=5, error_score=nan,
             estimator=DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None,
                                              criterion='gini', max_depth=None,
                                              max_features=None,
                                              max_leaf_nodes=None,
                                              min_impurity_decrease=0.0,
                                              min_impurity_split=None,
                                              min_samples_leaf=1,
                                              min_samples_split=2,
                                              min_weight_fraction_leaf=0.0,
                                              presort='deprecated',
                                              random_state=0, splitter='best'),
             iid='deprecated', n_jobs=None,
             param_grid=[{'max_depth': [3, 20, 50],
                          'min_samples_split': [3, 20, 30]}],
             

In [0]:
# 検証結果の確認
pd.DataFrame(tuned_model.cv_results_).T

Unnamed: 0,0,1,2,3,4,5,6,7,8
mean_fit_time,0.00465097,0.00363836,0.00363026,0.00439744,0.00440125,0.00429292,0.00441742,0.00438366,0.0043313
std_fit_time,0.000733465,2.46837e-05,1.63414e-05,0.00019559,0.000296823,0.000235866,0.000204605,0.000178314,0.000264515
mean_score_time,0.000431871,0.00031209,0.000304794,0.000304174,0.000319433,0.000306463,0.000294065,0.000318527,0.000303316
std_score_time,0.000121003,2.35105e-05,1.60097e-05,1.04583e-05,1.80805e-05,1.69338e-05,3.71076e-06,2.29517e-05,7.39697e-06
param_max_depth,3,3,3,20,20,20,50,50,50
param_min_samples_split,3,20,30,3,20,30,3,20,30
params,"{'max_depth': 3, 'min_samples_split': 3}","{'max_depth': 3, 'min_samples_split': 20}","{'max_depth': 3, 'min_samples_split': 30}","{'max_depth': 20, 'min_samples_split': 3}","{'max_depth': 20, 'min_samples_split': 20}","{'max_depth': 20, 'min_samples_split': 30}","{'max_depth': 50, 'min_samples_split': 3}","{'max_depth': 50, 'min_samples_split': 20}","{'max_depth': 50, 'min_samples_split': 30}"
split0_test_score,0.923077,0.912088,0.912088,0.956044,0.912088,0.912088,0.956044,0.912088,0.912088
split1_test_score,0.901099,0.901099,0.901099,0.912088,0.901099,0.901099,0.912088,0.901099,0.901099
split2_test_score,0.934066,0.934066,0.934066,0.923077,0.934066,0.934066,0.923077,0.934066,0.934066


In [0]:
estimator = DecisionTreeClassifier(random_state=0)
cv = 5
param_grid = [{
    'max_depth': [5, 10, 15] , 
    'min_samples_split': [10, 12, 15]
}]

In [0]:
# モデルの定義
tuned_model = GridSearchCV(estimator=estimator, 
                           param_grid=param_grid, 
                           cv=cv, return_train_score=False)

# モデルの学習
tuned_model.fit(x_train_val, t_train_val)

GridSearchCV(cv=5, error_score=nan,
             estimator=DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None,
                                              criterion='gini', max_depth=None,
                                              max_features=None,
                                              max_leaf_nodes=None,
                                              min_impurity_decrease=0.0,
                                              min_impurity_split=None,
                                              min_samples_leaf=1,
                                              min_samples_split=2,
                                              min_weight_fraction_leaf=0.0,
                                              presort='deprecated',
                                              random_state=0, splitter='best'),
             iid='deprecated', n_jobs=None,
             param_grid=[{'max_depth': [5, 10, 15],
                          'min_samples_split': [10, 12, 15]}],
            

In [0]:
# 学習結果の確認
pd.DataFrame(tuned_model.cv_results_).T

Unnamed: 0,0,1,2,3,4,5,6,7,8
mean_fit_time,0.0057682,0.0044908,0.00439858,0.00445137,0.004423,0.00435157,0.00447149,0.00441513,0.00439181
std_fit_time,0.0012571,0.000254556,0.000223073,0.000212777,0.000213973,0.000205486,0.000289485,0.000203764,0.000199727
mean_score_time,0.000544024,0.00035491,0.000335789,0.000340796,0.000315142,0.000336552,0.000350189,0.000316954,0.000332499
std_score_time,0.000197883,9.73098e-06,1.16143e-05,3.18367e-05,9.55508e-06,1.72194e-05,3.44432e-05,1.43445e-05,1.94761e-05
param_max_depth,5,5,5,10,10,10,15,15,15
param_min_samples_split,10,12,15,10,12,15,10,12,15
params,"{'max_depth': 5, 'min_samples_split': 10}","{'max_depth': 5, 'min_samples_split': 12}","{'max_depth': 5, 'min_samples_split': 15}","{'max_depth': 10, 'min_samples_split': 10}","{'max_depth': 10, 'min_samples_split': 12}","{'max_depth': 10, 'min_samples_split': 15}","{'max_depth': 15, 'min_samples_split': 10}","{'max_depth': 15, 'min_samples_split': 12}","{'max_depth': 15, 'min_samples_split': 15}"
split0_test_score,0.967033,0.923077,0.912088,0.967033,0.923077,0.912088,0.967033,0.923077,0.912088
split1_test_score,0.912088,0.901099,0.901099,0.912088,0.901099,0.901099,0.912088,0.901099,0.901099
split2_test_score,0.923077,0.934066,0.934066,0.923077,0.934066,0.934066,0.923077,0.934066,0.934066


In [0]:
# 最も予測精度の高かったハイパーパラメータの確認
tuned_model.best_params_

{'max_depth': 5, 'min_samples_split': 10}

In [0]:
# 最も予測精度の高かったモデルの引き継ぎ
best_model = tuned_model.best_estimator_

# モデルの検証
print(best_model.score(x_train_val, t_train_val))
print(best_model.score(x_test, t_test))

0.9934065934065934
0.956140350877193


In [0]:
# RandomizedSearchCV クラスのインポート
from sklearn.model_selection import RandomizedSearchCV

In [0]:
# 学習に使用するアルゴリズム
estimator = DecisionTreeClassifier(random_state=0)

In [0]:
list(range(1, 10, 2))

[1, 3, 5, 7, 9]

In [0]:
# ハイパーパラメータを探索する範囲の指定
param_distributions = {
    'max_depth': list(range(5, 100, 2)),
    'min_samples_split': list(range(2, 50, 1))
}

In [0]:
# 試行回数の指定
n_iter = 100

In [0]:
cv = 5

In [0]:
# モデルの定義
tuned_model = RandomizedSearchCV(
    estimator=estimator, 
    param_distributions=param_distributions, 
    n_iter=n_iter, cv=cv, 
    random_state=0, return_train_score=False
)

In [0]:
# モデルの学習＆検証
tuned_model.fit(x_train_val, t_train_val)

RandomizedSearchCV(cv=5, error_score=nan,
                   estimator=DecisionTreeClassifier(ccp_alpha=0.0,
                                                    class_weight=None,
                                                    criterion='gini',
                                                    max_depth=None,
                                                    max_features=None,
                                                    max_leaf_nodes=None,
                                                    min_impurity_decrease=0.0,
                                                    min_impurity_split=None,
                                                    min_samples_leaf=1,
                                                    min_samples_split=2,
                                                    min_weight_fraction_leaf=0.0,
                                                    presort='deprecated',
                                                    random_state=0,
             

In [0]:
# 学習結果の確認（スコアの高い順に表示）
pd.DataFrame(tuned_model.cv_results_).sort_values('rank_test_score').T

Unnamed: 0,47,77,82,90,42,19,28,12,11,62,69,39,70,3,96,29,6,68,43,34,9,48,45,33,91,32,25,37,44,46,36,52,54,57,59,61,63,66,76,75,...,97,78,83,92,79,89,88,80,81,87,86,85,93,0,49,71,2,5,7,16,17,20,21,22,23,26,72,30,35,38,40,41,98,50,55,58,60,67,31,99
mean_fit_time,0.00440059,0.00439644,0.00442047,0.00444942,0.00440087,0.00441413,0.00444965,0.00441675,0.00443063,0.00439563,0.00441728,0.0045197,0.00446839,0.00440903,0.00486445,0.00444055,0.00445151,0.00444665,0.00439844,0.00446701,0.00446858,0.00443587,0.00444555,0.00438561,0.00441499,0.00437269,0.0043612,0.004357,0.00433221,0.00433259,0.00444503,0.00433707,0.00434718,0.00435596,0.004386,0.00433688,0.00429759,0.00443082,0.00431967,0.00432754,...,0.00428357,0.00429215,0.00424848,0.00429354,0.00429106,0.00424395,0.0042582,0.00427666,0.00425286,0.00424485,0.00430832,0.00424695,0.00429783,0.00575585,0.00425987,0.00432897,0.00425711,0.00428023,0.00430202,0.00471916,0.00430641,0.00428996,0.00430493,0.00426722,0.00426445,0.00427494,0.00428519,0.00431919,0.00433512,0.00425086,0.00427585,0.00428262,0.00436959,0.00432982,0.0046957,0.0043416,0.00425773,0.00430794,0.00430884,0.00438004
std_fit_time,0.000238746,0.000215985,0.000167014,0.000241989,0.000216492,0.000213555,0.000219321,0.000203048,0.00016439,0.000240847,0.000215357,0.000330876,0.000212834,0.000203424,0.000931515,0.000192062,0.000241653,0.000205735,0.000226294,0.000201433,0.000220231,0.000199498,0.00021574,0.000209764,0.000212242,0.000209069,0.000220361,0.000218135,0.000186026,0.000203951,0.000300953,0.000188656,0.000247538,0.000195813,0.000195675,0.000197955,0.000226884,0.000292104,0.000200963,0.000199411,...,0.000220065,0.000186546,0.000234413,0.000240374,0.000217027,0.000238792,0.000237123,0.000229917,0.000248316,0.000253463,0.000251808,0.000227776,0.000240031,0.000997609,0.000237534,0.000217147,0.000240523,0.000228501,0.000253057,0.000942285,0.000242048,0.000242689,0.000252543,0.000242763,0.000232826,0.000214698,0.00023795,0.000278517,0.0002743,0.000231631,0.00024289,0.000239801,0.000289206,0.000210662,0.000716528,0.000283723,0.00024711,0.000246881,0.000266523,0.000244273
mean_score_time,0.000316048,0.000303888,0.000298738,0.000317097,0.000292253,0.000293732,0.000329542,0.000303125,0.00030179,0.000306654,0.000305748,0.000305939,0.000325012,0.000311613,0.00033493,0.000310469,0.000296211,0.000317669,0.000311947,0.000310564,0.000312281,0.000300026,0.000299406,0.000352287,0.000349092,0.000313473,0.000293064,0.000313807,0.000311565,0.000294209,0.000315809,0.000309324,0.000336552,0.000318241,0.000323915,0.000298452,0.000302744,0.000355721,0.000296736,0.000301218,...,0.00030241,0.000315857,0.000290632,0.000322199,0.000312948,0.000298405,0.000301981,0.000299072,0.000291443,0.000300264,0.000310326,0.000295162,0.000318527,0.000445271,0.000297785,0.000324345,0.000290632,0.000303173,0.000312042,0.000350237,0.000345898,0.000295925,0.000302792,0.000296879,0.000296736,0.000304174,0.000296354,0.000317192,0.000307226,0.000300407,0.000307798,0.000290632,0.000352144,0.000309038,0.000372171,0.00032649,0.000299263,0.000335646,0.00031476,0.000342989
std_score_time,1.8254e-05,2.026e-05,8.52059e-06,1.77816e-05,6.86943e-06,8.12304e-06,1.51687e-05,1.20241e-05,1.31959e-05,1.55579e-05,1.3887e-05,9.47263e-06,1.99358e-05,2.03554e-05,4.94582e-05,1.01202e-05,5.92239e-06,2.14172e-05,1.26578e-05,1.80636e-05,1.33714e-05,1.334e-05,5.7841e-06,5.83732e-05,3.83172e-05,1.3659e-05,6.63712e-06,1.22264e-05,1.14087e-05,1.34895e-05,6.71748e-06,6.85486e-06,3.23423e-05,1.82644e-05,2.22406e-05,9.70501e-06,8.49707e-06,2.17837e-05,7.55241e-06,4.57515e-06,...,9.35209e-06,2.02571e-05,1.24864e-05,8.62087e-06,1.0517e-05,9.9173e-06,9.296e-06,4.45683e-06,6.00626e-06,5.13038e-06,2.19706e-05,7.81053e-06,4.95085e-06,8.7557e-05,7.32221e-06,1.4407e-05,6.93138e-06,1.15986e-05,2.62665e-05,5.48035e-05,0.000102369,1.28032e-05,4.31004e-06,1.12885e-05,1.20731e-05,9.4978e-06,6.87704e-06,1.45233e-05,1.04215e-05,1.41461e-05,1.2988e-05,3.04206e-06,1.15607e-05,1.39359e-05,5.99049e-05,1.81127e-05,9.05036e-06,9.52768e-06,2.22159e-05,3.69143e-05
param_min_samples_split,10,10,4,4,7,9,11,2,8,7,4,2,2,2,4,6,8,4,9,5,5,5,5,13,12,12,12,13,14,16,14,14,24,14,20,16,23,23,15,16,...,29,39,44,36,27,35,36,31,48,43,31,39,42,30,38,27,37,40,36,40,39,27,27,43,41,27,30,42,27,43,49,31,45,27,43,36,36,47,44,39
param_max_depth,23,65,95,39,15,37,7,87,29,7,9,21,97,89,41,65,25,47,35,59,87,29,13,73,5,31,55,35,11,77,15,49,7,53,91,45,91,95,69,61,...,89,27,61,39,81,89,17,73,15,67,27,37,71,9,9,45,63,95,59,11,25,27,37,73,55,19,79,93,35,49,87,23,19,99,27,27,47,75,95,87
params,"{'min_samples_split': 10, 'max_depth': 23}","{'min_samples_split': 10, 'max_depth': 65}","{'min_samples_split': 4, 'max_depth': 95}","{'min_samples_split': 4, 'max_depth': 39}","{'min_samples_split': 7, 'max_depth': 15}","{'min_samples_split': 9, 'max_depth': 37}","{'min_samples_split': 11, 'max_depth': 7}","{'min_samples_split': 2, 'max_depth': 87}","{'min_samples_split': 8, 'max_depth': 29}","{'min_samples_split': 7, 'max_depth': 7}","{'min_samples_split': 4, 'max_depth': 9}","{'min_samples_split': 2, 'max_depth': 21}","{'min_samples_split': 2, 'max_depth': 97}","{'min_samples_split': 2, 'max_depth': 89}","{'min_samples_split': 4, 'max_depth': 41}","{'min_samples_split': 6, 'max_depth': 65}","{'min_samples_split': 8, 'max_depth': 25}","{'min_samples_split': 4, 'max_depth': 47}","{'min_samples_split': 9, 'max_depth': 35}","{'min_samples_split': 5, 'max_depth': 59}","{'min_samples_split': 5, 'max_depth': 87}","{'min_samples_split': 5, 'max_depth': 29}","{'min_samples_split': 5, 'max_depth': 13}","{'min_samples_split': 13, 'max_depth': 73}","{'min_samples_split': 12, 'max_depth': 5}","{'min_samples_split': 12, 'max_depth': 31}","{'min_samples_split': 12, 'max_depth': 55}","{'min_samples_split': 13, 'max_depth': 35}","{'min_samples_split': 14, 'max_depth': 11}","{'min_samples_split': 16, 'max_depth': 77}","{'min_samples_split': 14, 'max_depth': 15}","{'min_samples_split': 14, 'max_depth': 49}","{'min_samples_split': 24, 'max_depth': 7}","{'min_samples_split': 14, 'max_depth': 53}","{'min_samples_split': 20, 'max_depth': 91}","{'min_samples_split': 16, 'max_depth': 45}","{'min_samples_split': 23, 'max_depth': 91}","{'min_samples_split': 23, 'max_depth': 95}","{'min_samples_split': 15, 'max_depth': 69}","{'min_samples_split': 16, 'max_depth': 61}",...,"{'min_samples_split': 29, 'max_depth': 89}","{'min_samples_split': 39, 'max_depth': 27}","{'min_samples_split': 44, 'max_depth': 61}","{'min_samples_split': 36, 'max_depth': 39}","{'min_samples_split': 27, 'max_depth': 81}","{'min_samples_split': 35, 'max_depth': 89}","{'min_samples_split': 36, 'max_depth': 17}","{'min_samples_split': 31, 'max_depth': 73}","{'min_samples_split': 48, 'max_depth': 15}","{'min_samples_split': 43, 'max_depth': 67}","{'min_samples_split': 31, 'max_depth': 27}","{'min_samples_split': 39, 'max_depth': 37}","{'min_samples_split': 42, 'max_depth': 71}","{'min_samples_split': 30, 'max_depth': 9}","{'min_samples_split': 38, 'max_depth': 9}","{'min_samples_split': 27, 'max_depth': 45}","{'min_samples_split': 37, 'max_depth': 63}","{'min_samples_split': 40, 'max_depth': 95}","{'min_samples_split': 36, 'max_depth': 59}","{'min_samples_split': 40, 'max_depth': 11}","{'min_samples_split': 39, 'max_depth': 25}","{'min_samples_split': 27, 'max_depth': 27}","{'min_samples_split': 27, 'max_depth': 37}","{'min_samples_split': 43, 'max_depth': 73}","{'min_samples_split': 41, 'max_depth': 55}","{'min_samples_split': 27, 'max_depth': 19}","{'min_samples_split': 30, 'max_depth': 79}","{'min_samples_split': 42, 'max_depth': 93}","{'min_samples_split': 27, 'max_depth': 35}","{'min_samples_split': 43, 'max_depth': 49}","{'min_samples_split': 49, 'max_depth': 87}","{'min_samples_split': 31, 'max_depth': 23}","{'min_samples_split': 45, 'max_depth': 19}","{'min_samples_split': 27, 'max_depth': 99}","{'min_samples_split': 43, 'max_depth': 27}","{'min_samples_split': 36, 'max_depth': 27}","{'min_samples_split': 36, 'max_depth': 47}","{'min_samples_split': 47, 'max_depth': 75}","{'min_samples_split': 44, 'max_depth': 95}","{'min_samples_split': 39, 'max_depth': 87}"
split0_test_score,0.967033,0.967033,0.967033,0.967033,0.967033,0.967033,0.967033,0.956044,0.967033,0.967033,0.967033,0.956044,0.956044,0.956044,0.967033,0.967033,0.967033,0.967033,0.967033,0.967033,0.967033,0.967033,0.967033,0.923077,0.923077,0.923077,0.923077,0.923077,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,...,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088
split1_test_score,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.901099,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,...,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099,0.901099
split2_test_score,0.923077,0.923077,0.912088,0.912088,0.912088,0.912088,0.923077,0.923077,0.912088,0.912088,0.912088,0.923077,0.923077,0.923077,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.912088,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,0.934066,...,0.934066,0.945055,0.945055,0.945055,0.934066,0.945055,0.945055,0.934066,0.945055,0.945055,0.934066,0.945055,0.945055,0.934066,0.945055,0.934066,0.945055,0.945055,0.945055,0.945055,0.945055,0.934066,0.934066,0.945055,0.945055,0.934066,0.934066,0.945055,0.934066,0.945055,0.945055,0.934066,0.945055,0.934066,0.945055,0.945055,0.945055,0.945055,0.945055,0.945055


In [0]:
# 最も予測精度の高かったハイパーパラメータの確認
tuned_model.best_params_

{'max_depth': 23, 'min_samples_split': 10}

In [0]:
# 最も予測精度の高かったモデルの引き継ぎ
best_model = tuned_model.best_estimator_

In [0]:
# モデルの検証
print(best_model.score(x_train_val, t_train_val))
print(best_model.score(x_test, t_test))

0.9934065934065934
0.956140350877193


In [0]:
# optuna のインストール
!pip install optuna

Collecting optuna
[?25l  Downloading https://files.pythonhosted.org/packages/85/ee/2688cce5ced0597e12832d1ec4f4383a468f6bddff768eeaa3b5bf4f6500/optuna-1.3.0.tar.gz (163kB)
[K     |████████████████████████████████| 163kB 2.7MB/s 
[?25hCollecting alembic
[?25l  Downloading https://files.pythonhosted.org/packages/60/1e/cabc75a189de0fbb2841d0975243e59bde8b7822bacbb95008ac6fe9ad47/alembic-1.4.2.tar.gz (1.1MB)
[K     |████████████████████████████████| 1.1MB 41.8MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting cliff
[?25l  Downloading https://files.pythonhosted.org/packages/b9/17/57187872842bf9f65815b6969b515528ec7fd754137d2d3f49e3bc016175/cliff-3.1.0-py3-none-any.whl (80kB)
[K     |████████████████████████████████| 81kB 9.4MB/s 
[?25hCollecting cmaes
  Downloading https://files.pythonhosted.org/packages/03/de/6ed34ebc0e5c34ed371d898540bca36edb4afe5bb

In [0]:
import optuna

In [0]:
from sklearn.model_selection import cross_val_score

In [0]:
def objective(trial, x, t, cv):
    # ① ハイパーパラメータごとに探索範囲を指定
    max_depth = trial.suggest_int('max_depth', 2, 100)
    min_samples_split = trial.suggest_int('min_samples_split', 2, 100)

    # ② 学習に使用するアルゴリズムを指定
    estimator = DecisionTreeClassifier(
      max_depth = max_depth,
      min_samples_split = min_samples_split
    )

    # ③ 学習の実行、検証結果の表示
    print('Current_params : ', trial.params)
    accuracy = cross_val_score(estimator, x, t, cv=cv).mean()
    return accuracy

In [0]:
# study オブジェクトの作成（最大化）
study = optuna.create_study(direction='maximize')

In [0]:
# K 分割交差検証の K
cv = 5
# 目的関数の最適化
study.optimize(lambda trial: objective(trial, x_train_val, t_train_val, cv), n_trials=10)

print(study.best_trial)

[32m[I 2020-05-08 04:55:11,071][0m Finished trial#0 with value: 0.9208791208791209 with parameters: {'max_depth': 9, 'min_samples_split': 43}. Best is trial#0 with value: 0.9208791208791209.[0m


Current_params :  {'max_depth': 9, 'min_samples_split': 43}
Current_params :  {'max_depth': 34, 'min_samples_split': 53}


[32m[I 2020-05-08 04:55:11,198][0m Finished trial#1 with value: 0.9186813186813187 with parameters: {'max_depth': 34, 'min_samples_split': 53}. Best is trial#0 with value: 0.9208791208791209.[0m
[32m[I 2020-05-08 04:55:11,322][0m Finished trial#2 with value: 0.9274725274725275 with parameters: {'max_depth': 29, 'min_samples_split': 18}. Best is trial#2 with value: 0.9274725274725275.[0m


Current_params :  {'max_depth': 29, 'min_samples_split': 18}
Current_params :  {'max_depth': 91, 'min_samples_split': 83}


[32m[I 2020-05-08 04:55:11,448][0m Finished trial#3 with value: 0.9186813186813187 with parameters: {'max_depth': 91, 'min_samples_split': 83}. Best is trial#2 with value: 0.9274725274725275.[0m
[32m[I 2020-05-08 04:55:11,574][0m Finished trial#4 with value: 0.9208791208791209 with parameters: {'max_depth': 98, 'min_samples_split': 94}. Best is trial#2 with value: 0.9274725274725275.[0m


Current_params :  {'max_depth': 98, 'min_samples_split': 94}
Current_params :  {'max_depth': 15, 'min_samples_split': 82}


[32m[I 2020-05-08 04:55:11,703][0m Finished trial#5 with value: 0.9186813186813187 with parameters: {'max_depth': 15, 'min_samples_split': 82}. Best is trial#2 with value: 0.9274725274725275.[0m
[32m[I 2020-05-08 04:55:11,829][0m Finished trial#6 with value: 0.9274725274725275 with parameters: {'max_depth': 14, 'min_samples_split': 18}. Best is trial#2 with value: 0.9274725274725275.[0m


Current_params :  {'max_depth': 14, 'min_samples_split': 18}
Current_params :  {'max_depth': 99, 'min_samples_split': 91}


[32m[I 2020-05-08 04:55:11,954][0m Finished trial#7 with value: 0.9186813186813187 with parameters: {'max_depth': 99, 'min_samples_split': 91}. Best is trial#2 with value: 0.9274725274725275.[0m
[32m[I 2020-05-08 04:55:12,085][0m Finished trial#8 with value: 0.9186813186813187 with parameters: {'max_depth': 11, 'min_samples_split': 96}. Best is trial#2 with value: 0.9274725274725275.[0m


Current_params :  {'max_depth': 11, 'min_samples_split': 96}
Current_params :  {'max_depth': 13, 'min_samples_split': 45}


[32m[I 2020-05-08 04:55:12,213][0m Finished trial#9 with value: 0.9208791208791209 with parameters: {'max_depth': 13, 'min_samples_split': 45}. Best is trial#2 with value: 0.9274725274725275.[0m


FrozenTrial(number=2, value=0.9274725274725275, datetime_start=datetime.datetime(2020, 5, 8, 4, 55, 11, 199274), datetime_complete=datetime.datetime(2020, 5, 8, 4, 55, 11, 322253), params={'max_depth': 29, 'min_samples_split': 18}, distributions={'max_depth': IntUniformDistribution(high=100, low=2, step=1), 'min_samples_split': IntUniformDistribution(high=100, low=2, step=1)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=2, state=TrialState.COMPLETE)


In [0]:
# 最も予測精度の高かったハイパーパラメータの確認
study.best_params

{'max_depth': 29, 'min_samples_split': 18}

In [0]:
# 最適なハイパーパラメータを設定したモデルの定義
best_model = DecisionTreeClassifier(**study.best_params)

# モデルの学習
best_model.fit(x_train_val, t_train_val)

# モデルの検証
print(best_model.score(x_train_val, t_train_val))
print(best_model.score(x_test, t_test))

0.9868131868131869
0.9298245614035088
