In [11]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from sklearn.metrics import accuracy_score, roc_curve, roc_auc_score, f1_score, precision_score, confusion_matrix, cohen_kappa_score
import ipywidgets as widgets
from ipywidgets import Button, Layout, GridspecLayout


In [12]:
def generate_data(difference, variance, prior, samples):
  np.random.seed(42)
  mu2 = 0.5-difference/2
  mu1 = 0.5+difference/2
  N1 = int(samples*prior)
  N2 = samples - N1
  d1 = np.clip(np.random.normal(loc=mu1, scale=variance, size=N1), 0, 1)
  d2 = np.clip(np.random.normal(loc=mu2, scale=variance, size=N2), 0, 1)
  jitter = np.random.uniform(low=-0.25, high=0.25, size=samples)
  y_true = np.concatenate((np.zeros(N2), np.ones(N1)))
  y_pred = np.concatenate((d2, d1))
  y_true_jitter = y_true + jitter
  return y_true, y_pred, y_true_jitter  

def roc_and_threshold(y_true, y_pred, threshold):
  fpr, tpr, thresholds = roc_curve(y_true, y_pred)
  index = np.nonzero(thresholds < threshold)[0][0]
  op = {'x':fpr[index], 'y':tpr[index], 't':thresholds[index]}
  roc = {'x':fpr, 'y':tpr}
  return roc, op

def calculate_metrics(y_true, y_pred, threshold):
  auc = roc_auc_score(y_true, y_pred)
  y_p_bin = y_pred > threshold
  conf_mat = confusion_matrix(y_true, y_p_bin)
  TN, FP, FN, TP = np.ravel(conf_mat)
  sensitivity = TP/(TP+FN)
  specificity = TN/(TN+FP)
  kappa_score = cohen_kappa_score(y_true, y_p_bin)
  ppv = precision_score(y_true, y_p_bin)
  npv = TN/(TN+FN)
  metrics = {'Accuracy':accuracy_score(y_true, y_p_bin), 'F1-score':f1_score(y_true, y_p_bin),
             'Sensitivity':sensitivity, 'Specificity':specificity, 'PPV':ppv, 'NPV':npv, 'AUC':auc, 'Kappa':kappa_score, 'conf_mat':conf_mat}
  roc, operating_point = roc_and_threshold(y_true, y_pred, threshold)
  return metrics, roc, operating_point

def update_data_and_metrics(difference, variance, prior, samples, threshold):
  y_true, y_pred, y_true_jitter = generate_data(difference, variance, prior, samples)
  metrics, roc, operating_point = calculate_metrics(y_true, y_pred, threshold)
  data = {'y_true':y_true, 'y_pred':y_pred, 'y_true_jitter':y_true_jitter}
  return data, metrics, roc, operating_point

In [13]:

style = {'description_width': 'initial'}
difference = widgets.FloatSlider(value=0.5,min=0.0,max=0.95,step=0.05,description='Class separation',
    continuous_update=False, style=style
)
variance = widgets.FloatSlider(value=0.25, min=0.05, max=0.95, step=0.05, description='Within class spread',
    continuous_update=False, style=style
)
threshold = widgets.FloatSlider(value=0.5, min=0.05, max=0.95, step=0.05, description='Classification threshold',
    continuous_update=False, style=style
)
prior_prob = widgets.FloatSlider(value=0.5, min=0.05, max=0.95, step=0.05, description='Prior probability',
    continuous_update=False, style=style
)
num_samples = widgets.IntSlider(value=1000, min=50, max=10000, step=50, description='Number of samples',
    continuous_update=False, style=style
)


all_widgets = [difference, variance, num_samples, prior_prob, threshold]

def update(change):
  data, metrics, roc, operating_point = update_data_and_metrics(difference.value, variance.value, prior_prob.value, num_samples.value, threshold.value)
  #create_plot_objects(data, metrics, roc, operating_point)
  update_plot_objects(data, metrics, roc, operating_point, axlist)

def create_plot_objects_mpl(data, metrics, roc, operating_point):
  plt.close('all')
  #fig = plt.figure(figsize=(20,8)) 
  fig = plt.figure(figsize=(10,4))
  ax = fig.add_subplot(131)

  threshold = operating_point['t']   
    
  Nneg = np.sum(data['y_true']==0)
  nx = data['y_true_jitter'][0:Nneg]
  ny = data['y_pred'][0:Nneg]

  nind_true = np.argwhere(ny<threshold)
  nind_false = np.argwhere(ny>=threshold)
  nx_true = nx[nind_true]
  nx_false = nx[nind_false]
  ny_true = ny[nind_true]
  ny_false = ny[nind_false]

  px = data['y_true_jitter'][Nneg:]
  py = data['y_pred'][Nneg:]
  pind_true = np.argwhere(py>=threshold)
  pind_false = np.argwhere(py<threshold)  
  px_true = px[pind_true]
  px_false = px[pind_false]
  py_true = py[pind_true]
  py_false = py[pind_false]  

  #plot1 = ax.plot(nx, ny, 'C0', marker='o', ls="", markersize=1)
  #plot11 = ax.plot(px,py, 'C1', marker='o', ls="", markersize=1)

  plot10 = ax.plot(nx_true, ny_true, '#1f78b4', marker='o', ls="", markersize=1)
  plot11 = ax.plot(nx_false, ny_false, '#a6cee3', marker='o', ls="", markersize=1) 
  plot12 = ax.plot(px_true, py_true, '#ff7f00', marker='o', ls="", markersize=1)
  plot13 = ax.plot(px_false, py_false, '#fdbf6f', marker='o', ls="", markersize=1)   

  ax.set_xticks([0, 1])
  ax.set_xlabel('True class label')
  ax.set_ylabel('Model output probability')
  plot2 = ax.plot([-0.25,1.25], [threshold, threshold], color='red')
  ax.set_title('Model predictions')
  plottxt = ax.text(x=0.5, y=operating_point['t']+0.02, s='Threshold', horizontalalignment='center', color='red')

  #plist1 = {'swarm':[plot1[0],plot11[0]],'line':plot2[0], 'text':plottxt}
  plist1 = {'swarm':[plot10[0],plot11[0],plot12[0],plot13[0]],'line':plot2[0], 'text':plottxt}

  ax2 = fig.add_subplot(132)
  plist2 = {}
  for ii, metric in enumerate(metrics.keys()):
    plist2['plot{}'.format(ii)] = {}
    col = 'black'
    if metric != 'conf_mat':
      s = '{:.2f}'.format(metrics[metric])
      sname = metric
      xval = 0.1+0.4  
      if metrics[metric] >= 0.85:
        col = 'green'
      elif ((metrics[metric] >= 0.65) and (metrics[metric] < 0.85)):
        col = 'orange'
      else:
        col = 'red'   
    else:
      TN, FP, FN, TP = np.ravel(metrics[metric])
      s = 'TN: {0: <5}  FP: {1: <5} \nFN: {2: <5}  TP: {3: <5} '.format(TN, FP, FN, TP)
      sname = ' '
      xval = 0.1
    plist2['plot{}'.format(ii)]['name'] = ax2.text(x=0.1, y=0.1+0.1*ii, s=sname)
    plist2['plot{}'.format(ii)]['value'] = ax2.text(x=xval, y=0.1+0.1*ii, s=s, color=col)
    
  plt.axis('off')

  ax3 = fig.add_subplot(133)
  plot3 = ax3.plot(roc['x'], roc['y'])
  plot4 = ax3.plot([0,1],[0,1],color='orange')
  plot5 = ax3.plot([operating_point['x']], [operating_point['y']], markersize=8, color='red', marker="x", ls='')
  ax3.set_xlabel('1-Specificity')
  ax3.set_ylabel('Sensitivity')
  ax3.set_title('ROC curve. AUC = {:.3f}'.format(metrics['AUC']))
  plist3 = {'roc':plot3[0], 'base':plot4[0], 'op':plot5[0]}
  axlist = [ax, ax2, ax3]
  plist  = [plist1, plist2, plist3]
  return fig, axlist, plist

def update_plot_objects_mpl(data, metrics, roc, operating_point, axlist, plist):
  Nneg = np.sum(data['y_true']==0)
  nx = data['y_true_jitter'][0:Nneg]
  ny = data['y_pred'][0:Nneg]
  px = data['y_true_jitter'][Nneg:]
  py = data['y_pred'][Nneg:]  

  threshold = operating_point['t']

  nind_true = np.argwhere(ny<threshold)
  nind_false = np.argwhere(ny>=threshold)
  pind_true = np.argwhere(py>=threshold)
  pind_false = np.argwhere(py<threshold) 

  nx_true = nx[nind_true]
  nx_false = nx[nind_false]
  ny_true = ny[nind_true]
  ny_false = ny[nind_false]
  px_true = px[pind_true]
  px_false = px[pind_false]
  py_true = py[pind_true]
  py_false = py[pind_false] 

  nplot_true = plist[0]['swarm'][0]
  nplot_true.set_data([nx_true, ny_true])
  pplot_true = plist[0]['swarm'][2]
  pplot_true.set_data([px_true,py_true])

  nplot_false = plist[0]['swarm'][1]
  nplot_false.set_data([nx_false, ny_false])
  pplot_false = plist[0]['swarm'][3]
  pplot_false.set_data([px_false,py_false])

  plot2 = plist[0]['line']
  plot2.set_data([[-0.25,1.25], [operating_point['t'], operating_point['t']]])
  plottxt = plist[0]['text']
  plottxt.set_position([0.5, operating_point['t']+0.02])
  
  for ii, metric in enumerate(metrics.keys()):
    curr_plot = plist[1]['plot{}'.format(ii)]
    col = 'black'
    if metric != 'conf_mat':
      s = '{:.2f}'.format(metrics[metric])
      if metrics[metric] >= 0.85:
        col = 'green'
      elif ((metrics[metric] > 0.65) and (metrics[metric] < 0.85)):
        col = 'orange'
      else:
        col = 'red'    
    else:
      TN, FP, FN, TP = np.ravel(metrics[metric])
      s = 'TN: {0: <5}  FP: {1: <5} \nFN: {2: <5}  TP: {3: <5} '.format(TN, FP, FN, TP)
    curr_plot['value'].set_text(s)
    curr_plot['value'].set_color(col)

  plot3 = plist[2]['roc']
  plot3.set_data([roc['x'], roc['y']])

  plot5 = plist[2]['op']
  plot5.set_data([operating_point['x'], operating_point['y']])
  axlist[2].set_title('ROC curve. AUC = {:.3f}'.format(metrics['AUC']))
  return axlist, plist


In [14]:
grid = GridspecLayout(3, 3)

grid[0,0] = widgets.Label(value=r'\(\bf Population \ properties\)')
grid[1,0] = num_samples
grid[2,0] = prior_prob
grid[0,1] = widgets.Label(value=r'\(\bf Model \ strength\)')
grid[1,1] = difference
grid[2,1] = variance
grid[0,2] = widgets.Label(value=r'\(\bf Post-processing\)')
grid[1,2] = threshold

display(grid)
data, metrics, roc, operating_point = update_data_and_metrics(difference.value, variance.value, prior_prob.value, num_samples.value, threshold.value)
#fig, axlist, plist = create_plot_objects(data, metrics, roc, operating_point)
plt.ion()
fig, axlist, plist = create_plot_objects_mpl(data, metrics, roc, operating_point)

def update_plots(figure, axes_list, plot_list):
  def update(change):
    data, metrics, roc, operating_point = update_data_and_metrics(difference.value, variance.value, prior_prob.value, num_samples.value, threshold.value)
    #create_plot_objects(data, metrics, roc, operating_point)
    update_plot_objects_mpl(data, metrics, roc, operating_point, axes_list, plot_list)
    figure.canvas.draw()
    figure.canvas.flush_events()
  return update  



for widget in all_widgets:
  widget.observe(update_plots(figure=fig, axes_list=axlist, plot_list=plist))  
  
#plt.show()

GridspecLayout(children=(Label(value='\\(\\bf Population \\ properties\\)', layout=Layout(grid_area='widget001…

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …