# Visualize a Decision Tree in 4 Ways with Scikit-Learn and Python

## https://mljar.com/blog/visualize-decision-tree/


## Train Decision Tree on Classification Task

In [None]:
from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier 
from sklearn import tree

In [None]:
# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target

In [None]:
# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)

## Print Text Representation

In [None]:
text_representation = tree.export_text(clf)
print(text_representation)

In [None]:
# with open("decistion_tree.log", "w") as fout:
#     fout.write(text_representation)

## Plot Tree with plot_tree

In [None]:
import numpy as np
from matplotlib import pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
clf.fit(X_train, y_train)

# Symbolic class name representation
# tree.plot_tree(clf, class_names=None)

# Specific class name representation
class_names = list(iris['target_names'])
tree.plot_tree(clf, class_names=class_names)

In [None]:
fig = plt.figure(figsize=(20,15))
_ = tree.plot_tree(clf, 
                   feature_names=iris.feature_names,  
                   class_names=list(iris.target_names),
                   filled=True)

In [None]:
# fig.savefig("decistion_tree.png")

## Visualize Decision Tree with graphviz

In [None]:
import graphviz
# DOT data
dot_data = tree.export_graphviz(clf, out_file=None, 
                                feature_names=iris.feature_names,  
                                class_names=iris.target_names,
                                filled=True)

# Draw graph
graph = graphviz.Source(dot_data, format="png") 
graph

In [None]:
# graph.render("decision_tree_graphivz")

## Plot Decision Tree with dtreeviz Package

In [None]:
#!pip install dtreeviz

In [None]:
import dtreeviz

viz = dtreeviz.model(clf, X, y,
                target_name="target",
                feature_names=iris.feature_names,
                class_names=iris.target_names)

viz.view(scale=1.5)

In [None]:
# viz.save("decision_tree.svg")