In [1]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris, load_diabetes, load_breast_cancer
from sklearn.tree import DecisionTreeClassifier, export_text


In [4]:
iris_data = load_iris()
# cancer = load_breast_cancer()
diabetes = load_diabetes()

In [5]:
data_list = [iris_data,diabetes]

In [6]:
def get_train_test_split(data):
    X = data.data
    y = data.target
    X_train,X_test, y_train, y_test = train_test_split(X, y, random_state = 0)
    return X_train,X_test, y_train, y_test

In [7]:
def create_decision_tree(X_train, X_test, depth = None, max_leaf_nodes=None):
    model = DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes, max_depth=depth, random_state=0)
    model.fit(X_train, y_train)
    return model

In [8]:
def get_decision_text(model, feature_name):
    text = export_text(model, feature_name)
    print(100*'=')
    print(text)
    print(100*'=')
    return text


In [9]:
X_train,X_test, y_train, y_test = get_train_test_split(iris_data)
model = create_decision_tree(X_train, X_test, depth = 2, max_leaf_nodes=None)
text = get_decision_text(model, iris_data.feature_names)

|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal length (cm) <= 4.95
|   |   |--- class: 1
|   |--- petal length (cm) >  4.95
|   |   |--- class: 2



In [10]:
for data in data_list:
    X_train,X_test, y_train, y_test = get_train_test_split(data)
    model = create_decision_tree(X_train, X_test, depth = 2, max_leaf_nodes=None)
    text = get_decision_text(model, data.feature_names)

|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal length (cm) <= 4.95
|   |   |--- class: 1
|   |--- petal length (cm) >  4.95
|   |   |--- class: 2

|--- s5 <= -0.03
|   |--- bmi <= -0.05
|   |   |--- class: 72.0
|   |--- bmi >  -0.05
|   |   |--- class: 52.0
|--- s5 >  -0.03
|   |--- bmi <= -0.01
|   |   |--- class: 91.0
|   |--- bmi >  -0.01
|   |   |--- class: 220.0



In [11]:
diabetes.feature_names

['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']

In [12]:
diabetes.target

array([151.,  75., 141., 206., 135.,  97., 138.,  63., 110., 310., 101.,
        69., 179., 185., 118., 171., 166., 144.,  97., 168.,  68.,  49.,
        68., 245., 184., 202., 137.,  85., 131., 283., 129.,  59., 341.,
        87.,  65., 102., 265., 276., 252.,  90., 100.,  55.,  61.,  92.,
       259.,  53., 190., 142.,  75., 142., 155., 225.,  59., 104., 182.,
       128.,  52.,  37., 170., 170.,  61., 144.,  52., 128.,  71., 163.,
       150.,  97., 160., 178.,  48., 270., 202., 111.,  85.,  42., 170.,
       200., 252., 113., 143.,  51.,  52., 210.,  65., 141.,  55., 134.,
        42., 111.,  98., 164.,  48.,  96.,  90., 162., 150., 279.,  92.,
        83., 128., 102., 302., 198.,  95.,  53., 134., 144., 232.,  81.,
       104.,  59., 246., 297., 258., 229., 275., 281., 179., 200., 200.,
       173., 180.,  84., 121., 161.,  99., 109., 115., 268., 274., 158.,
       107.,  83., 103., 272.,  85., 280., 336., 281., 118., 317., 235.,
        60., 174., 259., 178., 128.,  96., 126., 28