In [20]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

from matplotlib import gridspec
from scipy.stats import norm
from collections import deque

In [21]:
prev_mu = 2
roc = deque(maxlen=200)
pr = deque(maxlen=200)

X = np.linspace(-4,4,81)
pdf = norm.pdf(X)

@widgets.interact(mu=(-1., 4.), theta=(-4.,8.))
def show(mu=prev_mu, theta=-1):
    plt.figure(figsize=(9,6))
    gs = gridspec.GridSpec(2,3)

    ax = plt.subplot(gs[0:3])
    ax.plot(X, pdf, c="tab:blue", label="neg")
    ax.plot(X+mu, pdf, c="tab:orange", label="pos")
    ax.axvline(theta, c='k')
    ax.annotate("neg ◀ ", (theta,.2), ha="right", c="tab:blue")
    ax.annotate(" ▶ pos", (theta,.2), ha="left", c="tab:orange")
    ax.set_ylim(0)
    ax.set_xlim(-4,8)
    ax.legend(loc=1)

    ax = plt.subplot(gs[3])
    ax.set_title("CM")
    ax.set_xticks([])
    ax.set_yticks([])

    TN = norm.cdf(theta)
    FN = norm.cdf(theta - mu)
    FP = 1 - TN
    TP = 1 - FN
    cm = [[TP, FN],
          [FP, TN]]
    
    ax.imshow(cm, cmap="binary", vmin=0, vmax=1)
    ax.text(0, 0, f"TP\n{TP:.3f}", ha="center", bbox=dict(facecolor="tab:orange"))
    ax.text(1, 0, f"FN\n{FN:.3f}", ha="center", bbox=dict(facecolor="tab:orange"))
    ax.text(0, 1, f"FP\n{FP:.3f}", ha="center", bbox=dict(facecolor="tab:blue"))
    ax.text(1, 1, f"TN\n{TN:.3f}", ha="center", bbox=dict(facecolor="tab:blue"))

    global prev_mu
    if prev_mu != mu:
        roc.clear()
        pr.clear()
        prev_mu = mu

    ax = plt.subplot(gs[4])
    ax.set_title("ROC")
    ax.set_xticks(np.linspace(0,1,6))
    ax.set_yticks(np.linspace(0,1,6))
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_xlabel("FPR")
    ax.set_ylabel("TPR")
    ax.grid(lw=.2)
    roc.append([FP,TP])

    auc = 0
    points = [[0,0], *sorted(roc), [1,1]]
    for i in range(len(points)-1):
        x1, y1 = points[i]
        x2, y2 = points[i+1]
        auc += (y1+y2)*(x2-x1)/2

    ax.plot(*np.array(points[1:-1]).T, marker='x', alpha=.6, label=f"ROC-AUC={auc:.3f}")
    ax.scatter(*roc[-1], marker='o', c='r', zorder=10)
    ax.legend(loc=4)

    ax = plt.subplot(gs[5])
    ax.set_title("PR")
    ax.set_xticks(np.linspace(0,1,6))
    ax.set_yticks(np.linspace(0,1,6))
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.grid(lw=.2)
    pr.append([TP, TP/(TP+FP)])

    auc = 0
    points = [[0,1], *sorted(pr), [1,.5]]
    for i in range(len(points)-1):
        x1, y1 = points[i]
        x2, y2 = points[i+1]
        auc += (y1+y2)*(x2-x1)/2

    ax.plot(*np.array(points[1:-1]).T, marker='x', alpha=.6, label=f"PR-AUC={auc:.3f}")
    ax.scatter(*pr[-1], marker='o', c='r', zorder=10)
    ax.legend(loc=4)

    plt.tight_layout()
    plt.show()

interactive(children=(FloatSlider(value=2.0, description='mu', max=4.0, min=-1.0), FloatSlider(value=-1.0, des…