In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output
from collections import deque, Counter

import matplotlib
import matplotlib.cm

from manify.predictors.decision_tree import _angular_greater

######################
# HELPER FUNCTIONS
######################


def build_edges_and_id_map(pdt):
    """
    Returns:
      node2id: dict mapping each DecisionNode object -> int index (matching pdt.nodes order)
      edges: list of (id_parent, id_child)
    """
    node2id = {node: i for i, node in enumerate(pdt.nodes)}
    edges = []
    for i, node in enumerate(pdt.nodes):
        if node.left is not None:
            edges.append((i, node2id[node.left]))
        if node.right is not None:
            edges.append((i, node2id[node.right]))
    return node2id, edges


def compute_node_masks(pdt, X, y, node2id):
    """
    BFS from root to produce bool mask for each node.
    """
    angles, _, _, _ = pdt._preprocess(X, y)
    node_masks = {}

    root = pdt.tree
    queue = deque([(root, torch.ones(X.shape[0], dtype=bool))])
    while queue:
        node, mask = queue.popleft()
        node_id = node2id[node]
        node_masks[node_id] = mask

        if node.feature is not None:
            theta = torch.tensor(node.theta)
            left_child = node.left
            right_child = node.right
            if left_child is not None:
                left_mask = mask & _angular_greater(angles[:, node.feature], theta).flatten()
                queue.append((left_child, left_mask))
            if right_child is not None:
                right_mask = mask & ~_angular_greater(angles[:, node.feature], theta).flatten()
                queue.append((right_child, right_mask))
    return node_masks


def bfs_tree_layout(pdt, node2id):
    """
    Manual BFS layout:
      - For depth d, y = -d
      - Spread nodes horizontally at each level
    """
    root = pdt.tree
    levels = {}
    queue = deque([(root, 0)])
    max_level = 0
    while queue:
        node, depth = queue.popleft()
        levels.setdefault(depth, []).append(node)
        max_level = max(max_level, depth)
        if node.left:
            queue.append((node.left, depth + 1))
        if node.right:
            queue.append((node.right, depth + 1))

    positions = {}
    for depth in range(max_level + 1):
        row = levels.get(depth, [])
        n_row = len(row)
        for i, n in enumerate(row):
            x = i - (n_row - 1) / 2
            y = -depth
            positions[node2id[n]] = (x, y)
    return positions


def make_tree_figure(edges, positions):
    """
    Build a Plotly figure for the tree given (src->dst) edges and {node_id: (x, y)} positions.
    """
    edge_x, edge_y = [], []
    for src, dst in edges:
        x0, y0 = positions[src]
        x1, y1 = positions[dst]
        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y, mode="lines", line=dict(width=1, color="gray"), hoverinfo="none", name="Edges"
    )

    node_x = []
    node_y = []
    node_ids = []
    for nid, (xx, yy) in positions.items():
        node_x.append(xx)
        node_y.append(yy)
        node_ids.append(nid)

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode="markers",
        marker=dict(size=12, color="blue"),
        text=[str(i) for i in node_ids],
        hoverinfo="text",
        customdata=node_ids,
        name="Nodes",
    )

    fig = go.Figure(data=[edge_trace, node_trace])
    fig.update_layout(title="Decision Tree (manual BFS layout)", showlegend=False, clickmode="event+select")
    return fig


def get_xy_coords(X, dim1, dim2):
    """
    Workaround for 'None' dimension:
    - If dim1 is None, replace it with an array of 1s
    - If dim2 is None, replace it with an array of 1s
    - If both are None, returns (None, None)
    """
    n = X.shape[0]
    if dim1 is None and dim2 is None:
        return None, None
    if dim1 is None:
        return np.ones(n), X[:, dim2]
    if dim2 is None:
        return X[:, dim1], np.ones(n)
    # Otherwise both dims are valid
    return X[:, dim1], X[:, dim2]


def create_class_color_map(y):
    """
    Use matplotlib's tab10 to get consistent colors for each unique class in y.
    Returns a dict {class_value: "rgba(...)"}.
    """
    unique_classes = np.unique(y)
    tab10 = matplotlib.cm.get_cmap("tab10")

    class2color = {}
    unique_classes_sorted = sorted(unique_classes)

    for i, c in enumerate(unique_classes_sorted):
        rgba = tab10(i % 10)  # (r, g, b, alpha) in [0..1]
        r, g, b, a = rgba
        r, g, b = int(r * 255), int(g * 255), int(b * 255)
        color_str = f"rgba({r},{g},{b},{a})"
        class2color[c] = color_str

    return class2color


def make_halfplane_background(X, dim1, dim2, theta, line_color="red"):
    """
    Colors the plane into two semi-transparent half-planes.
    We also draw the threshold boundary line, but only slightly outside the bounding box
    so it doesn't force the entire plot to be huge.
    """
    c = np.cos(theta)
    s = np.sin(theta)

    x_data, y_data = get_xy_coords(X, dim1, dim2)
    if x_data is None:  # Means both dims None => skip
        return []

    x_min, x_max = x_data.min(), x_data.max()
    y_min, y_max = y_data.min(), y_data.max()

    # Let's build a grid from (x_min, x_max) x (y_min, y_max).
    N = 50
    x_lin = np.linspace(x_min, x_max, N)
    y_lin = np.linspace(y_min, y_max, N)
    xx, yy = np.meshgrid(x_lin, y_lin)

    # For illustration: left side if (s*x + c*y) > 0, else right side
    Z = (s * xx + c * yy > 0).astype(int)

    # Discrete 2-color scale
    colorscale = [
        [0.0, "rgba(0,0,200,0.08)"],  # for Z=0
        [0.4999999, "rgba(0,0,200,0.08)"],
        [0.5, "rgba(200,0,0,0.08)"],  # for Z=1
        [1.0, "rgba(200,0,0,0.08)"],
    ]

    heatmap = go.Heatmap(x=x_lin, y=y_lin, z=Z, opacity=0.3, showscale=False, colorscale=colorscale, hoverinfo="none")

    # Draw the boundary line just outside data range
    dx = x_max - x_min
    dy = y_max - y_min
    bound_size = max(dx, dy) * 1.2  # scale a bit bigger than data
    x_line = [-s * bound_size, s * bound_size]
    y_line = [-c * bound_size, c * bound_size]

    line = go.Scatter(x=x_line, y=y_line, mode="lines", line=dict(color=line_color), name="Threshold")
    return [heatmap, line]


def draw_data_figure(X, y, mask, node_info, class2color):
    """
    - If node is a leaf => show a pie chart of label distribution (or placeholder if empty)
    - Else => 2D scatter with half-plane background & threshold line
    """
    # 1) Check if leaf
    if node_info["feature"] is None:
        # It's a leaf. Show a pie chart of label distribution among masked data
        y_in_leaf = y[mask]
        if len(y_in_leaf) == 0:
            # If no data belongs to this leaf, let's just do a small placeholder
            fig = go.Figure()
            fig.add_annotation(
                x=0.5,
                y=0.5,
                xref="paper",
                yref="paper",
                text=f"Leaf node {node_info['id']}<br>No data in mask",
                showarrow=False,
            )
            fig.update_layout(title=f"Leaf node {node_info['id']} - Empty")
            return fig

        counts = Counter(y_in_leaf.tolist())

        labels = []
        values = []
        colors = []

        # We ensure each known class is in the pie, even if 0, for consistent legend:
        # But if you prefer to only show classes that appear, remove that logic.
        all_classes = sorted(class2color.keys())
        for cl in all_classes:
            labels.append(str(cl))
            values.append(counts.get(cl, 0))
            colors.append(class2color[cl])

        fig = go.Figure()
        fig.add_trace(go.Pie(labels=labels, values=values, marker=dict(colors=colors), hoverinfo="label+value+percent"))
        fig.update_layout(title=f"Leaf node {node_info['id']} - Class distribution")
        return fig

    # 2) Non-leaf => scatter + threshold line + background
    dim1, dim2 = node_info["dim1"], node_info["dim2"]
    theta = node_info["theta"]

    # Convert data to plotting coords (possibly with None -> 1's)
    x_all, y_all = get_xy_coords(X, dim1, dim2)
    if x_all is None:
        # Fallback if both dims are None
        fig = go.Figure()
        fig.update_layout(title=f"Node {node_info['id']}: no 2D dims")
        return fig

    x_in = x_all[mask]
    y_in = y_all[mask]
    ymask = y[mask]

    # We'll create the figure
    fig = go.Figure()

    # a) Add half-plane coloring
    if theta is not None:
        background_traces = make_halfplane_background(X, dim1, dim2, theta)
        for tr in background_traces:
            fig.add_trace(tr)

    # b) Plot points outside the mask in grey
    fig.add_trace(
        go.Scatter(
            # x=x_all[~mask],
            # y=y_all[~mask],
            x=y_all[~mask],
            y=x_all[~mask],
            mode="markers",
            marker=dict(color="lightgray"),
            name="Excluded",
            hoverinfo="none",
        )
    )

    # c) Plot masked points by class color
    unique_in_mask = np.unique(ymask)
    for cl in unique_in_mask:
        color = class2color[cl]  # consistent across the entire tree
        idx = ymask == cl
        fig.add_trace(
            #     go.Scatter(x=x_in[idx], y=y_in[idx], mode="markers", marker=dict(color=color), name=f"Class {cl}")
            go.Scatter(x=y_in[idx], y=x_in[idx], mode="markers", marker=dict(color=color), name=f"Class {cl}")
        )

    # d) Finally set the axis range to the bounding box of the data (so we don't zoom out).
    x_min, x_max = x_all.min(), x_all.max()
    y_min, y_max = y_all.min(), y_all.max()
    pad_x = 0.05 * (x_max - x_min) if x_max > x_min else 1
    pad_y = 0.05 * (y_max - y_min) if y_max > y_min else 1

    fig.update_layout(
        title=f"Node {node_info['id']} (feature={node_info['feature']}, θ={theta:.2f})",
        # xaxis=dict(range=[x_min - pad_x, x_max + pad_x]),
        # yaxis=dict(range=[y_min - pad_y, y_max + pad_y]),
        xaxis=dict(range=[y_min - pad_y, y_max + pad_y]),
        yaxis=dict(range=[x_min - pad_x, x_max + pad_x]),
    )

    return fig


######################
# MAIN DASH APP
######################


def create_dashboard(pdt, X, y):
    node2id, edges = build_edges_and_id_map(pdt)
    node_masks = compute_node_masks(pdt, X, y, node2id)
    positions = bfs_tree_layout(pdt, node2id)

    # Gather node info
    node_data = {}
    for i, node in enumerate(pdt.nodes):
        if node.feature is not None:
            dim1, dim2 = pdt.angle_dims[node.feature]
        else:
            dim1, dim2 = (None, None)
        node_data[i] = {
            "id": i,
            "feature": node.feature,
            "theta": node.theta if node.feature is not None else None,
            "dim1": dim1,
            "dim2": dim2,
        }

    # Build color map for classes
    class2color = create_class_color_map(y)

    # Make tree figure
    tree_fig = make_tree_figure(edges, positions)

    # Set up Dash
    app = Dash(__name__)
    app.layout = html.Div(
        [
            html.H1("Interactive Decision Tree"),
            html.Div(
                [
                    dcc.Graph(
                        id="tree-graph",
                        figure=tree_fig,
                        style={"width": "45%", "display": "inline-block", "verticalAlign": "top"},
                    ),
                    dcc.Graph(
                        id="data-graph", style={"width": "50%", "display": "inline-block", "verticalAlign": "top"}
                    ),
                ]
            ),
        ]
    )

    @app.callback(Output("data-graph", "figure"), Input("tree-graph", "clickData"))
    def update_plot(clickData):
        if not clickData:
            return go.Figure()
        node_id = clickData["points"][0]["customdata"]
        info = node_data[node_id]
        mask = node_masks[node_id]

        if info["feature"] is None:
            # Pie chart case for leaf nodes
            fig = go.Figure()
            y_in_leaf = y[mask]
            if len(y_in_leaf) == 0:
                fig.add_annotation(x=0.5, y=0.5, xref="paper", yref="paper", text="Empty Leaf Node", showarrow=False)
            else:
                counts = Counter(y_in_leaf.tolist())
                labels = list(counts.keys())
                values = list(counts.values())
                colors = [class2color[label] for label in labels]

                fig.add_trace(
                    go.Pie(
                        labels=labels,
                        values=values,
                        marker=dict(colors=colors),
                        hoverinfo="label+value+percent",
                    )
                )
        else:
            # Non-leaf node: scatter plot with half-plane coloring
            fig = draw_data_figure(X, y, mask, info, class2color)

        return fig

    return app

In [None]:
import manify
from sklearn.model_selection import train_test_split

pm = manify.ProductManifold(signature=[(1, 2)])
X, y = pm.gaussian_mixture(1000, num_classes=4)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

pdt = manify.ProductSpaceDT(pm=pm, n_features="d_choose_2", max_depth=5)
pdt.fit(X_train, y_train)

print(f"{pdt.score(X_test, y_test).float().mean().item():.4f}")

create_dashboard(pdt, X_test, y_test).run(debug=True)

0.7050


  tab10 = matplotlib.cm.get_cmap("tab10")
