In [1]:
import dtreeviz
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
from IPython.display import display, clear_output
# from IPython.display.DisplayHandle import display, update
# import IPython.display
from math import ceil, sqrt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.decomposition import PCA
from ipywidgets import HBox, VBox, Layout, widgets
from plotly.graph_objs import FigureWidget, Scatter, Table
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.preprocessing import LabelEncoder

# Load the Iris dataset
# iris = datasets.load_iris()
# df = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
#                      columns= iris['feature_names'] + ['target'])
df = pd.read_csv('data.csv')
df.head()
labels = pd.read_csv('labels.csv')
df['label'] = labels['Class']
df.drop('Unnamed: 0', axis=1, inplace=True)

## Version 2 (interactive)

In [2]:
# create histogram
def create_histograms(df=df, exclude_cols=['target', '_x', '_y'], legend=True):
    """
    Creates histograms of features in selected region and all data
    ---
    Input: Datafame
    Output:
        If dtreeviz_plot=True, then dtreeviz.utils.DTreeVizRender
        Else, None
    ---
    df: Pandas Dataframe of data to analyze
    exclude_cols: Columns not to plot (output and otherwise)
    legend (default True): Whether to show legend on plots
    """
    curr_df = df.drop(exclude_cols, axis=1)
    r = int(sqrt(len(curr_df.columns)))
    c = ceil(len(curr_df.columns) / r)
    # fig = make_subplots(rows=r+1, cols=c+1, column_width=[1/c for _ in range(c + 1)], horizontal_spacing=0.2)
    fig = make_subplots(rows=r+1, cols=c+1)
    col_num =0
    max_cols = len(df.columns)
    for i in range(1, r+1):
        for j in range(1, c+1):
            if col_num < max_cols:
                fig.add_trace(go.Histogram(x=curr_df[curr_df.columns[col_num]], name=curr_df.columns[col_num]), row=i, col=j) 
                fig.add_annotation(xref="x domain",yref="y domain",x=0.5, y=1.2, showarrow=False,
                       text=f"<b>{curr_df.columns[col_num]}</b>", row=i, col=j)
            col_num += 1
    fig.update_layout(margin=dict(l=0, r=0, b=0))
    fig.update_traces(showlegend=legend)
    return fig

In [3]:
def explain_cluster(df, x_cols, num_factors = 10, dtreeviz_plot=True):
    """
    Runs decision tree on selected region and plots feature importance and decision boundaries
    ---
    Input: Pandas Datafame
    Output:
        If dtreeviz_plot=True, then dtreeviz.utils.DTreeVizRender
        Else, None
    ---
    df: Pandas Dataframe of data to analyze
    x_cols: Input columns of df
    dtreeviz_plot: 
     - If True, plots decision tree of selection boundaries using dtreeviz library
     - Else, plots decision tree of selection boundaries using sklearn (faster)
    """
    # Split data into features and target
    X = df[x_cols].values  # replace with the names of the columns you want to use as features
    y = df['_selected'].values  # replace with the name of the target column you want to predict

    # Create and fit a decision tree classifier
    clf = DecisionTreeClassifier()
    clf.fit(X, y)
    
    feature_importances = clf.feature_importances_
    # Print the feature importances
    # Combine feature names and importances into a list of tuples
    feature_importances = list(zip(x_cols, feature_importances))

    # Sort the list in descending order by feature importance
    feature_importances_sorted = sorted(feature_importances, key=lambda x: x[1], reverse=True)

    # Iterate over the sorted list and print out the feature names and importances
    from ipywidgets import Output, VBox
    print('Feature Importances in Decision Tree')
    for feature_name, importance in feature_importances_sorted[:num_factors]:
        importance_percent = importance * 100
        print(f"{feature_name}: {importance_percent:.2f}%")   
    if dtreeviz_plot:
        viz_model = dtreeviz.model(clf,
                               X_train=X, y_train=y,
                               feature_names=x_cols,
                               target_name=['_selected'], class_names=["not selected", "selected"])
        display(viz_model.view(scale=1.3))
    else:
        out = plot_tree(clf,
           feature_names = x_cols,
           class_names=['not selected', 'selected'],
           filled = True)
        out

In [4]:
# def get_colors(df, label_col):
#     # Using LabelEncoder to convert categories into numerical labels
#     le = LabelEncoder()
#     labels = le.fit_transform(df[label_col])
#     # Define color palette (can be customized)
#     cmap = mpl.colormaps['tab10']  # or any other colormap
#     palette = [cmap(i) for i in np.linspace(0, 1, np.max(labels))]
#     color = [palette[label] for label in labels]
#     print(dict(color=mcolors.rgb2hex(df)))
#     print(le.classes_)
#     return dict(color=mcolors.rgb2hex(df))
# get_colors(df, 'label')

In [17]:
def get_colors(df, label_col):
    """
    Generate a color mapping for distinct labels in a specified column of a DataFrame.
    ---
    Input: DataFrame, Label Column
    Output: Dictionary mapping labels to colors
    ---
    df: Pandas DataFrame containing the data.
    label_col: String specifying the column in the DataFrame that contains the labels to be colored.
    
    This function uses the LabelEncoder from sklearn to convert distinct labels in the specified column
    to numerical values. It then generates a color for each distinct label using a matplotlib colormap.
    The function returns a dictionary where the keys are the original labels and the values are the
    corresponding colors in RGB hex format.
    """
    # map categories into numerical labels
    le = LabelEncoder()
    labels = le.fit_transform(df[label_col])
    
    cmap = mpl.colormaps['tab10'] 
    
    # Generate a color for each unique label
    # changed from .unique() approach, I think this should work
    colors = cmap(np.linspace(0, 1, np.max(labels) +1))

    # Convert colors to RGB hex format
   # colors = [mcolors.rgb2hex(color) for color in colors]

    color_map = dict(zip(le.classes_, colors))
    
    # Map each label in df[label_col] to its corresponding color
    color_column = df[label_col].map(color_map)
    
    return dict(color=list(color_column))

In [18]:
def create_lasso(df, mode='table', label_col=None, exclude_cols=[], num_factors = 10, dtreeviz_plot=True):
    """
    Create Lasso tool to analyze lassoed data via below mode options
    ---
    Input: Datafame
    Output: Plotly FigureWidget with lasso select tool
    ---
    data: Pandas Dataframe of data
    exclude_cols: columns to exclude
    mode:
     - 'table' shows a table of selected points
     - 'histogram' shows an interactive histogram selected points
     - 'explainer' predicts which factors lead to clustered selection
    dtreeviz_plot: 
     - If mode = 'explainer' and True, plots decision tree of selection boundaries using dtreeviz library
     - Else, plots decision tree of selection boundaries using sklearn (faster)

    """
    IS_HIST = mode == 'histogram'
    TOP_FACTORS = mode == 'explainer'
    s = widgets.Output()
    pca = PCA(n_components=2)
    pca_cols = [x for x in df.columns if x not in exclude_cols]
    included_cols_df = df[pca_cols]
    pca.fit(included_cols_df)
    pca_df = pd.DataFrame(pca.transform(included_cols_df), columns=['_x', '_y'])
    df['_x'] = pca_df['_x']
    df['_y'] = pca_df['_y']
    
    # !! TODO: fix color + None handling
    color = None
    if label_col is not None:
        # https://stackoverflow.com/questions/68721086/plotly-how-to-define-marker-color-based-on-category-string-value-for-a-3d-scatt
        f = FigureWidget([Scatter(y = df["_x"], x = df["_y"], mode = 'markers', marker=get_colors(df, 'label'))])

    f = FigureWidget([Scatter(y = df["_x"], x = df["_y"], mode = 'markers')])
    f.update_layout(dragmode='lasso')
    f.layout.title = "Data Lasso Scatterplot"
    scatter = f.data[0]
    df.dropna()
    exclude_cols.extend(['_x', '_y'])

    N = len(df)
    scatter.marker.opacity = 0.5
    t = None
    
    if mode=='table':
        # Create a table FigureWidget that updates on selection from points in the scatter plot of f
        t = FigureWidget([Table(
            header=dict(values=df.columns,
                        fill = dict(color='#C2D4FF'),
                        align = ['left'] * 5),

            cells=dict(values=[df[col] for col in df.columns],
                    fill = dict(color='#F5F8FF'),
                    align = ['left'] * 5
                    ))])
    if IS_HIST:
        hist = create_histograms(df, exclude_cols=exclude_cols, legend=True)
        no_legend = create_histograms(df, exclude_cols=exclude_cols)
        # t is for "table", but can also be where data is
        t = go.FigureWidget(no_legend, )
        t.layout.title = 'All Points'
        # s is selected
        s = go.FigureWidget(hist)
        s.layout.title = 'Selected Points'
    if TOP_FACTORS:
        pass
    def selection_fn(trace,points,selector):
        nonlocal s
        if mode=='table':
            t.data[0].cells.values = [df.loc[points.point_inds][col] for col in df.columns]
        if IS_HIST:
            selected = df[df.index.isin(points.point_inds)]
            new_charts = create_histograms(selected, exclude_cols=exclude_cols, legend=True)
            s.data = []
            s.add_traces(new_charts.data)
        if TOP_FACTORS:
            df['_selected'] = df.index.isin(points.point_inds)
            x_cols = list(filter(lambda x: x not in exclude_cols and x != '_selected', df.columns))
            with s:
                clear_output(wait=True)
            out = explain_cluster(df, x_cols, num_factors, dtreeviz_plot=dtreeviz_plot)    
    scatter.on_selection(selection_fn)

    # Put everything together
    if IS_HIST:
        return VBox((f, s, t), layout=Layout(align_items='flex-start', margin='0px', justify_content='center'))
    return VBox(tuple(x for x in [f, s, t] if x))

# create_lasso(mode='explainer', exclude_cols=['target'], dtreeviz_plot=False)

In [19]:
create_lasso(df, mode='explainer', label_col='label', exclude_cols=['label'], dtreeviz_plot=False)

VBox(children=(FigureWidget({
    'data': [{'marker': {'opacity': 0.5},
              'mode': 'markers',
     …

In [None]:
df.head()

In [None]:
# add svm visualization - project the decision boundary down to 2d
# - want PCA acconting for actual data and SVM line
# - look at gene expression dataset
# - later counterfactual explanations
# - another idea: what if you have a graph and you select one dimension and PCA the other...
# - could also add color as another dimension

In [None]:
!git add .