## HLMA 408: Courbes ROC

***
> __Auteur__: Joseph Salmon
> <joseph.salmon@umontpellier.fr>

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from matplotlib import rc
import seaborn as sns
from scipy.stats import norm
import matplotlib.patches as mpatches
from ipywidgets import interact

In [2]:
rc('font', **{'family': 'sans-serif', 'sans-serif': ['Computer Modern Roman']})
params = {'axes.labelsize': 12,
          'font.size': 10,
          'legend.fontsize': 8,
          'xtick.labelsize': 0,
          'ytick.labelsize': 0,
          'text.usetex': True,
          'figure.figsize': (8, 6)}
plt.rcParams.update(params)

In [3]:
sns.set_context("poster")
sns.set_style("white")
sns.set_palette("colorblind")
sns.despine(ax=None, top=True, right=True, left=True, bottom=True)
color_blind_list = sns.color_palette("colorblind", 8)
my_orange = color_blind_list[2]
my_green = color_blind_list[1]
my_blue = color_blind_list[0]

<Figure size 576x432 with 0 Axes>

In [4]:
xs = np.linspace(-1, 9, num=150)
bbox_props = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.9)  # Boîtes.

In [5]:
def plot_roc_n_graph(q=1.75, mu_normal=1.5, sigma2_normal=1, mu_abnormal=3.2, sigma2_abnormal=1.2):
    y1 = norm.pdf(xs, loc=mu_normal, scale=np.sqrt(sigma2_normal))
    y2 = norm.pdf(xs, loc=mu_abnormal, scale=np.sqrt(sigma2_abnormal))

    fig, ax = plt.subplots(1, 2, figsize=(13, 6))
    FN = 1 - norm.cdf(q, loc=mu_normal, scale=np.sqrt(sigma2_normal))
    FP = 1 - norm.cdf(q, loc=mu_abnormal, scale=np.sqrt(sigma2_abnormal))

    FN_full = 1 - norm.cdf(xs, loc=mu_normal, scale=np.sqrt(sigma2_normal))
    FP_full = 1 - norm.cdf(xs, loc=mu_abnormal, scale=np.sqrt(sigma2_abnormal))

    ax[0].plot(FN_full, FP_full, 'k', linewidth=7)
    ax[0].plot(FN, FP, 'o', c='red', markersize=20)
    ax[0].set_xlabel('FP', fontsize=30)
    ax[0].set_ylabel('1-FN', fontsize=30)
    ax[0].tick_params(axis='both', which='major', labelsize=30)
    ax[0].set_xlim([-0.1, 1.3])
    ax[0].set_ylim([-0.1, 1.3])

    mod_norm = 1. / np.sqrt(2 * 3.14 * sigma2_normal)
    mod_abnorm = 1. / np.sqrt(2 * 3.14 * sigma2_abnormal)

    ax[1].text(mu_normal - 1.9, mod_norm - .21, '$H_0$', ha="center",
               va="bottom", size=20, bbox=bbox_props)
    ax[1].text(mu_abnormal + 1.49, mod_abnorm - 0.1, '$H_1$', ha="center",
               va="bottom", size=20, bbox=bbox_props)
    ax[1].text(q + .2, np.max([y1, y2]) + 0.05, '$q= ' + str(q) + '$',
               fontsize=30, fontweight='bold')

    ax[1].plot(xs, y1, color='k', linewidth=1)
    ax[1].plot(xs, y2, color='k', linewidth=1)

    ax[1].fill_between(xs, y2, where=xs <= q, facecolor=color_blind_list[1])
    ax[1].fill_between(xs, y1, where=xs >= q, facecolor=color_blind_list[0])

    ax[1].set_xlim([xs.min(), xs.max()])
    ax[1].set_ylim([-0.05, np.max([y1, y2]) + .1])
    ax[1].axvline(x=q, color='k', linewidth=5)
    sns.set_style("ticks")
    sns.despine()
    classes = ['FN', 'FP']
    class_colours = [color_blind_list[1], color_blind_list[0]]
    recs = []
    for i in range(0, len(class_colours)):
        recs.append(mpatches.Rectangle((0, 0), 1, 1, fc=class_colours[i]))
    leg = plt.legend(recs, classes, loc=1, fontsize=14)

    filename = 'fig_'
    plt.tight_layout()
    plt.show()

In [6]:
interact(plot_roc_n_graph,
         mu_normal= (1,3,0.5),
         sigma2_normal = (0.2, 2, 0.2),
         mu_abnormal= (1, 5, 0.5),
         sigma2_abnormal = (0.2, 2, 0.2),
         q=(0, 8, 0.05));

interactive(children=(FloatSlider(value=1.75, description='q', max=8.0, step=0.05), FloatSlider(value=1.5, des…