In [1]:
from utils import start
start()

Button(description=u'Click Here To Start!', style=ButtonStyle())

<IPython.core.display.Javascript object>

In [2]:
%matplotlib inline

from IPython.display import Javascript, display
from ipywidgets import widgets
from ipywidgets import HBox, VBox, Label, FloatText, Layout, interact, interactive, fixed, interactive_output
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.stats import norm
import numpy as np
import matplotlib.patches as patches
from params import params_dict

In [3]:
mpl.rcParams.update({'font.size': params_dict['font_size']})

In [4]:
sig_present = widgets.FloatSlider(
    value=params_dict['signal_present']['default'],
    min=params_dict['signal_present']['min'],
    max=params_dict['signal_present']['max'],
    step=0.5,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    description='',
    layout={'width':'200px'},
)

sig_absent = widgets.FloatSlider(
    value=params_dict['signal_absent']['default'],
    min=params_dict['signal_absent']['min'],
    max=params_dict['signal_absent']['max'],
    step=0.5,
    description='',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout={'width':'200px'}
)

std = widgets.FloatSlider(
    value=params_dict['std']['default'],
    min=params_dict['std']['min'],
    max=params_dict['std']['max'],
    step=0.5,
    description='',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout={'width':'200px'}
)

crit = widgets.FloatSlider(
    value=params_dict['crit']['default'],
    min=params_dict['crit']['min'],
    max=params_dict['crit']['max'],
    step=0.5,
    description='',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout={'width':'200px'}
)

In [5]:
sig_present_label = Label(params_dict['signal_present']['label'], layout={'width':'200px'})
sig_absent_label = Label(params_dict['signal_absent']['label'], layout={'width':'200px'})
std_label = Label(params_dict['std']['label'], layout={'width':'200px'})
crit_label = Label(params_dict['crit']['label'], layout={'width':'200px'})

In [6]:
hit_rate = widgets.BoundedFloatText(
    value=0.0,
    disabled=False,
    min=0.0,
    max=1.0,
    step=0.05,
    color='black',
    layout={'width': '100px'},
)

miss_rate = widgets.BoundedFloatText(
    value=0.0,
    disabled=False,
    min=0.0,
    max=1.0,
    step=0.05,
    layout={'width': '100px'},
)

fp_rate = widgets.BoundedFloatText(
    value=0.0,
    disabled=False,
    min=0.0,
    max=1.0,
    step=0.05,
    layout={'width': '100px'},
)

cr_rate = widgets.BoundedFloatText(
    value=0.0,
    disabled=False,
    min=0.0,
    max=1.0,
    step=0.05,
    layout={'width': '100px'},
)

dprime = widgets.FloatText(
    value=0.0,
    disabled=True,
    layout={'width': '100px'},
)

In [7]:
hit_rate_label = Label('Hits', layout={'width': '200px'})
miss_rate_label = Label('Misses', layout={'width': '200px'})
fp_rate_label = Label('False Positives', layout={'width': '200px'})
cr_rate_label = Label('Correct Rejections', layout={'width': '200px'})
dp_label = Label('d\'')

In [8]:
hit_rate_label.add_class('hrl')
fp_rate_label.add_class('fprl');

In [9]:
dprime.add_class('disabled-font');

In [10]:
# initialize hits, misses, etc
mr = norm.cdf(crit.value, loc=sig_present.value, scale=std.value)
hr = 1 - mr

cr = norm.cdf(crit.value, loc=sig_absent.value, scale=std.value)
fp = 1 - cr

hit_rate.value = '%.2f' % (hr)
miss_rate.value = '%.2f' % (mr)
cr_rate.value = '%.2f' % (cr)
fp_rate.value = '%.2f' % (fp)

In [11]:
rates_layout = widgets.Layout(width='100%',
                              grid_template_rows='auto auto',
                              grid_template_columns='20% 30% 20% 30%')

rate_widget = widgets.GridBox(children = [hit_rate_label, hit_rate,
                              fp_rate_label, fp_rate,
                              miss_rate_label, miss_rate,
                              cr_rate_label, cr_rate],
                              layout=rates_layout)

slider_layout = widgets.Layout(width='100%',
                              grid_template_rows='auto auto',
                              grid_template_columns='20% 30% 20% 30%')

slider_widget = widgets.GridBox(children=[sig_absent_label, sig_absent,
                                          sig_present_label, sig_present,
                                          std_label, std,
                                          crit_label, crit],
                               layout=slider_layout)

In [12]:
def hit_rate_observer(ev):
    miss_rate.value = '%.2f' % (1 - ev.new)
    
    curr_crit = crit.value
    std_crit = norm.ppf(ev.new, scale=std.value)
    sig_present.value = curr_crit + std_crit

hit_rate.observe(hit_rate_observer, names='value')

In [13]:
def miss_rate_observer(ev):
    hit_rate.value = '%.2f' % (1 - ev.new)
    
    curr_crit = crit.value
    std_crit = norm.ppf(1 - ev.new, scale=std.value)
    sig_present.value = curr_crit + std_crit

miss_rate.observe(miss_rate_observer, names='value')

In [14]:
def fp_rate_observer(ev):
    cr_rate.value = '%.2f' % (1 - ev.new)
    
    curr_crit = crit.value
    std_crit = norm.ppf(ev.new, scale=std.value)
    sig_absent.value = curr_crit + std_crit

fp_rate.observe(fp_rate_observer, names='value')

In [15]:
def cr_rate_observer(ev):
    fp_rate.value = '%.2f' % (1 - ev.new)
    
    curr_crit = crit.value
    std_crit = norm.ppf(1 - ev.new, scale=std.value)
    sig_absent.value = curr_crit + std_crit

cr_rate.observe(cr_rate_observer, names='value')

In [16]:
def two_curve_slider(mean_signal_present, mean_signal_absent, standard_deviation, criterion):    
    fig, ax = plt.subplots(figsize=(18,5))

    minval = 0
    maxval = 40
    
    xs = np.linspace(minval, maxval, 1000)
    curve_1 = gaussian(xs, mean_signal_present, standard_deviation)
    curve_2 = gaussian(xs, mean_signal_absent, standard_deviation)
    
    idx = np.argmin(np.abs(xs - criterion))
    
    truncated_curve_1 = curve_1[idx:]
    truncated_curve_2 = curve_2[idx:]
    
    truncated_xs = np.linspace(criterion, maxval, truncated_curve_1.shape[0])
    plt.fill_between(truncated_xs, truncated_curve_2, color='red', alpha=0.4);
    plt.fill_between(truncated_xs, truncated_curve_1, color='green', alpha=0.4);
    
    # update values #######
    dp = (mean_signal_present - mean_signal_absent)/float(standard_deviation)
    dprime.value = '%.2f' % (dp)
    
    hit_rate.unobserve_all()
    miss_rate.unobserve_all()
    fp_rate.unobserve_all()
    cr_rate.unobserve_all()
    
    mr = norm.cdf(criterion, loc=mean_signal_present, scale=standard_deviation)
    hit_rate.value = '%.2f' % (1 - mr)
    miss_rate.value = '%.2f' % (mr)
    
    cr = norm.cdf(criterion, loc=mean_signal_absent, scale=standard_deviation)
    fp_rate.value = '%.2f' % (1 - cr)
    cr_rate.value = '%.2f' % (cr)
    
    
    hit_rate.observe(hit_rate_observer, names='value')
    miss_rate.observe(miss_rate_observer, names='value')
    fp_rate.observe(fp_rate_observer, names='value')
    cr_rate.observe(cr_rate_observer, names='value')
    #############

    plt.axvline(criterion, c='blue', linewidth=3, linestyle='dashed');
    plt.axvline(mean_signal_present, c='green', alpha=0.2)
    plt.axvline(mean_signal_absent, c='red', alpha=0.2)
    plt.plot(xs, curve_1, color='k', linewidth=4);
    plt.plot(xs, curve_2, color='k', linewidth=4);
    plt.xlim((minval, maxval))
    plt.xticks(np.linspace(minval, maxval, 5), np.linspace(minval, maxval, 5))

    ylims = plt.ylim()
    plt.ylim([0, ylims[1]])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    plt.ylabel(r'$P(x)$')
    plt.xlabel(r'$x$')
    plt.title('Normal Distributions')

In [17]:
def gaussian(x, mu, sig):
    unnormed = np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))
    scaling = np.sqrt(2*np.pi)*sig
    return unnormed / scaling  

In [18]:
## ROC Widget
def plot_ROC(mean2, mean1, sigma, crit):
    fig, ax = plt.subplots(figsize=(8,8))
    
    crits = np.linspace(mean1 - mean1 * 3, mean2 + mean2 * 3, 100)
    hits = 1 - norm.cdf(crits,
                       loc=mean2,
                       scale=sigma)
    fps = 1 - norm.cdf(crits,
                       loc=mean1,
                       scale=sigma)
    
    plt.plot(fps, hits, color='k', linewidth=4)
    plt.plot(hits,hits, color='k', alpha=0.2)
    
    nearest = np.argmin(np.abs(crit-crits))
    plt.plot(fps[nearest], hits[nearest], 'o', markersize=16, color='blue')
    plt.xticks(np.linspace(0,1,5))
    plt.yticks(np.linspace(0,1,5))
    
    bar_offset = 0.07

    plt.plot([fps[nearest], fps[nearest]],
             [-bar_offset, hits[nearest]],
             linestyle='dashed', alpha=0.5, color='red')
    
    plt.plot([-bar_offset, fps[nearest]],
             [hits[nearest], hits[nearest]],
             linestyle='dashed', alpha=0.5, color='green')
    
    red_rect = patches.Rectangle((0,-0.1),fps[nearest],bar_offset,facecolor='r')
    green_rect = patches.Rectangle((-0.1,0),bar_offset,hits[nearest],facecolor='g')
    
    ax.add_patch(red_rect)
    ax.add_patch(green_rect)
    plt.title('ROC Curve')
    plt.xlabel('False Positives')
    plt.ylabel('Hits')
    plt.ylim((-bar_offset, 1.02))
    plt.xlim((-bar_offset, 1.02))
    
roc_widget = interactive_output(plot_ROC, 
         {'mean2':sig_present,
          'mean1':sig_absent,
          'sigma':std,
          'crit':crit})


In [19]:
box_layout = Layout(display='flex',
                    flex_flow='column',
                    align_items='center',
                    width='50%')

In [20]:
plot = interactive_output(two_curve_slider, 
         {'mean_signal_present':sig_present,
          'mean_signal_absent':sig_absent,
          'standard_deviation':std,
          'criterion':crit})

sliders_and_plot = VBox([plot, slider_widget, rate_widget])
roc_and_dprime = VBox([roc_widget, HBox([dp_label, dprime])], layout=box_layout)
plot_and_roc = HBox([sliders_and_plot, roc_and_dprime])

In [21]:
display(plot_and_roc)

HBox(children=(VBox(children=(Output(), GridBox(children=(Label(value=u'Signal Absent', layout=Layout(width=u'â€¦

In [22]:
%%html
<style>
.hrl {
    font-weight: bold;
    color: green;
}
.fprl {
    font-weight: bold;
    color: red;
}
.disabled-font input[type="number"]:disabled {
    opacity: 1;
    font-weight: bold;
}
</style>

In [23]:
from IPython.display import HTML
HTML('''<script>
code_show=true; 
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<a href="javascript:code_toggle()">Code Visibility Toggle</a>''')