In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import precision_score, recall_score, confusion_matrix, classification_report, accuracy_score, f1_score
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV

In [2]:
fetal = pd.read_csv("../fetal_health.csv")
X = fetal.drop(["fetal_health"],axis=1).values
y = fetal["fetal_health"].values.ravel().astype(int) - 1

In [3]:
fetal[fetal.duplicated()]
fetal_dup = fetal.drop_duplicates(subset = None , keep = 'first', inplace = False)
corr = fetal.corr()

In [4]:
X_train, X_test, y_train,y_test = train_test_split(X,y,test_size=0.3,random_state=123, stratify = y)

In [5]:
def get_results_simple(model, prediction): # 각 지표들 리턴해주는 함수
    result_dict = dict()
    
    test_accuracy = round(accuracy_score(y_test, prediction),4)
    recall = round(recall_score(y_test, prediction, average = "weighted", labels = np.unique(prediction)), 3)
    precision = round(precision_score(y_test, prediction, average = "weighted", labels = np.unique(prediction)), 3)
    f1 = round(f1_score(y_test, prediction, average = "micro", labels = np.unique(prediction)), 3)

    result_dict["test_accuracy"] = test_accuracy
    result_dict["recall"] = recall
    result_dict["f1_score"] = f1
    result_dict["precision"] = precision
    
    return result_dict

In [6]:
clf = DecisionTreeClassifier(random_state=123)

params =  {
    'min_samples_split': [2, 3, 4],
    'max_depth': [6, 16, None]
}

grid = GridSearchCV(estimator=clf,
                    param_grid=params,
                    cv=100,
                    n_jobs=1,
                    verbose=2)

grid.fit(X_train, y_train)

dt_grid = DecisionTreeClassifier(random_state=123, max_depth = grid.best_params_['max_depth'], min_samples_split = grid.best_params_['min_samples_split'])
dt_grid.fit(X_train, y_train)

Fitting 100 folds for each of 9 candidates, totalling 900 fits
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=2; total time=   0.0s
[CV] END ...................max_depth=6, min_s

[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=3; total time=   0.0s
[CV] END ...................

[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................max_depth=6, min_samples_split=4; total time=   0.0s
[CV] END ...................

[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=2; total time=   0.0s
[CV] END ..................m

[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=3; total time=   0.0s
[CV] END ..................m

[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................max_depth=16, min_samples_split=4; total time=   0.0s
[CV] END ..................m

[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=2; total time=   0.0s
[CV] END ................max

[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=3; total time=   0.0s
[CV] END ................max

[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max_depth=None, min_samples_split=4; total time=   0.0s
[CV] END ................max

DecisionTreeClassifier(max_depth=16, min_samples_split=4, random_state=123)

In [7]:
print(f"Training Accuracy: {dt_grid.score(X_train, y_train):0.3f}")
print(f"Test Accuracy: {dt_grid.score(X_test, y_test):0.3f}")

Training Accuracy: 0.990
Test Accuracy: 0.933


In [8]:
prediction = dt_grid.predict(X_test)
dt_grid_result = get_results_simple(dt_grid, prediction)
dt_grid_result['train_accuracy'] = round(dt_grid.score(X_train, y_train),3)
dt_grid_result 

{'test_accuracy': 0.9326,
 'recall': 0.933,
 'f1_score': 0.933,
 'precision': 0.932,
 'train_accuracy': 0.99}