# Interact with predictions thresholds

In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import pandas as pd
import matplotlib.patches as patches
import itertools
from sklearn import metrics

from ipywidgets import (interact,
    interactive,
    FloatSlider)

In [3]:
def sim_preds(n=100):
    p = np.random.beta(1,2,n)
    u = np.random.uniform(0,1,n)
    y = p >= u
    return p, y*1.0

In [4]:
def plot_confusion_matrix(cm, classes,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues,
                          outcome='Observed',
                          criteria='Predicted',
                          ax=None):
    """
    This function prints and plots the confusion matrix.
    """
    if ax is None:
        fig, ax = plt.subplots(1,1)
    ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.set_title(title)
    tick_marks = np.arange(len(classes))
    ax.set_xticks(tick_marks)
    ax.set_xticklabels(classes)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(classes,rotation=90)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        ax.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    
    ax.grid('off')
    ax.set_ylabel(outcome)
    ax.set_xlabel(criteria)

In [5]:
def plt_pred_obs(p,y,thresh,shape,ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(20,10))
    df = pd.DataFrame({'p':p,'y':y})
    df = df.sort_values('p')
    y_sorted = df['y'].values.reshape(shape[0],shape[1])
    p_sorted = df['p'].values.reshape(shape[0],shape[1])
    ax.imshow(y_sorted, cmap=plt.cm.Blues)
    
    ax.set_xticks([])
    ax.set_yticks([])
    below_thresh = True
    above_thresh = False
    for i, j in itertools.product(range(y_sorted.shape[0]), range(y_sorted.shape[1])):
        ax.text(j, i, "%0.2f"%p_sorted[i, j],
                horizontalalignment="center",
                color="black" if y_sorted[i, j] == 0 else "white")
                #alpha=0.6 if below_thresh else 1)
        
        if thresh <= p_sorted[i, j]:
            above_thresh = True
            
        if above_thresh & below_thresh:
            c = patches.Circle((j, i+0.3), 0.1, facecolor='none', edgecolor='red',
                               linewidth=5)
            ax.add_patch(c)
            
        if thresh < p_sorted[i, j]:
            below_thresh = False
    ax.set_title('Sorted predictions vs observed')

#plt_pred_obs(p,y,0.5,(10,10))

In [6]:
n = 100
p, y = sim_preds(n) 

thesh_w = FloatSlider(min=0, max=1, step=0.01, value=0.5,description='theshold',readout_format='0.2f')

def cm_thresh_plots(thresh):
    pred = p > thresh
    fig, axx = plt.subplots(1,2,figsize=(12,6))
    cm = metrics.confusion_matrix(y, pred)
    plot_confusion_matrix(cm,classes=['0','1'],ax=axx[0])
    plt_pred_obs(p,y,thresh,(10,10),ax=axx[1])
    plt.show()
    
def interact_cm_thresh():
    interact(cm_thresh_plots,
            thresh=thesh_w)
    
interact_cm_thresh()