In [6]:
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals

# Common imports
import numpy as np
import os

# to make this notebook's output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "decision_trees"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)

In [7]:
PROJECT_ROOT_DIR

'.'

In [8]:
IMAGES_PATH

'.\\images\\decision_trees'

In [9]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

In [10]:
iris = load_iris()


In [11]:
iris.feature_names

['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

In [12]:
X = iris.data[:, 2:]
y = iris.target

In [13]:
tree_clf = DecisionTreeClassifier(max_depth=2)

In [14]:
tree_clf.fit(X,y)

In [15]:
from sklearn.tree import export_graphviz
def image_path(fig_id):
    return os.path.join(IMAGES_PATH, fig_id)


In [16]:
export_graphviz(
    tree_clf, 
    out_file=image_path("iris_tree.dot"),
    feature_names = iris.feature_names[2:],
    class_names=iris.target_names,
    rounded = True,
    filled=True
)

In [17]:
tree_clf.predict_proba([[5,1.5]])

array([[0.        , 0.90740741, 0.09259259]])