In [151]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('default')
from activation_functions import sigmoid
from metrics import accuracy
from BaseRegression import BaseRegression
from collections import Counter

from sklearn import datasets
from sklearn.model_selection import train_test_split

# Decision Tree

Decision tree is a powerful algorithm that can fit complex data and perform both classification, regression, and multioutput tasks.

Advantage of Decision Trees:
* Make very few assumptions about the data.
* Fairly intuitive and the decisions are easy to interpret. (white box model)
* Feature scaling and centering is not necessary to obtain good results.

Other noteworthy info:
* Decision trees form the fundamental components of a RandomForest.
* The CART algorithm (scikit-learn) produces only binary trees whereas ID3 for example allows nodes to have more than 2 children.

<img src="https://miro.medium.com/v2/resize:fit:1060/1*H6thrs5CR_wdxQyMCwWawQ.png" alt="Image of Decision tree" style="background-color:white;">

### <span style="color:#217AB8"> Making predictions</span> 

Starting at the root node and follow the conditions that apply to your current instance to the leaf. This will look like is the attribute x of your instance larger or smaller than 1. If yes follow the tree down the right path. If no go down the left. Once you reach a leaf node (aka does not have any child nodes) use this node's class to predict the class of your instance.\
It takes approximately $O(log_2(m))$ nodes to predict an instance's class. This is indipendent of the number of features so predictions are very fast even with large training sets.

* *sample* how many training samples a node's condition applies to.\
* *value* how many training samples of each class a node applies to (eg. in node x: class a is represented 1, class b is represented 20 times)
* *gini* measures the impurity of a node (pure when a node applies to only instances of one class)

Gini impurity = $G_i = 1  - \sum_{k=1}^n p_{i,k}^2$ where $p_{i,k}$ is the ratio of class k instances among the training instances in the ith node.

### <span style="color:#217AB8">CART classification and regression tree algorithm</span> 
A greedy algorithm meaning at every step from the beginning it tries to optimize the split rather than checking whether this improves impurity further down the line at lower levels.\
Therefore it does not guarantee an optimal solution.\
It is also an NP complete problem and requires O(exp(m)) time, making it hard to work with even small training sets. -> find reasonalbly good solutions.

1. Split the training set into 2 using a single feature $k$ and a threshold $t_k$ (eg. petal length >= 1.3). Find the purest subsets for pairs (k, $t_k$) weighted by their size.

&emsp;&emsp;&emsp;Minimize Cost function: $ J(k, t_k) = \frac{m_{left}}{m}G_{left} + \frac{m_{right}}{m}G_{right}$ where m is the number of instances in the subesets

2. Continue this on the subsets recursively.
3. Stop when max_depth is reached or if no split further reduces the gini impurity.\
Training complexity requires the comparison of all features (unless max_features is set) on all samples at each node.\
This brings the training compexity to $O(n * m * log(m))$

If the tree is left unconstrained it will fit itself very closely to the training data and most likely overfitting.\
This is often described as a non-parametric model. In contrast, parametric models such as linear models have a pre-determined number of parameters, so their degrees of freedom are limited.\
(which in turn can lead to underfitting, especially when the data contains more complex patterns than the model is able to catch)



### <span style="color:#217AB8">Gini impurity or Entropy</span> 
Shannon's information theory: Entropy measures the average information content of a message:\
Entropy is 0 when all messages are identical.\
For example if in the case of decision tree the entropy is 0 then the node captures only one class.

Entropy in the ith node: $H_i = \sum_{k=1}^{n} p_{i,k}*log(p_{i,k})$ where $p_{i,k}\not=0$

[Article explaining entropy, information gain and gini](https://www.machinelearningnuggets.com/splitting-criteria-in-decision-trees/#:~:text=Entropy%20measures%20data%20points%27%20degree,ranges%20between%200%20and%201.&text=We%20can%20see%20that%20the,the%20data%20is%20perfectly%20randomized.
)

Example:

Should you use Gini or Entropy?
According to O'reilly "Hands on Machine learning" they lead to similar trees.\
"The gini index seems to be slightly faster to compute so it is a good default.\
However when they differ the Gini impurity tends to isolate the most frequent class  in its own branch of the tree, while entropy tends to produce slightly more balanced trees." (Sebastian Raschka's analysis)

### <span style="color:#217AB8">Regularization</span> 
For example:
* restrict the maximum depth of the tree
* set the minumum number of samples a node must have before it can split
* set the minumum number of samples a leaf must have
* set the minumum fraction of all training data that a leaf must represent
* restrict the maximum number of leaf nodes that can be determined
* restrict the maximum number of features considered for splitting at a node

You could also train the tree without restrictions and then prune the tree after the training.
For example:
Prune a node if all of its children are leaves and it provides no statistically significant improvement of purity.\
Using the chi-sqaure test with null hypothesis that the node increases purity. If you want to reject the null hypothesis with 95% confidence then the p-value should be over 0.05. 
 


In [152]:

def entropy(classes):
    # Entropy measures data points' degree of impurity, uncertainty, or surprise. 
    # Range [0 and 1]: equals 1 when the data is perfectly randomized.
    # Expected Value of surprise
    probs = np.bincount(classes) / len(classes)
    return -np.sum([px * np.log2(px) for px in probs if px > 0])

def info_gain(labels, column):
    # column: is the chosen attribute column of the dataset
    # labels: the target array y
    # The information gain of an attribute
    # I(attribute a) = Entropy of dataset - sum over unique values v in a{( count(v)/len(y) * Entropy(Unique value in a))}
    dataset_entropy = entropy(labels)
    sum_val_entropies = 0
    for val in np.unique(column):
        value_ids = [id for id, elem in enumerate(column) if elem == val]
        val_clss=[clss for id, clss in enumerate(labels) if id in value_ids]
        sum_val_entropies += column.count(val) / len(column) * entropy(val_clss)
    return dataset_entropy - sum_val_entropies

def gini(classes):
    #compute the gini impurity
    probs = np.bincount(classes) / len(classes)
    return 1 - np.sum(np.square(probs))


In [153]:
clss = [1,1,1,1,1,1,1,2,2,2,2,2,2,3,3,3,4,4,4,4,4,4,4,4,4]
vls = ['b','b','b','b','b','b','c','c','c','b','b','b','b','a','a','a','b','b','b','b','a','a','a','a','a']

info_gain(clss, vls)

0.6186951736180384

In [154]:
class Node:
    def __init__(self, feature=None, threshold=None, samples=None, left=None, right=None, *, value=None):
        self.feature = feature
        self.threshold = threshold
        self.value = value
        self.left = left
        self.right = right
        self.samples = samples
        self.gini = None
    
    def is_leaf_node(self):
        return self.value is not None

In [155]:
class DecisionTree:
    def __init__(self, min_samples_split=2, max_depth=100, n_features=None):
        self.min_samples_split =min_samples_split
        self.max_depth = max_depth
        self.n_features =n_features 
        self.root = None

    def fit(self, X, y):
        # ensure that the number of features to use is never larger than the actual existing features.
        self.n_features = X.shape[1] if not self.n_features else min(self.n_features, X.shape[1])
        # grow the tree with its choices and thresholds
        self.root = self._grow_tree(X, y)
    
    def _grow_tree(self, X, y, depth=0):
        n_samples, n_feats = X.shape
        n_labels = len(np.unique(y))

        # check for stopping criteria before continuing to grow the tree.
        if (depth >= self.max_depth 
            or n_samples < self.min_samples_split
            or n_labels == 1):
            leaf_value = self._most_common_label(y)
            return Node(value=leaf_value, samples=n_samples)
        
        # if stopping criteria is not met countinue growing the tree.
        feature_ids = np.random.choice(n_feats, self.n_features, replace=False)
        
        # greedy search for best features and thresholds
        best_feat, best_thresh = self._best_criteria(X,y, feature_ids)

        # split based on best feature and threshold
        left_ids, right_ids = self._split(X[:, best_feat], best_thresh)

        # Continue growing the tree for the left and right children of current node
        left  = self._grow_tree( X[left_ids, :], y[left_ids], depth+1)
        right = self._grow_tree( X[right_ids, :], y[right_ids], depth+1)

        return Node(best_feat, best_thresh, n_samples, left, right)


    def _info_gain(self, labels, column, threshold):
        # column: is the chosen attribute column of the dataset
        # labels: the target array y
        # The information gain of an attribute
        # I(attribute a) = Entropy of dataset - sum over unique values v in a{( count(v)/len(y) * Entropy(Unique value in a))}
        parent_entropy = entropy(labels)
        left_ids, right_ids = self._split(column, threshold)
        
        if len(left_ids) == 0 or len(right_ids) == 0:
            return 0
        
        left_entropy, right_entropy = entropy(labels[left_ids]), entropy(labels[right_ids])
        child_entropy = (len(left_ids) / len(labels)) * left_entropy + (len(right_ids) / len(labels)) * right_entropy
        return (parent_entropy - child_entropy)

    def _split(self, col, split_thresh):
        left_ids = np.argwhere(col <= split_thresh).flatten()
        right_ids = np.argwhere(col > split_thresh).flatten()
        return left_ids, right_ids

    def _most_common_label(self, y):
        most_common = Counter(y).most_common(1)[0][0]
        return most_common

    def _best_criteria(self, X, y, feat_ids):
        # brute force search over all unique values of all relevant features
        # Goal: find the features and splits that have the highest information gain
        best_gain = -1
        split_id, split_thresh = None, None
        for feat_id in feat_ids:
            X_col = X[:, feat_id]
            threshold = np.unique(X_col)
            for thresh in threshold:
                gain = self._info_gain(y, X_col, thresh)
                if gain > best_gain:
                    best_gain = gain
                    split_id = feat_id
                    split_thresh = thresh
        return split_id, split_thresh


    def predict(self, X):
        # for each sample run through the tree and predict the class of its leaf node.
        return np.array([self._traverse_tree(x, self.root) for x in X])
    
    def _traverse_tree(self, x, node):
        # with Recursion
        # if we are at a leaf node return the most common label saved as the value of the node 
        if node.is_leaf_node():
            return node.value
        # if the feature f of our sample s is equal to or below the threshold go to the left child node
        if x[node.feature] <= node.threshold:
            return self._traverse_tree(x, node.left)
        # else the the feature f is larger than the threshold and continue to the right child node
        return self._traverse_tree(x, node.right)


In [156]:
cancer_data = datasets.load_breast_cancer()

X = cancer_data.data
# Target "1" : Benign, "0":Malignant
y = cancer_data.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)

In [157]:
tree = DecisionTree(max_depth=10)
tree.fit(X_train, y_train)

predicted = tree.predict(X_test)
accuracy = accuracy(predicted, y_test)
print(accuracy)

0.9210526315789473


array(['mean radius', 'mean texture', 'mean perimeter', 'mean area',
       'mean smoothness', 'mean compactness', 'mean concavity',
       'mean concave points', 'mean symmetry', 'mean fractal dimension',
       'radius error', 'texture error', 'perimeter error', 'area error',
       'smoothness error', 'compactness error', 'concavity error',
       'concave points error', 'symmetry error',
       'fractal dimension error', 'worst radius', 'worst texture',
       'worst perimeter', 'worst area', 'worst smoothness',
       'worst compactness', 'worst concavity', 'worst concave points',
       'worst symmetry', 'worst fractal dimension'], dtype='<U23')

In [168]:
label_names = {0:"Malignant", 1:"Benign"}

# To Do: See if you can learn something from scikit learn implementation for the visualization.

def plot_tree(decision_tree, *, max_depth=None, feature_names=None, class_names=None, label="all", filled=False, impurity=True, 
              node_ids=False, proportion=False, rounded=False, precision=3, ax=None, fontsize=None,):
    """Plot a decision tree.
    The sample counts that are shown are weighted with any sample_weights that
    might be present.

    The visualization is fit automatically to the size of the axis.
    Use the ``figsize`` or ``dpi`` arguments of ``plt.figure``  to control
    the size of the rendering.

    Read more in the :ref:`User Guide <tree>`.

    .. versionadded:: 0.21

    Parameters
    ----------
    decision_tree : decision tree regressor or classifier
        The decision tree to be plotted.

    max_depth : int, default=None
        The maximum depth of the representation. If None, the tree is fully
        generated.

    feature_names : array-like of str, default=None
        Names of each of the features.
        If None, generic names will be used ("x[0]", "x[1]", ...).

    class_names : array-like of str or True, default=None
        Names of each of the target classes in ascending numerical order.
        Only relevant for classification and not supported for multi-output.
        If ``True``, shows a symbolic representation of the class name.

    label : {'all', 'root', 'none'}, default='all'
        Whether to show informative labels for impurity, etc.
        Options include 'all' to show at every node, 'root' to show only at
        the top root node, or 'none' to not show at any node.

    filled : bool, default=False
        When set to ``True``, paint nodes to indicate majority class for
        classification, extremity of values for regression, or purity of node
        for multi-output.

        impurity : bool, default=True
        When set to ``True``, show the impurity at each node.

    node_ids : bool, default=False
        When set to ``True``, show the ID number on each node.

    proportion : bool, default=False
        When set to ``True``, change the display of 'values' and/or 'samples'
        to be proportions and percentages respectively.

    rounded : bool, default=False
        When set to ``True``, draw node boxes with rounded corners and use
        Helvetica fonts instead of Times-Roman.

    precision : int, default=3
        Number of digits of precision for floating point in the values of
        impurity, threshold and value attributes of each node.

    ax : matplotlib axis, default=None
        Axes to plot to. If None, use current axis. Any previous content
        is cleared.

    fontsize : int, default=None
        Size of text font. If None, determined automatically to fit figure.

    Returns
    -------
    annotations : list of artists
        List containing the artists for the annotation boxes making up the
        tree.
    -------
    """

    check_is_fitted(decision_tree)

    exporter = _MPLTreeExporter(max_depth=max_depth, feature_names=feature_names, class_names=class_names, label=label, filled=filled,
                                impurity=impurity, node_ids=node_ids, proportion=proportion, rounded=rounded, precision=precision, fontsize=fontsize,)
    return exporter.export(decision_tree, ax=ax)

class _BaseTreeExporter:
    def __init__(self, max_depth=None, feature_names=None, class_names=None, label="all", filled=False, impurity=True,
                node_ids=False, proportion=False, rounded=False, precision=3, fontsize=None,):
        self.max_depth = max_depth
        self.feature_names = feature_names
        self.class_names = class_names
        self.label = label
        self.filled = filled
        self.impurity = impurity
        self.node_ids = node_ids
        self.proportion = proportion
        self.rounded = rounded
        self.precision = precision
        self.fontsize = fontsize

    def get_color(self, value):
        # Find the appropriate color & intensity for a node
        if self.colors["bounds"] is None:
            # Classification tree
            color = list(self.colors["rgb"][np.argmax(value)])
            sorted_values = sorted(value, reverse=True)
            if len(sorted_values) == 1:
                alpha = 0.0
            else:
                alpha = (sorted_values[0] - sorted_values[1]) / (1 - sorted_values[1])
        else:
            # Regression tree or multi-output
            color = list(self.colors["rgb"][0])
            alpha = (value - self.colors["bounds"][0]) / (
                self.colors["bounds"][1] - self.colors["bounds"][0]
            )
        # compute the color as alpha against white
        color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color]
        # Return html color code in #RRGGBB format
        return "#%2x%2x%2x" % tuple(color)
    
    def get_fill_color(self, tree, node_id):
        # Fetch appropriate color for node
        if "rgb" not in self.colors:
            # Initialize colors and bounds if required
            self.colors["rgb"] = _color_brew(tree.n_classes[0])
            if tree.n_outputs != 1:
                # Find max and min impurities for multi-output
                self.colors["bounds"] = (np.min(-tree.impurity), np.max(-tree.impurity))
            elif tree.n_classes[0] == 1 and len(np.unique(tree.value)) != 1:
                # Find max and min values in leaf nodes for regression
                self.colors["bounds"] = (np.min(tree.value), np.max(tree.value))
        if tree.n_outputs == 1:
            node_val = tree.value[node_id][0, :] / tree.weighted_n_node_samples[node_id]
            if tree.n_classes[0] == 1:
                # Regression or degraded classification with single class
                node_val = tree.value[node_id][0, :]
                if isinstance(node_val, Iterable) and self.colors["bounds"] is not None:
                    # Only unpack the float only for the regression tree case.
                    # Classification tree requires an Iterable in `get_color`.
                    node_val = node_val.item()
        else:
            # If multi-output color node by impurity
            node_val = -tree.impurity[node_id]
        return self.get_color(node_val)


    def node_to_str(self, tree, node_id, criterion):
        # Generate the node content string
        if tree.n_outputs == 1:
            value = tree.value[node_id][0, :]
        else:
            value = tree.value[node_id]

        # Should labels be shown?
        labels = (self.label == "root" and node_id == 0) or self.label == "all"

        characters = self.characters
        node_string = characters[-1]

        # Write node ID
        if self.node_ids:
            if labels:
                node_string += "node "
            node_string += characters[0] + str(node_id) + characters[4]

        # Write decision criteria
        if tree.children_left[node_id] != _tree.TREE_LEAF:
            # Always write node decision criteria, except for leaves
            if self.feature_names is not None:
                feature = self.feature_names[tree.feature[node_id]]
            else:
                feature = "x%s%s%s" % (
                    characters[1],
                    tree.feature[node_id],
                    characters[2],
                )
            node_string += "%s %s %s%s" % (
                feature,
                characters[3],
                round(tree.threshold[node_id], self.precision),
                characters[4],
            )

        # Write impurity
        if self.impurity:
            if isinstance(criterion, _criterion.FriedmanMSE):
                criterion = "friedman_mse"
            elif isinstance(criterion, _criterion.MSE) or criterion == "squared_error":
                criterion = "squared_error"
            elif not isinstance(criterion, str):
                criterion = "impurity"
            if labels:
                node_string += "%s = " % criterion
            node_string += (
                str(round(tree.impurity[node_id], self.precision)) + characters[4]
            )

            # Write impurity
        if self.impurity:
            if isinstance(criterion, _criterion.FriedmanMSE):
                criterion = "friedman_mse"
            elif isinstance(criterion, _criterion.MSE) or criterion == "squared_error":
                criterion = "squared_error"
            elif not isinstance(criterion, str):
                criterion = "impurity"
            if labels:
                node_string += "%s = " % criterion
            node_string += (
                str(round(tree.impurity[node_id], self.precision)) + characters[4]
            )

        # Write node sample count
        if labels:
            node_string += "samples = "
        if self.proportion:
            percent = (
                100.0 * tree.n_node_samples[node_id] / float(tree.n_node_samples[0])
            )
            node_string += str(round(percent, 1)) + "%" + characters[4]
        else:
            node_string += str(tree.n_node_samples[node_id]) + characters[4]


# Write node class distribution / regression value
        if self.proportion and tree.n_classes[0] != 1:
            # For classification this will show the proportion of samples
            value = value / tree.weighted_n_node_samples[node_id]
        if labels:
            node_string += "value = "
        if tree.n_classes[0] == 1:
            # Regression
            value_text = np.around(value, self.precision)
        elif self.proportion:
            # Classification
            value_text = np.around(value, self.precision)
        elif np.all(np.equal(np.mod(value, 1), 0)):
            # Classification without floating-point weights
            value_text = value.astype(int)
        else:
            # Classification with floating-point weights
            value_text = np.around(value, self.precision)
        # Strip whitespace
        value_text = str(value_text.astype("S32")).replace("b'", "'")
        value_text = value_text.replace("' '", ", ").replace("'", "")
        if tree.n_classes[0] == 1 and tree.n_outputs == 1:
            value_text = value_text.replace("[", "").replace("]", "")
        value_text = value_text.replace("\n ", characters[4])
        node_string += value_text + characters[4]

 # Write node majority class
        if (
            self.class_names is not None
            and tree.n_classes[0] != 1
            and tree.n_outputs == 1
        ):
            # Only done for single-output classification trees
            if labels:
                node_string += "class = "
            if self.class_names is not True:
                class_name = self.class_names[np.argmax(value)]
            else:
                class_name = "y%s%s%s" % (
                    characters[1],
                    np.argmax(value),
                    characters[2],
                )
            node_string += class_name

        # Clean up any trailing newlines
        if node_string.endswith(characters[4]):
            node_string = node_string[: -len(characters[4])]

        return node_string + characters[5]



class _DOTTreeExporter(_BaseTreeExporter):
    def __init__( self, out_file=SENTINEL, max_depth=None, feature_names=None, class_names=None, label="all", filled=False,
                 leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False,
                 special_characters=False, precision=3, fontname="helvetica",):
        super().__init__(
            max_depth=max_depth,
            feature_names=feature_names,
            class_names=class_names,
            label=label,
            filled=filled,
            impurity=impurity,
            node_ids=node_ids,
            proportion=proportion,
            rounded=rounded,
            precision=precision,
        )
        self.leaves_parallel = leaves_parallel
        self.out_file = out_file
        self.special_characters = special_characters
        self.fontname = fontname
        self.rotate = rotate

        # PostScript compatibility for special characters
        if special_characters:
            self.characters = ["&#35;", "<SUB>", "</SUB>", "&le;", "<br/>", ">", "<"]
        else:
            self.characters = ["#", "[", "]", "<=", "\\n", '"', '"']

        # The depth of each node for plotting with 'leaf' option
        self.ranks = {"leaves": []}
        # The colors to render each node with
        self.colors = {"bounds": None}

        def export(self, decision_tree):
        # Check length of feature_names before getting into the tree node
        # Raise error if length of feature_names does not match
        # n_features_in_ in the decision_tree
        if self.feature_names is not None:
            if len(self.feature_names) != decision_tree.n_features_in_:
                raise ValueError(
                    "Length of feature_names, %d does not match number of features, %d"
                    % (len(self.feature_names), decision_tree.n_features_in_)
                )
        # each part writes to out_file
        self.head()
        # Now recurse the tree and add node & edge attributes
        if isinstance(decision_tree, _tree.Tree):
            self.recurse(decision_tree, 0, criterion="impurity")
        else:
            self.recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)

        self.tail()

    def tail(self):
        # If required, draw leaf nodes at same depth as each other
        if self.leaves_parallel:
            for rank in sorted(self.ranks):
                self.out_file.write(
                    "{rank=same ; " + "; ".join(r for r in self.ranks[rank]) + "} ;\n"
                )
        self.out_file.write("}")

    def head(self):
        self.out_file.write("digraph Tree {\n")

        # Specify node aesthetics
        self.out_file.write("node [shape=box")
        rounded_filled = []
        if self.filled:
            rounded_filled.append("filled")
        if self.rounded:
            rounded_filled.append("rounded")
        if len(rounded_filled) > 0:
            self.out_file.write(
                ', style="%s", color="black"' % ", ".join(rounded_filled)
            )
        self.out_file.write(', fontname="%s"' % self.fontname)
        self.out_file.write("] ;\n")

        # Specify graph & edge aesthetics
        if self.leaves_parallel:
            self.out_file.write("graph [ranksep=equally, splines=polyline] ;\n")

        self.out_file.write('edge [fontname="%s"] ;\n' % self.fontname)

        if self.rotate:
            self.out_file.write("rankdir=LR ;\n")


    def recurse(self, tree, node_id, criterion, parent=None, depth=0):
        if node_id == _tree.TREE_LEAF:
            raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF)

        left_child = tree.children_left[node_id]
        right_child = tree.children_right[node_id]

        # Add node with description
        if self.max_depth is None or depth <= self.max_depth:
            # Collect ranks for 'leaf' option in plot_options
            if left_child == _tree.TREE_LEAF:
                self.ranks["leaves"].append(str(node_id))
            elif str(depth) not in self.ranks:
                self.ranks[str(depth)] = [str(node_id)]
            else:
                self.ranks[str(depth)].append(str(node_id))

            self.out_file.write(
                "%d [label=%s" % (node_id, self.node_to_str(tree, node_id, criterion))
            )

            if self.filled:
                self.out_file.write(
                    ', fillcolor="%s"' % self.get_fill_color(tree, node_id)
                )
            self.out_file.write("] ;\n")
        
            if parent is not None:
                # Add edge to parent
                self.out_file.write("%d -> %d" % (parent, node_id))
                if parent == 0:
                    # Draw True/False labels if parent is root node
                    angles = np.array([45, -45]) * ((self.rotate - 0.5) * -2)
                    self.out_file.write(" [labeldistance=2.5, labelangle=")
                    if node_id == 1:
                        self.out_file.write('%d, headlabel="True"]' % angles[0])
                    else:
                        self.out_file.write('%d, headlabel="False"]' % angles[1])
                self.out_file.write(" ;\n")

            if left_child != _tree.TREE_LEAF:
                self.recurse(
                    tree,
                    left_child,
                    criterion=criterion,
                    parent=node_id,
                    depth=depth + 1,
                )
                self.recurse(
                    tree,
                    right_child,
                    criterion=criterion,
                    parent=node_id,
                    depth=depth + 1,
                )
        else:
            self.ranks["leaves"].append(str(node_id))

            self.out_file.write('%d [label="(...)"' % node_id)
            if self.filled:
                # color cropped nodes grey
                self.out_file.write(', fillcolor="#C0C0C0"')
            self.out_file.write("] ;\n" % node_id)

            if parent is not None:
                # Add edge to parent
                self.out_file.write("%d -> %d ;\n" % (parent, node_id))

            
class _MPLTreeExporter(_BaseTreeExporter):
    def __init__(self, max_depth=None, feature_names=None, class_names=None, label="all", filled=False, 
                 impurity=True, node_ids=False, proportion=False, rounded=False, precision=3, fontsize=None,):
        super().__init__(
            max_depth=max_depth,
            feature_names=feature_names,
            class_names=class_names,
            label=label,
            filled=filled,
            impurity=impurity,
            node_ids=node_ids,
            proportion=proportion,
            rounded=rounded,
            precision=precision,
        )
        self.fontsize = fontsize

        # The depth of each node for plotting with 'leaf' option
        self.ranks = {"leaves": []}
        # The colors to render each node with
        self.colors = {"bounds": None}

        self.characters = ["#", "[", "]", "<=", "\n", "", ""]
        self.bbox_args = dict()
        if self.rounded:
            self.bbox_args["boxstyle"] = "round"

        self.arrow_args = dict(arrowstyle="<-")

    def _make_tree(self, node_id, et, criterion, depth=0):
        # traverses _tree.Tree recursively, builds intermediate
        # "_reingold_tilford.Tree" object
        name = self.node_to_str(et, node_id, criterion=criterion)
        if et.children_left[node_id] != _tree.TREE_LEAF and (
            self.max_depth is None or depth <= self.max_depth
        ):
            children = [
                self._make_tree(
                    et.children_left[node_id], et, criterion, depth=depth + 1
                ),
                self._make_tree(
                    et.children_right[node_id], et, criterion, depth=depth + 1
                ),
            ]
        else:
            return Tree(name, node_id)
        return Tree(name, node_id, *children)

    def export(self, decision_tree, ax=None):
        import matplotlib.pyplot as plt
        from matplotlib.text import Annotation

        if ax is None:
            ax = plt.gca()
        ax.clear()
        ax.set_axis_off()
        my_tree = self._make_tree(0, decision_tree.tree_, decision_tree.criterion)
        draw_tree = buchheim(my_tree)

        # important to make sure we're still
        # inside the axis after drawing the box
        # this makes sense because the width of a box
        # is about the same as the distance between boxes
        max_x, max_y = draw_tree.max_extents() + 1
        ax_width = ax.get_window_extent().width
        ax_height = ax.get_window_extent().height

        scale_x = ax_width / max_x
        scale_y = ax_height / max_y
        self.recurse(draw_tree, decision_tree.tree_, ax, max_x, max_y)

        anns = [ann for ann in ax.get_children() if isinstance(ann, Annotation)]

        # update sizes of all bboxes
        renderer = ax.figure.canvas.get_renderer()
        for ann in anns:
            ann.update_bbox_position_size(renderer)

        if self.fontsize is None:
            # get figure to data transform
            # adjust fontsize to avoid overlap
            # get max box width and height
            extents = [ann.get_bbox_patch().get_window_extent() for ann in anns]
            max_width = max([extent.width for extent in extents])
            max_height = max([extent.height for extent in extents])
            # width should be around scale_x in axis coordinates
            size = anns[0].get_fontsize() * min(
                scale_x / max_width, scale_y / max_height
            )
            for ann in anns:
                ann.set_fontsize(size)

        return anns

    def recurse(self, node, tree, ax, max_x, max_y, depth=0):
        import matplotlib.pyplot as plt

        kwargs = dict(
            bbox=self.bbox_args.copy(),
            ha="center",
            va="center",
            zorder=100 - 10 * depth,
            xycoords="axes fraction",
            arrowprops=self.arrow_args.copy(),
        )

        kwargs["arrowprops"]["edgecolor"] = plt.rcParams["text.color"]

        if self.fontsize is not None:
            kwargs["fontsize"] = self.fontsize

        # offset things by .5 to center them in plot
        xy = ((node.x + 0.5) / max_x, (max_y - node.y - 0.5) / max_y)

        if self.max_depth is None or depth <= self.max_depth:
            if self.filled:
                kwargs["bbox"]["fc"] = self.get_fill_color(tree, node.tree.node_id)
            else:
                kwargs["bbox"]["fc"] = ax.get_facecolor()

            if node.parent is None:
                # root
                ax.annotate(node.tree.label, xy, **kwargs)
            else:
                xy_parent = (
                    (node.parent.x + 0.5) / max_x,
                    (max_y - node.parent.y - 0.5) / max_y,
                )
                ax.annotate(node.tree.label, xy_parent, xy, **kwargs)
            for child in node.children:
                self.recurse(child, tree, ax, max_x, max_y, depth=depth + 1)

        else:
            xy_parent = (
                (node.parent.x + 0.5) / max_x,
                (max_y - node.parent.y - 0.5) / max_y,
            )
            kwargs["bbox"]["fc"] = "grey"
            ax.annotate("\n  (...)  \n", xy_parent, xy, **kwargs)


@validate_params(
    {
        "decision_tree": "no_validation",
        "out_file": [str, None, HasMethods("write")],
        "max_depth": [Interval(Integral, 0, None, closed="left"), None],
        "feature_names": ["array-like", None],
        "class_names": ["array-like", "boolean", None],
        "label": [StrOptions({"all", "root", "none"})],
        "filled": ["boolean"],
        "leaves_parallel": ["boolean"],
        "impurity": ["boolean"],
        "node_ids": ["boolean"],
        "proportion": ["boolean"],
        "rotate": ["boolean"],
        "rounded": ["boolean"],
        "special_characters": ["boolean"],
        "precision": [Interval(Integral, 0, None, closed="left"), None],
        "fontname": [str],
    },
    prefer_skip_nested_validation=True,
)
def export_graphviz(
    decision_tree,
    out_file=None,
    *,
    max_depth=None,
    feature_names=None,
    class_names=None,
    label="all",
    filled=False,
    leaves_parallel=False,
    impurity=True,
    node_ids=False,
    proportion=False,
    rotate=False,
    rounded=False,
    special_characters=False,
    precision=3,
    fontname="helvetica",
):
    """Export a decision tree in DOT format.

    This function generates a GraphViz representation of the decision tree,
    which is then written into `out_file`. Once exported, graphical renderings
    can be generated using, for example::

        $ dot -Tps tree.dot -o tree.ps      (PostScript format)
        $ dot -Tpng tree.dot -o tree.png    (PNG format)

    The sample counts that are shown are weighted with any sample_weights that
    might be present.

    Read more in the :ref:`User Guide <tree>`.

    Parameters
    ----------
    decision_tree : object
        The decision tree estimator to be exported to GraphViz.

    out_file : object or str, default=None
        Handle or name of the output file. If ``None``, the result is
        returned as a string.

        .. versionchanged:: 0.20
            Default of out_file changed from "tree.dot" to None.

    max_depth : int, default=None
        The maximum depth of the representation. If None, the tree is fully
        generated.

    feature_names : array-like of shape (n_features,), default=None
        An array containing the feature names.
        If None, generic names will be used ("x[0]", "x[1]", ...).

    class_names : array-like of shape (n_classes,) or bool, default=None
        Names of each of the target classes in ascending numerical order.
        Only relevant for classification and not supported for multi-output.
        If ``True``, shows a symbolic representation of the class name.

    label : {'all', 'root', 'none'}, default='all'
        Whether to show informative labels for impurity, etc.
        Options include 'all' to show at every node, 'root' to show only at
        the top root node, or 'none' to not show at any node.

    filled : bool, default=False
        When set to ``True``, paint nodes to indicate majority class for
        classification, extremity of values for regression, or purity of node
        for multi-output.

    leaves_parallel : bool, default=False
        When set to ``True``, draw all leaf nodes at the bottom of the tree.

    impurity : bool, default=True
        When set to ``True``, show the impurity at each node.

    node_ids : bool, default=False
        When set to ``True``, show the ID number on each node.

    proportion : bool, default=False
        When set to ``True``, change the display of 'values' and/or 'samples'
        to be proportions and percentages respectively.

    rotate : bool, default=False
    When set to ``True``, orient tree left to right rather than top-down.

    rounded : bool, default=False
        When set to ``True``, draw node boxes with rounded corners.

    special_characters : bool, default=False
        When set to ``False``, ignore special characters for PostScript
        compatibility.

    precision : int, default=3
        Number of digits of precision for floating point in the values of
        impurity, threshold and value attributes of each node.

    fontname : str, default='helvetica'
        Name of font used to render text.

    Returns
    -------
    dot_data : str
        String representation of the input tree in GraphViz dot format.
        Only returned if ``out_file`` is None.

        .. versionadded:: 0.18

    Examples
    --------
    >>> from sklearn.datasets import load_iris
    >>> from sklearn import tree

    >>> clf = tree.DecisionTreeClassifier()
    >>> iris = load_iris()
    >>> clf = clf.fit(iris.data, iris.target)
    >>> tree.export_graphviz(clf)
    'digraph Tree {...
    """
    if feature_names is not None:
        feature_names = check_array(
            feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
        )
    if class_names is not None and not isinstance(class_names, bool):
        class_names = check_array(
            class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
        )

    check_is_fitted(decision_tree)
    own_file = False
    return_string = False
    try:
        if isinstance(out_file, str):
            out_file = open(out_file, "w", encoding="utf-8")
            own_file = True

        if out_file is None:
            return_string = True
            out_file = StringIO()
            exporter = _DOTTreeExporter(
            out_file=out_file,
            max_depth=max_depth,
            feature_names=feature_names,
            class_names=class_names,
            label=label,
            filled=filled,
            leaves_parallel=leaves_parallel,
            impurity=impurity,
            node_ids=node_ids,
            proportion=proportion,
            rotate=rotate,
            rounded=rounded,
            special_characters=special_characters,
            precision=precision,
            fontname=fontname,
        )
        exporter.export(decision_tree)

        if return_string:
            return exporter.out_file.getvalue()

    finally:
        if own_file:
            out_file.close()

    def _compute_depth(tree, node):
        """
        Returns the depth of the subtree rooted in node.
        """
    

    def compute_depth_(
        current_node, current_depth, children_left, children_right, depths
    ):
        depths += [current_depth]
        left = children_left[current_node]
        right = children_right[current_node]
        if left != -1 and right != -1:
            compute_depth_(
                left, current_depth + 1, children_left, children_right, depths
            )
            compute_depth_(
                right, current_depth + 1, children_left, children_right, depths
            )

    depths = []
    compute_depth_(node, 1, tree.children_left, tree.children_right, depths)
    return max(depths)
@validate_params(
    {
        "decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
        "feature_names": ["array-like", None],
        "class_names": ["array-like", None],
        "max_depth": [Interval(Integral, 0, None, closed="left"), None],
        "spacing": [Interval(Integral, 1, None, closed="left"), None],
        "decimals": [Interval(Integral, 0, None, closed="left"), None],
        "show_weights": ["boolean"],
    },
    prefer_skip_nested_validation=True,
)

def export_text(
    decision_tree,
    *,
    feature_names=None,
    class_names=None,
    max_depth=10,
    spacing=3,
    decimals=2,
    show_weights=False,
):
    """Build a text report showing the rules of a decision tree.

    Note that backwards compatibility may not be supported.

    Parameters
    ----------
    decision_tree : object
        The decision tree estimator to be exported.
        It can be an instance ofDecisionTreeClassifier or DecisionTreeRegressor.

    feature_names : array-like of shape (n_features,), default=None
        An array containing the feature names.
        If None generic names will be used ("feature_0", "feature_1", ...).

    class_names : array-like of shape (n_classes,), default=None
        Names of each of the target classes in ascending numerical order.
        Only relevant for classification and not supported for multi-output.

        - if `None`, the class names are delegated to `decision_tree.classes_`;
        - otherwise, `class_names` will be used as class names instead of
          `decision_tree.classes_`. The length of `class_names` must match
          the length of `decision_tree.classes_`.

        .. versionadded:: 1.3

    max_depth : int, default=10
        Only the first max_depth levels of the tree are exported.
        Truncated branches will be marked with "...".

    spacing : int, default=3
        Number of spaces between edges. The higher it is, the wider the result.

    decimals : int, default=2
        Number of decimal digits to display.

    show_weights : bool, default=False
        If true the classification weights will be exported on each leaf.
        The classification weights are the number of samples each class.
        Returns
    -------
    report : str
        Text summary of all the rules in the decision tree.

    Examples
    --------

    >>> from sklearn.datasets import load_iris
    >>> from sklearn.tree import DecisionTreeClassifier
    >>> from sklearn.tree import export_text
    >>> iris = load_iris()
    >>> X = iris['data']
    >>> y = iris['target']
    >>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
    >>> decision_tree = decision_tree.fit(X, y)
    >>> r = export_text(decision_tree, feature_names=iris['feature_names'])
    >>> print(r)
    |--- petal width (cm) <= 0.80
    |   |--- class: 0
    |--- petal width (cm) >  0.80
    |   |--- petal width (cm) <= 1.75
    |   |   |--- class: 1
    |   |--- petal width (cm) >  1.75
    |   |   |--- class: 2
    """

    if feature_names is not None:
        feature_names = check_array(
            feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
        )
    if class_names is not None:
        class_names = check_array(
            class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
        )

    check_is_fitted(decision_tree)
    tree_ = decision_tree.tree_
    if is_classifier(decision_tree):
        if class_names is None:
            class_names = decision_tree.classes_
        elif len(class_names) != len(decision_tree.classes_):
            raise ValueError(
                "When `class_names` is an array, it should contain as"
                " many items as `decision_tree.classes_`. Got"
                f" {len(class_names)} while the tree was fitted with"
                f" {len(decision_tree.classes_)} classes."
            )
    right_child_fmt = "{} {} <= {}\n"
    left_child_fmt = "{} {} >  {}\n"
    truncation_fmt = "{} {}\n"

    if feature_names is not None and len(feature_names) != tree_.n_features:
        raise ValueError(
            "feature_names must contain %d elements, got %d"
            % (tree_.n_features, len(feature_names))
        )
    if isinstance(decision_tree, DecisionTreeClassifier):
        value_fmt = "{}{} weights: {}\n"
        if not show_weights:
            value_fmt = "{}{}{}\n"
    else:
        value_fmt = "{}{} value: {}\n"

    if feature_names is not None:
        feature_names_ = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else None
            for i in tree_.feature
        ]
    else:
        feature_names_ = ["feature_{}".format(i) for i in tree_.feature]

    export_text.report = ""

    def _add_leaf(value, class_name, indent):
        val = ""
        is_classification = isinstance(decision_tree, DecisionTreeClassifier)
        if show_weights or not is_classification:
            val = ["{1:.{0}f}, ".format(decimals, v) for v in value]
            val = "[" + "".join(val)[:-2] + "]"
        if is_classification:
            val += " class: " + str(class_name)
        export_text.report += value_fmt.format(indent, "", val)

    def print_tree_recurse(node, depth):
        indent = ("|" + (" " * spacing)) * depth
        indent = indent[:-spacing] + "-" * spacing
        value = None
        if tree_.n_outputs == 1:
            value = tree_.value[node][0]
        else:
            value = tree_.value[node].T[0]
        class_name = np.argmax(value)

        if tree_.n_classes[0] != 1 and tree_.n_outputs == 1:
            class_name = class_names[class_name]

        if depth <= max_depth + 1:
            info_fmt = ""
            info_fmt_left = info_fmt
            info_fmt_right = info_fmt

            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_names_[node]
                threshold = tree_.threshold[node]
                threshold = "{1:.{0}f}".format(decimals, threshold)
                export_text.report += right_child_fmt.format(indent, name, threshold)
                export_text.report += info_fmt_left
                print_tree_recurse(tree_.children_left[node], depth + 1)

                export_text.report += left_child_fmt.format(indent, name, threshold)
                export_text.report += info_fmt_right
                print_tree_recurse(tree_.children_right[node], depth + 1)
            else:  # leaf
                _add_leaf(value, class_name, indent)
        else:
            subtree_depth = _compute_depth(tree_, node)
            if subtree_depth == 1:
                _add_leaf(value, class_name, indent)
            else:
                trunc_report = "truncated branch of depth %d" % subtree_depth
                export_text.report += truncation_fmt.format(indent, trunc_report)

    print_tree_recurse(0, 1)
    return export_text.report

InvalidParameterError: The 'decision_tree' parameter of plot_tree must be an instance of 'sklearn.tree._classes.DecisionTreeClassifier' or an instance of 'sklearn.tree._classes.DecisionTreeRegressor'. Got <__main__.DecisionTree object at 0x129f6b320> instead.