# 붓꽃 품종 분류(KFold 모델)

In [25]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier # 분류기: 이진 트리
from sklearn.model_selection import KFold # KFold 모델
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np

In [26]:
iris = load_iris()

iris_data = iris.data # 붓꽃 데이터 저장
iris_label = iris.target # 붓꽃 레이블 저장
print(iris_label)
print(iris.target_names)

iris_df = pd.DataFrame(data=iris_data, columns=iris.feature_names)
iris_df['label'] = iris.target 
iris_df

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
['setosa' 'versicolor' 'virginica']


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),label
0,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,2
146,6.3,2.5,5.0,1.9,2
147,6.5,3.0,5.2,2.0,2
148,6.2,3.4,5.4,2.3,2


## 분류

In [27]:
dt_clf = DecisionTreeClassifier(random_state=156)

In [28]:
kfold = KFold(n_splits=5) # fold sets = 5개
cv_accuracy = []
n_iter = 0

## KFold 를 사용한 train/test set 분리

In [29]:
for train_index, test_index in kfold.split(iris_data):
    X_train, X_test = iris_data[train_index], iris_data[test_index]
    y_train, y_test = iris_label[train_index], iris_label[test_index]
    dt_clf.fit(X_train, y_train)
    pred = dt_clf.predict(X_test)
    n_iter += 1
    accuracy = np.round(accuracy_score(y_test, pred), 4) 
    train_size = X_train.shape[0]
    test_size = X_test.shape[0]
    print(n_iter, accuracy, train_size, test_size)
    print(n_iter, test_index)
    print()
    cv_accuracy.append(accuracy)

print(np.mean(cv_accuracy))

1 1.0 120 30
1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29]

2 0.9667 120 30
2 [30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
 54 55 56 57 58 59]

3 0.8667 120 30
3 [60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
 84 85 86 87 88 89]

4 0.9333 120 30
4 [ 90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119]

5 0.7333 120 30
5 [120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
 138 139 140 141 142 143 144 145 146 147 148 149]

0.9
