In [157]:
import dtreeviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display, clear_output
from math import ceil, sqrt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from ipywidgets import HBox, VBox, Layout, widgets
from plotly.graph_objs import FigureWidget, Scatter, Table
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Load the Iris dataset
iris = datasets.load_iris()
df = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
                     columns= iris['feature_names'] + ['target'])

## Version 2 (interactive)

In [158]:
# create histogram
def create_histograms(df=df, exclude_cols=['target', '_x', '_y'], legend=True):
    curr_df = df.drop(exclude_cols, axis=1)
    r = int(sqrt(len(curr_df.columns)))
    c = ceil(len(curr_df.columns) / r)
    # fig = make_subplots(rows=r+1, cols=c+1, column_width=[1/c for _ in range(c + 1)], horizontal_spacing=0.2)
    fig = make_subplots(rows=r+1, cols=c+1)
    col_num =0
    max_cols = len(df.columns)
    for i in range(1, r+1):
        for j in range(1, c+1):
            if col_num < max_cols:
                fig.add_trace(go.Histogram(x=curr_df[curr_df.columns[col_num]], name=curr_df.columns[col_num]), row=i, col=j) 
                fig.add_annotation(xref="x domain",yref="y domain",x=0.5, y=1.2, showarrow=False,
                       text=f"<b>{curr_df.columns[col_num]}</b>", row=i, col=j)
            col_num += 1
    fig.update_layout(margin=dict(l=0, r=0, b=0))
    # fig.update_layout(
    #     legend=dict(
    #     )
    # )
    fig.update_traces(showlegend=legend)
    return fig

In [162]:
# create an output widget
s = widgets.Output()
def explain_cluster(df, x_cols, dtreeviz_plot=True):
    # Split data into features and target
    X = df[x_cols].values  # replace with the names of the columns you want to use as features
    y = df['_selected'].values  # replace with the name of the target column you want to predict

    # Create and fit a decision tree classifier
    clf = DecisionTreeClassifier()
    clf.fit(X, y)
    
    feature_importances = clf.feature_importances_
#     print(feature_importances, x_cols)
    # Print the feature importances
    # Combine feature names and importances into a list of tuples
    feature_importances = list(zip(x_cols, feature_importances))

    # Sort the list in descending order by feature importance
    feature_importances_sorted = sorted(feature_importances, key=lambda x: x[1], reverse=True)

    # Iterate over the sorted list and print out the feature names and importances
    from ipywidgets import Output, VBox
    print('Feature Importances in Decision Tree')
    for feature_name, importance in feature_importances_sorted:
        importance_percent = importance * 100
        print(f"{feature_name}: {importance_percent:.2f}%")    
    viz_model = dtreeviz.model(clf,
                           X_train=X, y_train=y,
                           feature_names=x_cols,
                           target_name=['_selected'], class_names=["not selected", "selected"])
#     plot_tree(clf,
#        feature_names = x_cols, 
# #         feature_names = ['A', 'B', 'C', 'D'],
#        class_names=['not selected', 'selected'],
#        filled = True)
#     print(type(viz_model))

    out = viz_model.view(scale=0.8)
#     clear_output(wait=True)
    with s:
        s.clear_output()
        display(out)
print(type(s))

<class 'ipywidgets.widgets.widget_output.Output'>


In [163]:
# fit decision tree to selected and not zselected points

In [164]:
def create_lasso(data=df, mode='table', exclude_cols=[]):
    global s
    """
    Input: Datafame
    Output: Plotly FigureWidget with lasso select tool
    data: Pandas Dataframe of data
    exclude_cols: columns to exclude
    mode:
     - 'table' shows a table of selected points
     - 'histogram' shows an interactive histogram selected points
     - 'explainer' predicts which factors lead to clustered selection
    """
    IS_HIST = mode == 'histogram'
    TOP_FACTORS = mode == 'explainer'
    pca = PCA(n_components=2)
    pca.fit(df)
    pca_df = pd.DataFrame(pca.transform(df), columns=['_x', '_y'])
    df['_x'] = pca_df['_x']
    df['_y'] = pca_df['_y']

    # TODO: default to lasso select
    # data_name = f"{df=}".split('=')[0]
    f = FigureWidget([Scatter(y = df["_x"], x = df["_y"], mode = 'markers')])
    f.update_layout(dragmode='lasso')
    f.layout.title = "Data Lasso Scatterplot"
    scatter = f.data[0]
    df.dropna()
    exclude_cols.extend(['_x', '_y'])

    N = len(df)
    scatter.marker.opacity = 0.5
    t = None
    
    if mode=='table':
        # Create a table FigureWidget that updates on selection from points in the scatter plot of f
        t = FigureWidget([Table(
            header=dict(values=df.columns,
                        fill = dict(color='#C2D4FF'),
                        align = ['left'] * 5),

            cells=dict(values=[df[col] for col in df.columns],
                    fill = dict(color='#F5F8FF'),
                    align = ['left'] * 5
                    ))])
    if IS_HIST:
        hist = create_histograms(df, exclude_cols=exclude_cols, legend=True)
        no_legend = create_histograms(df, exclude_cols=exclude_cols)
        # t is for "table", but can also be where data is
        t = go.FigureWidget(no_legend, )
        t.layout.title = 'All Points'
        # s is selected
        s = go.FigureWidget(hist)
        s.layout.title = 'Selected Points'
    if TOP_FACTORS:
        pass
    def selection_fn(trace,points,selector):
        global s
        if mode=='table':
            t.data[0].cells.values = [df.loc[points.point_inds][col] for col in df.columns]
        if IS_HIST:
            selected = df[df.index.isin(points.point_inds)]
            new_charts = create_histograms(selected, exclude_cols=exclude_cols, legend=True)
            s.data = []
            s.add_traces(new_charts.data)
        if TOP_FACTORS:
            df['_selected'] = df.index.isin(points.point_inds)
            x_cols = list(filter(lambda x: x not in exclude_cols and x != '_selected', df.columns))
            s = explain_cluster(df, x_cols)
            # t.data = []
            # t.add_traces(decision_tree.data)            
    scatter.on_selection(selection_fn)

    # iplot({data : scatter.on_selection(selection_fn)})
    # Put everything together
    if IS_HIST:
        return VBox((f, s, t), layout=Layout(align_items='flex-start', margin='0px', justify_content='center'))
    return VBox(tuple(x for x in [f, s, t] if x))

create_lasso(mode='explainer', exclude_cols=['target'])

VBox(children=(FigureWidget({
    'data': [{'marker': {'opacity': 0.5},
              'mode': 'markers',
     …

Feature Importances in Decision Tree
sepal length (cm): 46.93%
sepal width (cm): 22.59%
petal width (cm): 18.51%
petal length (cm): 11.97%
Feature Importances in Decision Tree
sepal length (cm): 36.34%
sepal width (cm): 31.43%
petal length (cm): 28.06%
petal width (cm): 4.17%


AttributeError: __enter__

Feature Importances in Decision Tree
sepal length (cm): 41.60%
petal length (cm): 26.26%
sepal width (cm): 25.15%
petal width (cm): 6.98%


AttributeError: __enter__