### Training and Visualization of a Decision Tree

In [None]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

In [None]:
iris = load_iris()
X = iris.data[:, 2:]
y = iris.target

In [None]:
tree = DecisionTreeClassifier(max_depth=10, random_state=45)

In [None]:
tree.fit(X, y)

In [None]:
%pip install graphviz

In [None]:
from graphviz import Source

In [None]:
from sklearn.tree import export_graphviz

In [None]:
export_graphviz(
    tree,
    out_file="iris_tree.dot",
    feature_names=iris.feature_names[2:],
    class_names=iris.target_names,
    filled=True,
    rounded=True
)

Source.from_file("iris_tree.dot")

### Decision Tree as a Regression

In [None]:
import numpy as np
np.random.seed(45)
m = 200
X = np.random.rand(m, 1)
y = 4 * (X - 0.5) ** 2
y = y + np.random.randn(m,1)/20

In [None]:
X

In [None]:
y

In [None]:
from sklearn.tree import DecisionTreeRegressor

In [None]:
tree1 = DecisionTreeRegressor(random_state=55)
tree2 = DecisionTreeRegressor(random_state=100, min_samples_leaf=7)
tree1.fit(X, y)
tree2.fit(X, y)

In [None]:
import matplotlib.pyplot as plt
def plot_regression_predictions(tree_reg, X, y, axes=[0, 1, -0.2, 1], ylabel="$y$"):
    x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)
    y_pred = tree_reg.predict(x1)
    plt.axis(axes)
    plt.xlabel("$x_1$", fontsize=18)
    if ylabel:
        plt.ylabel(ylabel, fontsize=18, rotation=0)
    plt.plot(X, y, "b.")
    plt.plot(x1, y_pred, "r.-", linewidth=2, label=r"$\hat{y}$")

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)

plt.sca(axes[0])
plot_regression_predictions(tree1, X, y)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):
    plt.plot([split, split], [-0.2, 1], style, linewidth=2)
plt.text(0.21, 0.65, "Depth=0", fontsize=15)
plt.text(0.01, 0.2, "Depth=1", fontsize=13)
plt.text(0.65, 0.8, "Depth=1", fontsize=13)
plt.legend(loc="upper center", fontsize=18)
plt.title("max_depth=2", fontsize=14)

plt.sca(axes[1])
plot_regression_predictions(tree2, X, y, ylabel=None)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):
    plt.plot([split, split], [-0.2, 1], style, linewidth=2)
for split in (0.0458, 0.1298, 0.2873, 0.9040):
    plt.plot([split, split], [-0.2, 1], "k:", linewidth=1)
plt.text(0.3, 0.5, "Depth=2", fontsize=13)
plt.title("max_depth=3", fontsize=14)