## Sensitivity vs Specificity vs Accuracy vs Precision

In [None]:
from ipywidgets import interact
from ipywidgets import Output
from IPython.display import clear_output
from IPython.display import display
from matplotlib import pyplot as plt
import ipywidgets
import matplotlib.patches as patches
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd

In [None]:
def get_confusion_matrix(sensitivity, specificity, prevalence):
    return np.array([
        [
            sensitivity * prevalence,
            (1 - specificity) * (1 - prevalence),
        ],
        [
            (1 - sensitivity) * prevalence,
            specificity * (1 - prevalence),
        ]
    ])

@interact(
    prevalence=(0., 1., 0.01),
    sensitivity=(0., 1., 0.01),
    specificity=(0., 1., 0.01),
)
def f(
    prevalence=0.1,
    sensitivity=0.99,
    specificity=0.95,
):
    confusion_matrix = get_confusion_matrix(sensitivity, specificity, prevalence)
    df = (
        pd.DataFrame(
            [
                ['True positive', confusion_matrix[0, 0]],
                ['True negative', confusion_matrix[1, 1]],
                ['False positive', confusion_matrix[0, 1]],
                ['False negative', confusion_matrix[1, 0]],
                ['Accuracy', confusion_matrix[0, 0] + confusion_matrix[1, 1]],
                ['Precision', confusion_matrix[0, 0] / (confusion_matrix[0, 0] + confusion_matrix[0, 1])],
            ],
            columns=['', ' ']  # TODO: figure out better way to hide column names.
        )
        .applymap(lambda x: '{:.2f}%'.format(x * 100) if isinstance(x, float) else x)
    )
    display(df.style.hide_index())

In [None]:
def get_confusion_matrix(sensitivity, specificity, prevalence):
    return np.array([
        [
            sensitivity * prevalence,
            (1 - specificity) * (1 - prevalence),
        ],
        [
            (1 - sensitivity) * prevalence,
            specificity * (1 - prevalence),
        ]
    ])

@interact(
    prevalence=(0., 100., 0.1),
)
def f(
    prevalence=50.,
):
    prevalence /= 100
    x_vals = np.arange(0, 1.4, 0.2)
    y_vals = np.arange(0, 1.4, 0.2)
    x_grid, y_grid = np.meshgrid(x_vals, y_vals)
    z_grids = get_confusion_matrix(x_grid, y_grid, prevalence)
    x_ravel = x_grid[:-1,:-1].ravel()
    y_ravel = y_grid[:-1,:-1].ravel()

    with np.errstate(all='ignore'):
        precision_grid = z_grids[0, 0] / (z_grids[0, 0] + z_grids[0, 1])
    accuracy_grid = z_grids[0, 0] + z_grids[1, 1]
    
    def do_plot(ax, z_grid, title):
        ax.set_xlabel('sensitivity')
        ax.set_ylabel('specificity')
        ax.set_title(title)
        ax.pcolormesh(x_grid * 100, y_grid * 100, z_grid, vmin=0, vmax=1, cmap='cool')
        
        # TODO: center labels.
        ax.set_xlim(0, 119)
        ax.set_ylim(0, 119)

        for x, y, z in zip(x_ravel, y_ravel, z_grid[:-1,:-1].ravel()):
            ax.text(
                x * 100 + 10,
                y * 100 + 10,
                '{:.1f}%'.format(z * 100),
                horizontalalignment='center',
                verticalalignment='center',
                color='black',
            )

    fig, ((ax0, ax1, ax2), (ax3, ax4, ax5)) = plt.subplots(2, 3, figsize=(12, 8))
    fig.set_tight_layout(True)

    do_plot(ax0, z_grids[0, 0], 'True positive')
    do_plot(ax1, z_grids[0, 1], 'False positive')
    do_plot(ax3, z_grids[1, 0], 'False negative')
    do_plot(ax4, z_grids[1, 1], 'True negative')
    do_plot(ax2, accuracy_grid, 'Accuracy')
    do_plot(ax5, precision_grid, 'Precision')