In [None]:
# Plot a scikit-learn DecisionTree with Plotly (no matplotlib)
import numpy as np
import plotly.graph_objects as go

def plot_tree_plotly(clf, feature_names=None, class_names=None, max_depth=None, precision=2, filled=True, node_size=14, height=600, width=1000):
    """
    Plot a scikit-learn DecisionTreeClassifier/Regressor using Plotly.

    Parameters
    ----------
    clf : fitted DecisionTreeClassifier or DecisionTreeRegressor
    feature_names : list[str] | None
        Names for features; defaults to indices if None.
    class_names : list[str] | None
        Optional class names for classifiers; inferred from clf.classes_ if None.
    max_depth : int | None
        Limit the depth displayed.
    precision : int
        Digits for thresholds and impurity.
    filled : bool
        If True, color leaf nodes by predicted class/value.
    node_size : int
        Marker size for nodes.
    height, width : int
        Figure size.
    """
    tree = clf.tree_
    n_nodes = tree.node_count
    children_left = tree.children_left
    children_right = tree.children_right
    feature = tree.feature
    threshold = tree.threshold
    impurity = tree.impurity
    try:
        n_node_samples = tree.n_node_samples
    except Exception:
        n_node_samples = tree.weighted_n_node_samples.astype(int)

    is_classifier = hasattr(clf, 'classes_')
    if is_classifier and class_names is None:
        class_names = [str(c) for c in getattr(clf, 'classes_', [])]

    if feature_names is None:
        feature_names = [f'x{idx}' for idx in range(max(0, feature.max()) + 1)] if n_nodes > 0 else []

    LEAF = -1
    def is_leaf(node):
        return children_left[node] == LEAF and children_right[node] == LEAF

    # Precompute leaves count per subtree for x-positioning
    from functools import lru_cache
    @lru_cache(None)
    def leaf_count(node):
        if node < 0:
            return 0
        if is_leaf(node) or (max_depth is not None and depth(node) >= max_depth):
            return 1
        return leaf_count(children_left[node]) + leaf_count(children_right[node])

    # Compute depth per node
    depths = np.zeros(n_nodes, dtype=int)
    def assign_depths(node=0, d=0):
        depths[node] = d
        if is_leaf(node):
            return
        if max_depth is not None and d >= max_depth - 1:
            return
        if children_left[node] != LEAF:
            assign_depths(children_left[node], d+1)
        if children_right[node] != LEAF:
            assign_depths(children_right[node], d+1)
    # Helper to query already-computed depth (used by leaf_count)
    def depth(node):
        return int(depths[node])

    if n_nodes == 0:
        return go.Figure()

    assign_depths(0, 0)

    # Compute x positions by allocating horizontal space proportional to leaf counts
    xs = np.zeros(n_nodes, dtype=float)
    ys = np.zeros(n_nodes, dtype=float)

    def assign_positions(node=0, x0=0.0, x1=1.0):
        y = -depth(node)
        ys[node] = y
        if is_leaf(node) or (max_depth is not None and depth(node) >= max_depth):
            xs[node] = (x0 + x1) / 2.0
            return
        lc = leaf_count(children_left[node])
        rc = leaf_count(children_right[node])
        total = max(lc + rc, 1)
        x_mid = x0 + (lc / total) * (x1 - x0)
        xs[node] = x_mid
        assign_positions(children_left[node], x0, x_mid)
        assign_positions(children_right[node], x_mid, x1)

    assign_positions(0, 0.0, 1.0)

    # Build edge coordinates
    edge_x, edge_y = [], []
    for i in range(n_nodes):
        if is_leaf(i) or (max_depth is not None and depth(i) >= max_depth):
            continue
        for child in (children_left[i], children_right[i]):
            if child == LEAF:
                continue
            if max_depth is not None and depth(child) >= max_depth + 1:
                continue
            edge_x += [xs[i], xs[child], None]
            edge_y += [ys[i], ys[child], None]

    edges = go.Scatter(x=edge_x, y=edge_y, mode='lines', line=dict(color='rgba(150,150,150,0.6)', width=1), hoverinfo='skip')

    # Build node labels and colors
    texts = []
    colors = []
    for i in range(n_nodes):
        if max_depth is not None and depth(i) >= max_depth and not is_leaf(i):
            # Truncated branch node
            texts.append(f'Depth {depth(i)} (truncated)')
            colors.append('lightgray')
            continue
        if is_leaf(i):
            # value shape: (n_outputs, n_classes) for classifier or (n_outputs, 1) for regressor
            val = tree.value[i]
            val_str = np.array2string(val.squeeze(), precision=precision, separator=', ')
            if is_classifier:
                pred_idx = int(np.argmax(val.squeeze()))
                pred_name = class_names[pred_idx] if class_names and pred_idx < len(class_names) else str(pred_idx)
                txt = f'class: {pred_name}<br>samples: {int(n_node_samples[i])}<br>value: {val_str}'
                texts.append(txt)
                # simple color palette
                palette = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692']
                colors.append(palette[pred_idx % len(palette)] if filled else 'white')
            else:
                pred = float(val.squeeze())
                txt = f'value: {pred:.{precision}f}<br>samples: {int(n_node_samples[i])}'
                texts.append(txt)
                colors.append('#636EFA' if filled else 'white')
        else:
            fidx = feature[i]
            fname = feature_names[fidx] if fidx >= 0 and fidx < len(feature_names) else f'f{fidx}'
            thr = threshold[i]
            txt = f'{fname} <= {thr:.{precision}f}<br>impurity: {impurity[i]:.{precision}f}<br>samples: {int(n_node_samples[i])}'
            texts.append(txt)
            colors.append('lightsteelblue')

    nodes = go.Scatter(
        x=xs, y=ys, mode='markers+text',
        marker=dict(size=node_size, color=colors, line=dict(color='black', width=1)),
        text=texts, textposition='top center',
        hoverinfo='text'
    )

    fig = go.Figure(data=[edges, nodes])
    # Format layout: flip y so root at top
    fig.update_layout(
        showlegend=False,
        height=height, width=width,
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.05, 1.05]),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, autorange='reversed')
    )
    return fig


In [None]:
# Example: Train a small tree on the provided Titanic CSV and plot with Plotly
import pandas as pd
from sklearn.tree import DecisionTreeClassifier

df = pd.read_csv('Machine Learning/Classification & Regression/Predicting Housing Prices/DecisionTrees/DecisionTrees_titanic.csv')
y = df['Survived']
X = df.drop(columns=['Survived'])
feature_names = list(X.columns)

clf = DecisionTreeClassifier(max_depth=3, random_state=0)
clf.fit(X, y)

fig = plot_tree_plotly(clf, feature_names=feature_names, class_names=['Not Survived','Survived'], max_depth=3, precision=2, filled=True)
fig.update_layout(title='Decision Tree (Plotly)')
fig.show()
