In [1]:
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals

# Common imports
import numpy as np
import os

# to make this notebook's output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "decision_trees"

def image_path(fig_id):
    return os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID, fig_id)

def save_fig(fig_id, tight_layout=True):
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(image_path(fig_id) + ".png", format='png', dpi=300)

In [2]:
# 演習6-10-7

# moonsデータセットの生成
from sklearn.datasets import make_moons

X, y = make_moons(n_samples=10000, noise=0.4)

In [3]:
# テストセットの分割
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=0.2,
                                                    random_state=42)

In [4]:
# グリッドサーチ

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV

param_grid = [
    {'max_leaf_nodes': list(range(2, 20))}
]

decision_clf = DecisionTreeClassifier()

grid_search = GridSearchCV(decision_clf, param_grid, cv=3
                          ,n_jobs=-1)

grid_search.fit(X_train, y_train)

GridSearchCV(cv=3, error_score='raise-deprecating',
       estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best'),
       fit_params=None, iid='warn', n_jobs=-1,
       param_grid=[{'max_leaf_nodes': [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]}],
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=0)

In [5]:
grid_search.best_params_

{'max_leaf_nodes': 17}

In [6]:
from sklearn.metrics import accuracy_score

final_predictions = grid_search.best_estimator_.predict(X_test)

score = accuracy_score(y_test, final_predictions)
score

0.8695

In [7]:
# 演習6-10-8

# 100個のインスタンスを無作為に選択した
# 1000個のサブセットを作成

subsets = []

from sklearn.model_selection import ShuffleSplit

ss = ShuffleSplit(n_splits=1000, test_size=0.2, random_state=42)

for train_index, test_index in ss.split(X_train):
    print(train_index)
    print(test_index)
    X_train_n = X_train[train_index]
    y_train_n = y_train[train_index]
    subsets.append((X_train_n, y_train_n))

[1467 5768 5714 ...  860 7603 7270]
[2215 2582 1662 ... 1115 6093 6832]
[3643 2054 7218 ...  595 7653 7210]
[6130 4282 4667 ...  851 6289 3332]
[6125 4220 5705 ... 4919 6233 3550]
[6046 5218 7163 ... 1440 3279 7303]
[7752 4753 2253 ... 2865 1041 1614]
[2629 4050  962 ... 1023 4353  632]
[1170 1848 6564 ... 2591 6823 1049]
[2631 6782 4896 ... 1975 7778 4997]
[ 870 3665 2716 ... 4849 6421  449]
[2398  445 3782 ... 1900  794 6221]
[ 700 6697 5668 ... 1450 5416 4160]
[7939 6485 1726 ... 2589 1304 3155]
[7444 2129  315 ... 6272 3924 5274]
[4509 3234 5381 ...  834 6520 1570]
[4934 7169 7436 ...  520 3843 3618]
[1363 3586 2179 ... 2401 4670 7338]
[ 681 1176 3478 ...  994 5929 5283]
[3762 3523 7002 ... 1212 6317 4343]
[1391 5517 6703 ... 5529 5831 2130]
[ 144 3690 6715 ... 1053 2306 4796]
[4916 6722 2776 ... 2471 2787 7052]
[6549 1801 1595 ... 5605 7085 1003]
[4883 3080 4792 ... 2653 2929 1453]
[6640 5810 1809 ... 5583 4337 3652]
[7082 3905 6753 ... 4498 2142 1727]
[ 397 5634 6830 ... 7959 341

[5098 3860 2391 ... 4124 4977  672]
[2844 7323 7848 ... 2312  165  252]
[4909  833 1569 ... 3560 1904 5650]
[7730  735 7299 ... 4319  519 6830]
[4668 1487  861 ... 6619 5258 5420]
[2356 2238 6349 ... 5885 2663 2677]
[3367 3200 7921 ... 6850 3212 2446]
[7715 5882 5130 ...  162 1043 7960]
[1100 7424 6898 ... 4450  354 1077]
[4849 1790 2387 ... 6662 1105  604]
[1910 4926 6377 ...   76 5678 6305]
[5547 3891 7697 ... 4845 2804  900]
[2300 1804 4019 ... 1346 6145 1030]
[ 150 7137 2554 ...  221 2134 7730]
[1032 4949 1737 ... 4407  913   29]
[6616 5137 2187 ... 2915 3269 2019]
[ 934 7163 5653 ... 7312 1682 1109]
[3797 6954 2057 ...  922 3392 3844]
[6498 1033  349 ... 2556 7557 3537]
[2974 1420 6557 ...  595 1306 1931]
[7358 2685 1524 ... 2224 5425 6679]
[2200 7212  133 ... 6963 2203 4648]
[3653 3516 3569 ... 6275 1573 7700]
[7691 2178 6630 ... 1857 2367 4942]
[7525 3729  136 ... 3241 7675 3332]
[6738 4947 4710 ... 1154 5209 5976]
[3467 1423 2869 ... 5434 7932 1779]
[7366 6298 3690 ... 6784 672

[7151 5301 2093 ... 6032 1289 3612]
[4240 4430 4355 ... 3590 6944 5540]
[4024 4969 5390 ... 4915 1197 3290]
[1860 5384 4757 ... 2776 4136 5484]
[2842 7054  582 ... 3400 6986 4447]
[6669 5811 3605 ... 2836 1085 5295]
[6059 1815 2657 ... 3743 3163 1953]
[ 564 5396 2964 ... 2352 3810 5483]
[ 792 7899 6799 ... 6750 5327 7266]
[6130  209 7254 ... 6070 4118 7317]
[ 689 2556 4500 ...  592 7599 6002]
[3716 1096 5908 ... 7127 7025 1548]
[4046 6693 5145 ... 1512 5430 2765]
[1092 4686  175 ... 6901 1387 3739]
[6135  415 4302 ... 6437  685  185]
[1305 3094 2479 ... 2561 4107 4735]
[2082 2192 7345 ...  176 5838 7875]
[2065 1901 3622 ... 7815 7328 5497]
[4302 2592 4845 ... 7960 7176 6872]
[ 280 3211 5773 ... 3676 6116 3350]
[2296 3982 5510 ... 1992 3916 4836]
[6074 3693 4012 ... 2036 4730 5914]
[4799  381 2935 ... 4293 5017 3594]
[5265   20   42 ... 1460 1681  495]
[2802 4820 6628 ... 1212 7139 6057]
[ 301 3684 6368 ... 7902 7733  826]
[2983 3224 2231 ... 3433 3400 3987]
[3716 2962 4390 ... 7859 679

[5064 4401  258 ... 6985  795 7238]
[6186 5324 1231 ... 4242  411 3320]
[5548 5189 3125 ... 4529 5722 4479]
[3045 1334 4120 ... 3775 5289 2356]
[6489 2973 3616 ... 4348 7627  206]
[2536 4714 4560 ... 2710  120 1568]
[5717 2554 7980 ... 6055 4849  394]
[1798 4259 6148 ... 1773 6867 1122]
[3012 5578 6941 ... 2963 7473 1022]
[1732 1676  124 ... 4485 1036 3957]
[7373 6650 2004 ... 4580 7214  984]
[2193 7114 4794 ... 4424 1108 4675]
[5285 7313  447 ... 7569 1543 6923]
[3204  424  190 ... 1221  510 1200]
[4200 4781 4221 ... 4247 4057  550]
[4735 2490 6290 ... 5330 7057 6862]
[7549 5124 2708 ... 7199  632 1359]
[3574  839 7329 ...  190 3444 3611]
[5140 4336 2343 ... 6584 1824 6674]
[4262 5408 5236 ... 6992 7199 2056]
[4808 1313 5177 ...  370  892 5571]
[  69 6970 3320 ... 2237 3907 3762]
[4559 1760 7328 ... 6970 7584 3651]
[4128 5069 2571 ... 6293 5334  741]
[4080 6307  488 ... 2798 3509 4830]
[4799 2475  107 ... 5724 3191 6352]
[6967 2825 6974 ... 1950   74 5460]
[4965 6122 4618 ... 2726  52

[2972 7550 1635 ... 7339  742 2888]
[ 945 1579 6342 ... 2830 2799  294]
[4830 6491 1070 ...  192 2917 1421]
[5163 3800 4334 ... 7040 4249 2070]
[6999 7500 1759 ... 1673 1901 5106]
[1540  230 6118 ... 7673 6389 1803]
[4850 5490 3461 ... 4127 2612 7487]
[6778 7514  634 ... 4460 5734 1044]
[5937 3496 3185 ...  934 5927 7430]
[7066 4555 5619 ... 7629 3290 1842]
[7965 5531 5130 ...  597 3227 5391]
[ 550 5164 7460 ... 2287 4986 3606]
[3939 3298 5752 ... 1663 6258 5376]
[7788 6583  173 ... 3874 7107 4340]
[1053 3901 1892 ... 6857 4442 1481]
[2294 6744 4157 ... 5456 7632 6727]
[4881 1809 6303 ...  406 2510 2065]
[6079 3098  377 ...  424 5297 2524]
[  51 5971 1855 ... 1300 7049 3539]
[1349  883 6867 ... 5906 6883 3980]
[7742 5958 7852 ... 4040 3520 6958]
[7789 7260 4315 ... 1514 2479 6510]
[7690 2937  546 ... 4847 3988 4660]
[7103 6327 7697 ...  322 2185 7331]
[5341 1477 2281 ... 6045  574 5458]
[7364 5625 4321 ... 3725 2245 5195]
[2428  958 3623 ... 5210  178 5328]
[ 686 6085 7188 ... 4426  18

[2199 4316  260 ... 6110 6957 5607]
[5743 3751 7193 ... 3169 3510 3138]
[ 824 1483 4779 ...  861 7082 5465]
[7507  552 3085 ... 3744 5961 2012]
[3367 4197  864 ... 7338 2411 1986]
[5889 5374 6321 ...  812 7267 1531]
[6876 5081 3118 ... 4956 4812 1034]
[ 273 6165 4872 ... 1997 5227 1213]
[3073 5800 2282 ... 3874 7815 1081]
[2943 7855 7560 ... 2850   12 5087]
[5739 4462 3146 ... 5727 7223 7683]
[2020 5065 2726 ... 6656 2992 1109]
[5684  195 7899 ... 6822  333 7299]
[4300 3962 1926 ... 1144 2261  638]
[7289  947  418 ... 4762  298 5966]
[6573  666 3771 ... 3226 5052 1978]
[7715 2112 4855 ... 6207 1388 6773]
[6766  737 1350 ... 6343 5084 7556]
[3930 6794 6030 ... 3089 7492 2202]
[4680  834 5666 ... 2217 4681  855]
[5381 7383 4613 ... 2866 4069  966]
[6716   19 2765 ... 3242 2555 4431]
[7916 1100 4820 ... 4459 4310 2767]
[5836 7798 7871 ... 3121 2168 7193]
[7137 4629 4401 ... 5361 1305 4918]
[7429 4970 1680 ... 5114 2305   40]
[7748 2232 3713 ... 1001 6976 2316]
[6584 4092 2681 ... 6808 265

[1843 7707  333 ... 7665 7877 3226]
[1928 1369 7735 ...  103 1540 3625]
[ 137 1236 6904 ... 7120 3253 1037]
[ 477 3702  486 ... 7824  573 3445]
[6139 6632 7548 ...  328 2402 5725]
[7669 5196  983 ... 4019 5258  126]
[1127 6053 2035 ... 5585 1489 1682]
[4651  563 7553 ... 3414 2330 2520]
[1130 5108 1328 ... 4411 6281 6550]
[1086 4951 6556 ...  856 1748 2022]
[5436 6804 3151 ... 3554 4558 4853]
[7580 1939 2063 ... 3529 2687 1030]
[4160 4017 1905 ... 5415 2234 6861]
[1151 5756 5539 ... 2356 5060  478]
[ 688 7841 7465 ... 3895 4435 5010]
[6841 3409 6323 ...  226 2840 7585]
[2776 3681 7354 ... 1248 1258 2075]
[ 897 4519   65 ... 1984 3828 5959]
[2710 4284 1412 ...  326 1001  158]
[5824 6758 7729 ... 7128 5488 6201]
[5061 3989 1471 ... 3958 2803 5496]
[7031 2905 6689 ... 6113 1341 2435]
[5070 7639  280 ... 7723 2686 5865]
[5448 5569 4733 ... 3793 2700 5244]
[2234  783 2691 ...  848 1978 4375]
[1350 6550 3238 ... 4832  175 1151]
[1501 2353 7089 ... 2950  490 6767]
[2715 3339 4816 ... 4425 115

In [8]:
from sklearn.base import clone

# サブセットごとに訓練

decision_clf_list = [clone(grid_search.best_estimator_) for _ in range(1000)]

accuracy_scores = []

for decision_clf, (X_train_n, y_train_n) in zip(decision_clf_list, subsets):
    decision_clf.fit(X_train_n, y_train_n)
    
    y_pred = decision_clf.predict(X_test)
    accuracy_scores.append(accuracy_score(y_test, y_pred))

# 配列(正確度)の平均値    
np.mean(accuracy_scores)

0.8646885

In [9]:
# 多数決予測
Y_pred = np.empty([1000, len(X_test)], dtype=np.uint8)

for tree_index, tree in enumerate(decision_clf_list):
    print(tree_index)
    Y_pred[tree_index] = tree.predict(X_test)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [10]:
from scipy.stats import mode

y_pred_majority_votes, n_votes = mode(Y_pred, axis=0)
print(mode(Y_pred, axis=0))

ModeResult(mode=array([[1, 1, 0, ..., 0, 0, 0]], dtype=uint8), count=array([[1000, 1000, 1000, ..., 1000, 1000,  840]]))


In [11]:
# 多数決予測による正確度
accuracy_score(y_test, y_pred_majority_votes.reshape([-1]))

0.867