In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import ipywidgets as widgets
from scipy.stats import norm
from collections import deque

In [2]:
prev_mu = 2
roc = deque(maxlen=200)
pr = deque(maxlen=200)

X = np.linspace(-4,4,81)
pdf = norm.pdf(X)

@widgets.interact(mu=(-1., 4.), theta=(-4.,8.))
def show(mu=prev_mu, theta=-1):
    plt.figure(figsize=(9,6))
    gs = gridspec.GridSpec(2,3)

    ax = plt.subplot(gs[0:3])
    ax.plot(X, pdf, c="tab:orange", label="pos.")
    ax.plot(X+mu, pdf, c="tab:blue", label="neg.")
    ax.axvline(theta, c='k')
    ax.annotate('pos ◀ ', (theta,.2), ha="right", c="tab:orange")
    ax.annotate(' ▶ neg', (theta,.2), ha="left", c="tab:blue")
    ax.set_ylim(0)
    ax.set_xlim(-4,8)
    ax.legend(loc=1)

    ax = plt.subplot(gs[3])
    ax.set_title("CM")
    ax.set_xticks([])
    ax.set_yticks([])

    TP = norm.cdf(theta)
    FP = norm.cdf(theta - mu)
    FN = 1 - TP
    TN = 1 - FP
    cm = [[TP, FN],
          [FP, TN]]
    
    ax.imshow(cm, cmap="binary")
    ax.text(0, 0, "TP", ha="center", bbox=dict(facecolor="tab:orange"))
    ax.text(1, 0, "FN", ha="center", bbox=dict(facecolor="tab:orange"))
    ax.text(0, 1, "FP", ha="center", bbox=dict(facecolor="tab:blue"))
    ax.text(1, 1, "TN", ha="center", bbox=dict(facecolor="tab:blue"))

    global prev_mu
    if prev_mu != mu:
        roc.clear()
        pr.clear()
        prev_mu = mu

    ax = plt.subplot(gs[4])
    ax.set_title("ROC")
    ax.set_xticks(np.linspace(0,1,6))
    ax.set_yticks(np.linspace(0,1,6))
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_xlabel("FPR")
    ax.set_ylabel("TPR")
    ax.grid(lw=.2)
    roc.append([FP,TP])
    ax.plot(*np.array(sorted(roc)).T, marker='x')

    ax = plt.subplot(gs[5])
    ax.set_title("PR")
    ax.set_xticks(np.linspace(0,1,6))
    ax.set_yticks(np.linspace(0,1,6))
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.grid(lw=.2)
    pr.append([TP, TP/(TP+FP)])
    ax.plot(*np.array(sorted(pr)).T, marker='x')

    plt.tight_layout()
    plt.show()

interactive(children=(FloatSlider(value=2.0, description='mu', max=4.0, min=-1.0), FloatSlider(value=-1.0, des…