In [6]:
from sklearn import datasets, metrics
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
## 梯度提升樹算法 https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html
## 梯度提升用法 補充資料: https://sklearn.apachecn.org/docs/master/12.html 
from sklearn.ensemble import GradientBoostingRegressor

In [7]:
# 讀取資料集
diabetes = datasets.load_diabetes()

In [8]:
## 查看數據集
diabetes

{'data': array([[ 0.03807591,  0.05068012,  0.06169621, ..., -0.00259226,
          0.01990749, -0.01764613],
        [-0.00188202, -0.04464164, -0.05147406, ..., -0.03949338,
         -0.06833155, -0.09220405],
        [ 0.08529891,  0.05068012,  0.04445121, ..., -0.00259226,
          0.00286131, -0.02593034],
        ...,
        [ 0.04170844,  0.05068012, -0.01590626, ..., -0.01107952,
         -0.04688253,  0.01549073],
        [-0.04547248, -0.04464164,  0.03906215, ...,  0.02655962,
          0.04452873, -0.02593034],
        [-0.04547248, -0.04464164, -0.0730303 , ..., -0.03949338,
         -0.00422151,  0.00306441]]),
 'target': array([151.,  75., 141., 206., 135.,  97., 138.,  63., 110., 310., 101.,
         69., 179., 185., 118., 171., 166., 144.,  97., 168.,  68.,  49.,
         68., 245., 184., 202., 137.,  85., 131., 283., 129.,  59., 341.,
         87.,  65., 102., 265., 276., 252.,  90., 100.,  55.,  61.,  92.,
        259.,  53., 190., 142.,  75., 142., 155., 225.,  59

In [9]:
# 切分訓練集/測試集
x_train, x_test, y_train, y_test = train_test_split(diabetes.data, diabetes.target, test_size=0.25, random_state=42)

# 建立模型
clf = GradientBoostingRegressor(random_state=7)

# 先看看使用預設參數得到的結果，約為 8.379 的 MSE
clf.fit(x_train, y_train)
y_pred = clf.predict(x_test)
print(y_pred)
print(metrics.mean_squared_error(y_test, y_pred))

[150.68112664 178.65623672 164.29556533 264.1458946  119.34914439
 100.0066001  262.61805636 193.75687766 156.9277513  141.94009502
  93.66018541 203.16978624  91.81382551 231.1475259  101.70234301
 104.3002564  204.87906413 260.25456712 173.41098469 228.65621133
 182.76128625  92.09131258  57.71263035 206.34374662 149.27440239
 200.66978189 230.9655656  208.70344588  65.66211676 111.34052639
 190.95486696 113.63465037 151.36081426 200.35069462 145.56152671
 213.13001427 115.78197705 120.39914176 182.02414273  70.56445464
  56.42472183  88.83125248 194.51853483 180.53436082 185.70157169
  68.92573667  90.2953634  128.71045645  66.36018314 154.51602553
 127.33762419  74.53385563 145.80730988  95.52467116 204.96544395
 130.89073069  94.14609549 214.42734668  84.46167615  98.45931564
 168.05991836 185.30055574 129.42491838  88.77780802 118.7357578
 247.2121936  182.73525531 199.37105778 159.853519   120.27426071
 159.02689489 188.50473103 215.13206188  87.64949931  71.42656571
 224.028878

In [10]:
# 設定要訓練的超參數組合
n_estimators = [100, 200, 300, 400, 500]
max_depth = [1, 3, 5, 7, 9]
param_grid = dict(n_estimators=n_estimators, max_depth=max_depth)

## 建立搜尋物件，放入模型及參數組合字典 (n_jobs=-1 會使用全部 cpu 平行運算)
## GridSearchCV:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
## scoring選擇 https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
grid_search = GridSearchCV(clf, param_grid, scoring="neg_mean_squared_error", n_jobs=-1, verbose=1)
# 開始搜尋最佳參數
grid_result = grid_search.fit(x_train, y_train)
# 預設會跑 5-fold cross-validadtion，總共 9 種參數組合，總共要 train 27 次模型

Fitting 5 folds for each of 25 candidates, totalling 125 fits


In [11]:
# 印出最佳結果與最佳參數
print("Best Accuracy: %f using %s" % (grid_result.best_score_, grid_result.best_params_))

Best Accuracy: -3247.214455 using {'max_depth': 1, 'n_estimators': 200}


In [12]:
grid_result.best_params_

{'max_depth': 1, 'n_estimators': 200}

In [13]:
# 使用最佳參數重新建立模型
clf_bestparam = GradientBoostingRegressor(max_depth=grid_result.best_params_['max_depth'],
                                           n_estimators=grid_result.best_params_['n_estimators'])

# 訓練模型
clf_bestparam.fit(x_train, y_train)

# 預測測試集
y_pred = clf_bestparam.predict(x_test)
y_pred

array([147.62791737, 169.93514815, 145.96103907, 296.37797451,
       113.11578163,  94.65616707, 286.24163386, 190.51760893,
       141.45659258, 135.78217337, 101.13954016, 171.58696384,
       101.78116202, 238.34180325, 115.71284898, 103.57807589,
       206.07659343, 279.91670858, 189.70019724, 222.10920399,
       208.28689141,  99.85134863,  65.65623649, 203.8306259 ,
       142.10330313, 185.68347753, 190.3440274 , 189.29927462,
        67.30301071, 106.78929635, 175.96589711, 106.72501955,
       142.89331852, 180.06826006, 175.45516721, 233.06425568,
       126.09880562, 112.17785364, 168.94913986,  71.59918232,
        60.30622124, 102.89829965, 176.5078026 , 166.40050282,
       165.86413562,  72.88737385, 100.41272618, 110.01502495,
        81.51111436, 151.44233779, 130.17625418,  87.92769734,
       140.50611114, 104.22992208, 200.82153355, 129.54403737,
        92.07138239, 197.16217893,  95.2669606 ,  94.65317351,
       185.85489618, 211.69914617, 120.23131756, 116.48

In [14]:
# 調整參數後約可降至 8.30 的 MSE
print(metrics.mean_squared_error(y_test, y_pred))

2812.9857279113457
