In [1]:
import random
import numpy as np
from numpy import sqrt, argmax
import pandas as pd

import sklearn.linear_model
from sklearn import metrics
from sklearn.datasets import make_classification
#from sklearn.inspection import DecisionBoundaryDisplay

from plotnine import *
import ipywidgets as widgets
from ipywidgets import FloatSlider, Dropdown, IntSlider, HBox, VBox, interactive_output, Layout, AppLayout

import warnings
warnings.filterwarnings("ignore", module = "plotnine\..*" )

In [2]:
### GENERATE DATA ###
# temporarily calling data generation function for each plot due to subplotting/refresh issue

common_params = {
    "n_samples": 300,
    "n_features": 2,
    "n_informative": 2,
    "n_redundant": 0,
    "n_clusters_per_class": 1,
    "flip_y": 0,
    "random_state": 11,
    "shuffle": False
}
    
def generate_data_for_plotting(class_imbalance, separation, cutoff):
    
    X, y = make_classification(**common_params, weights=[class_imbalance], class_sep = separation)
    
    # fit a logistic regression model - will eventually update to include more classifiers
    logreg = sklearn.linear_model.LogisticRegression()
    logreg.fit(X, y)
    
    probs = logreg.predict_proba(X)[:,1]
    y_pred = [1 if i > cutoff else 0 for i in probs]
    
    # create dataframe for plotting
    df = pd.DataFrame({
        "x1": X[:,0],
        "x2": X[:,1],
        "y": y,
        "y_pred": y_pred,
        "pr": probs
        }).astype({'y': 'category'})
    
    df['color'] = np.select([(df['y'] == df['y_pred']), (df['y'] != df['y_pred'])],
                            ['true','false']
    )

    return logreg, df


In [3]:
### CREATE PLOTS ###

colors = ["#34585F", "#A4B89E", "#C6CDA7", "#EFE0B7", "#BEAC7C"]

def plot_decision_boundary(class_imbalance, separation, cutoff):

    # prep data
    logreg, df = generate_data_for_plotting(class_imbalance, separation, cutoff)
    
    b = logreg.intercept_[0]
    w1, w2 = logreg.coef_.T
    c = -b/w2
    m = -w1/w2
    xmin, xmax = -2.25, 4.25
    ymin, ymax = -3.75, 3.75
    
    # xmin, xmax = df['x1'].min()-1.25, df['x1'].max()+1 
    # ymin, ymax = df['x2'].min()-0.5, df['x2'].max()+0.5 
    
    xd = np.linspace(start=xmin, stop=xmax, num=len(df.index))
    yd = m*xd + c
    
    # plot data
    (ggplot(df, aes(x = 'x1', y = 'x2', fill = 'y')) +
            geom_point(aes(color = 'color'), size = 3.5, alpha = .85) +
            geom_abline(intercept = c,
                slope = m,
                linetype='dotted') +
            geom_ribbon(mapping = aes(x = xd, ymin = yd, ymax = float('inf')), 
                fill = '#83ad76', alpha = .15) +
            geom_ribbon(mapping = aes(x = xd, ymin = yd, ymax = float('-inf')), 
                fill = '#196675', alpha = .1) +
            scale_fill_manual(labels = ['True 0', 'True 1'], values = ['#196675', '#83ad76']) +
            scale_color_manual(labels = ['Incorrect pred'], limits = ['false'], values = ['#e03a2f', 'grey']) +
            scale_x_continuous(limits=(xmin,xmax), expand = (0,0)) +
            scale_y_continuous(limits=(ymin,ymax), expand = (0,0)) +
            labs(x = None,
                 y = None,
                 fill = "",
                 color = "") +
                 labs(title = "Classifier decision boundary",
                      x = r'$x_1$',
                      y = r'$x_2$') +
            coord_flip() +
            theme_minimal() +
            theme(legend_position=(.68, .15),
                  panel_grid_major = element_blank(),
                  legend_background=element_blank(),
                  legend_box_background=element_blank(),
                  legend_direction='horizontal') +
            guides(color = guide_legend(override_aes = {'fill': "white"}))
    ).draw()

def plot_metrics_bar_chart(class_imbalance, separation, cutoff):
    
    # prep data
    logreg, df = generate_data_for_plotting(class_imbalance, separation, cutoff)

    accuracy = metrics.accuracy_score(df['y'], df['y_pred'])
    recall = metrics.recall_score(df['y'], df['y_pred'])
    precision = metrics.precision_score(df['y'], df['y_pred'])

    metrics_df = pd.DataFrame.from_dict({
        'Accuracy': [accuracy],
        'Recall': [recall],
        'Precision': [precision]
    }).melt(var_name='cols', value_name='vals')
    
    metrics_df['vals'] = metrics_df['vals'].round(2)
    
    # plot data
    (ggplot(metrics_df, aes(x='cols', y='vals', fill='cols')) + 
              geom_col() +
              geom_text(aes(label = 'vals'), nudge_y=-.1, color = "white") +
              scale_fill_manual(values = colors, 
                                guide = False) +
              ylim(0,1) +
              labs(title = "Accuracy, precision, and recall",
                   x = " ",
                   y = None) +
              theme_minimal() +
              theme(panel_grid_major = element_blank())
    ).draw()


def plot_roc_curve(class_imbalance, separation, cutoff):
    
    # prep data
    logreg, df = generate_data_for_plotting(class_imbalance, separation, cutoff)

    fpr, tpr, threshold = metrics.roc_curve(df['y'], df['pr'])
    roc_auc = metrics.auc(fpr, tpr).round(3)

    gmeans = sqrt(tpr * (1-fpr))
    ix = argmax(gmeans)
    
    df = pd.DataFrame(dict(fpr = fpr, 
                           tpr = tpr, 
                           roc_auc = roc_auc))
    
    # plot data
    (ggplot(df, aes(x = 'fpr', y = 'tpr')) + 
        geom_line(color = "#34585F", size = 1.5) + 
        geom_abline(linetype = 'dashed') +
        geom_point(aes(x = fpr[ix], y = tpr[ix]), size = 7, color = "#A4B89E", alpha = .2) +
        #geom_text(aes(x = 0.85, y = 0, label = ['AUC: ' + str(roc_auc)])) +
        labs(title = f"ROC AUC = {roc_auc}",
             x = "False positive rate",
             y = "True positive rate") +
        theme_minimal() +
        theme(panel_grid_major = element_blank())
    ).draw()

    
def plot_prob_density(class_imbalance, separation, cutoff):
    
    # prep data
    logreg, df = generate_data_for_plotting(class_imbalance, separation, cutoff)
    
    # plot data
    (ggplot(df, aes(x='pr', fill='y')) + 
     geom_density(alpha=0.5, color='grey') +
     geom_vline(xintercept=cutoff, color='grey', linetype='dashed') +
     scale_fill_manual(labels = ['0', '1'], values = ['#196675', '#83ad76']) +
     labs(title="Predicted probability density",
          x=" ",
          y="Density",
          fill="") +
     theme_minimal() +
     theme(legend_position='none',
           panel_grid_major = element_blank(),
           axis_text_y=element_blank())
    ).draw() 
    

In [4]:
%%html
<style>
.box_style{
    background-color: #f7f7f7;
    font-size: .25rem;
    border-left: 1px solid #cecece;
    border-top: 1px solid #cecece;
    border-bottom: 1px solid #cecece;

}
.plot_style{
    background-color: #fff;
    border-right: 1px solid #cecece;
    border-top: 1px solid #cecece;
    border-bottom: 1px solid #cecece;
}

p{
    font-style: italic;
    font-size: .75rem;
    line-height: .85rem;
}
</style>

In [5]:
#### CREATE WIDGETS ###

# define widgets
class_widget = Dropdown(options=[('Logistic Regression', 1)], #to update with additional classifiers
                                value=1,
                                description='Classifier',
                                layout = Layout(margin = '10px 10px 20px 0',
                                                        padding = '0 50px 0 0'))
cb_widget = FloatSlider(description = "Imbalance",
                        continuous_update = False,
                        min=0.1, max=0.9, step=0.1, value=0.5)
sep_widget = FloatSlider(description = "Separation",
                        continuous_update = False,
                         min=0.5, max=1.5, step=0.1, value=1.0)
cutoff_widget = FloatSlider(description = "Cutoff",
                        continuous_update = False,
                         min=.1, max=.9, step=0.1, value=0.5)


# interact plots with widgets
widget_vars = {'class_imbalance': cb_widget, 
               'separation': sep_widget,
               'cutoff': cutoff_widget}
scatter_plot = interactive_output(plot_decision_boundary, widget_vars)
density_plot = interactive_output(plot_prob_density, widget_vars)
metrics_plot = interactive_output(plot_metrics_bar_chart, widget_vars)
roc_plot = interactive_output(plot_roc_curve, widget_vars)

In [6]:
### SET UP APP ###

header = widgets.HTML("<h1>Visualizing class imbalance and model evaluation metrics</h1>",
                      layout=widgets.Layout(margin='0 0 10px 0'))
footer = widgets.HTML("""<p>Note: Results should be taken with a grain of salt, as this visualization is only meant to convey the intuition behind the 
                      effects of imbalanced data and other parameters on model evaluation metrics.</p>""",
                      layout=widgets.Layout(margin='20px 0 0 0',
                                            max_width='1000px'))

widget_box = VBox([class_widget, cb_widget, sep_widget, cutoff_widget], layout=Layout(padding='15px 0px 3px 3px',
                                                                                #border='solid 1px #cecece', 
                                                                                border_left = 'solid 1px #cecece',
                                                                                flex_flow='row wrap', 
                                                                                align_content='flex-start', 
                                                                                justify_content='flex-start'))
top_box = HBox([scatter_plot, density_plot], layout=Layout(max_width='650px',
                                                                   display = 'flex',
                                                                   flex_flow = 'row',
                                                                   justify_content = 'space-between',
                                                                   align_items = 'center'))
bottom_box = HBox([metrics_plot, roc_plot], layout=Layout(max_width='650px',
                                                                  display = 'flex',
                                                                  flex_flow = 'row',
                                                                  justify_content = 'space-between',
                                                                  align_items = 'center'))
plots_box = VBox([top_box, bottom_box], layout=Layout(max_width='700px',
                                                              #margin='3px',
                                                              padding='10px',
                                                              #max_height='600px',
                                                              #border='solid 1px #cecece',
                                                              justify_content = 'space-between',
                                                              align_items = 'center'))

widget_box.add_class("box_style")
plots_box.add_class("plot_style")

app = AppLayout(header=header,
          left_sidebar=widget_box,
          center=plots_box,
          right_sidebar=None,
          footer=footer,
          pane_widths=['250px', 1, 1],
          pane_heights=['70px', 1, '90px'],
          border='solid 1px #cecece')


In [7]:
### DISPLAY APP ###

display(app)

AppLayout(children=(HTML(value='<h1>Visualizing class imbalance and model evaluation metrics</h1>', layout=Lay…