In [3]:
import data_provider
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn import metrics

Create Model

In [4]:
grid_params = {
    'criterion': ['gini', 'entropy'],
    'splitter' : ["best", "random"],
}

for dataset in data_provider.data():
    features_train = dataset['train'][0]
    labels_train = dataset['train'][1]

    features_test = dataset['test'][0]
    labels_test = dataset['test'][1]

    gs = GridSearchCV(
        DecisionTreeClassifier(),
        grid_params,
        cv=3,
        n_jobs=-1
    )
    gs_results = gs.fit(features_train, labels_train)
    model = gs_results.best_estimator_

    model.fit(features_train, labels_train)

    predictions = model.predict(features_test)
    print("Model: {}".format(dataset['label']))
    print("Accuracy: {}".format(model.score(features_test, labels_test)))
    fpr, tpr, tresholds = metrics.roc_curve(labels_test, predictions)
    print("AUC: {}".format(metrics.auc(fpr, tpr)))
    print("Precision: {}".format(metrics.precision_score(labels_test, predictions)))
    print("Recall: {}".format(metrics.recall_score(labels_test, predictions)))
    print("ROC: {}".format(metrics.roc_curve(labels_test, predictions)))
    print()

    export_graphviz(
        model,
        out_file =  "decision_trees/Final Tree - {}.dot".format(dataset['label']),
        feature_names = list(features_train.columns.values),
        filled = True,
        rounded = True
    )


Model: Wine Quality
Accuracy: 0.7625
AUC: 0.7605086541858
Precision: 0.7714285714285715
Recall: 0.7894736842105263
ROC: (array([0.        , 0.26845638, 1.        ]), array([0.        , 0.78947368, 1.        ]), array([2, 1, 0]))
Model: Spam Base
Accuracy: 0.9261939218523878
AUC: 0.9223729467920819
Precision: 0.9077490774907749
Recall: 0.9044117647058824
ROC: (array([0.        , 0.05966587, 1.        ]), array([0.        , 0.90441176, 1.        ]), array([2, 1, 0], dtype=int64))
Model: Wine Quality Standard
Accuracy: 0.7375
AUC: 0.7349582008713058
Precision: 0.7457627118644068
Recall: 0.7719298245614035
ROC: (array([0.        , 0.30201342, 1.        ]), array([0.        , 0.77192982, 1.        ]), array([2, 1, 0]))
Model: Spam Base Standard
Accuracy: 0.914616497829233
AUC: 0.9134713252842903
Precision: 0.8790035587188612
Recall: 0.9080882352941176
ROC: (array([0.        , 0.08114558, 1.        ]), array([0.        , 0.90808824, 1.        ]), array([2, 1, 0], dtype=int64))
