In [1]:
from typing import List
from matplotlib import pyplot as plt
import numpy as np

In [2]:
WASA_THRESHOLD = 0.93
BALANCE_WEIGHTS = True

### Comparison: WASA93, ROC AUC, Cohen's Kappa

* WASA93: <u>W</u>ake <u>A</u>ccuracy when <u>S</u>leep <u>A</u>ccuracy is fixed at <u>93%</u>

In [3]:
from pisces.utils import pad_to_hat, plot_scores_CDF, plot_scores_PDF, add_rocs
from sklearn.metrics import roc_auc_score, roc_curve, cohen_kappa_score


def split_analysis(y, y_hat_sleep_proba, sleep_accuracy: float = WASA_THRESHOLD, balancing: bool = BALANCE_WEIGHTS):

    y_flat = y.reshape(-1,)
    n_sleep = np.sum(y_flat > 0)
    n_wake = np.sum(y_flat == 0)
    N = n_sleep + n_wake

    balancing_weights_ignore_mask = np.where(y_flat > 0, N / n_sleep, N / n_wake) \
        if balancing else np.ones_like(y_flat)
    balancing_weights_ignore_mask /= np.sum(balancing_weights_ignore_mask) # sums to 1.0

    # adjust y to match the lenght of y_hat, which was padded to fit model constraints
    y_padded = pad_to_hat(y_flat, y_hat_sleep_proba)
    # make a mask to ignore the padded values, so they aren't counted against us
    mask = pad_to_hat(balancing_weights_ignore_mask, y_hat_sleep_proba)

    # also ignore any unscored or missing values.
    y_to_score = pad_to_hat(y_flat >= 0, y_hat_sleep_proba)
    mask *= y_to_score
    # roc_auc will complain if -1 is in y_padded
    y_padded *= y_to_score 

    # ROC analysis
    fprs, tprs, thresholds = roc_curve(y_padded, y_hat_sleep_proba, sample_weight=mask)

    # Sleep accuracy = (n sleep correct) / (n sleep) = TP/AP = TPR
    wasa_threshold = thresholds[np.sum(tprs <= sleep_accuracy)]
    y_guess = y_hat_sleep_proba > wasa_threshold

    # # WASA X
    guess_right = y_guess == y_padded
    y_wake = y_padded == 0
    wake_accuracy = np.sum(y_wake * guess_right * y_to_score) / np.sum(n_wake)
     
    return {
        "y_padded": y_padded,
        "y_hat": y_hat_sleep_proba,
        "mask": mask,
        "kappa": cohen_kappa_score(y_padded, y_guess, sample_weight=mask),
        "auc": roc_auc_score(y_padded, y_hat_sleep_proba, sample_weight=mask),
        "roc_curve": {"tprs": tprs,
                      "fprs": fprs,
                      "thresholds": thresholds
        }, 
        f"wasa{int(100 * sleep_accuracy)}_threshold": wasa_threshold,
        f"wasa{int(100 * sleep_accuracy)}": wake_accuracy, 
    }


In [4]:
import matplotlib.pyplot as plt
from pisces.experiments import DataSetObject


sets = DataSetObject.find_data_sets("../data_sets")
walch = sets['walch_et_al']
hybrid = sets['hybrid_motion']

In [5]:
from pisces.experiments import MOResUNetPretrained, evaluate_mo_on_data_set


In [6]:
hfd = sets['henry_ford_disordered']
mo = MOResUNetPretrained(sampling_hz=32)

In [7]:
evaluations_hfd, mo_preprocessed_data_hfd = evaluate_mo_on_data_set(mo, hfd)

Using 16 of 16 cores (100%) for parallel preprocessing.
This can cause memory or heat issues if  is too high; if you run into problems, call prepare_set_for_training() again with max_workers = -1, going more negative if needed. (See the docstring for more info.)


getting needed X, y for AWS001
sampling_period_s: 0.0
getting needed X, y for AWS017
sampling_period_s: 0.0
getting needed X, y for AWS018
sampling_period_s: 0.0
getting needed X, y for AWS019




sampling_period_s: 0.019994020462036133
resampling to 32Hz (0.03125s) from 50 Hz (0.01999s)
getting needed X, y for AWS006
getting needed X, y for AWS002
sampling_period_s: 0.0
getting needed X, y for AWS020
sampling_period_s: 0.0
getting needed X, y for AWS021
sampling_period_s: 0.0
getting needed X, y for AWS022
sampling_period_s: 0.0
getting needed X, y for AWS023
sampling_period_s: 0.0
getting needed X, y for AWS024
sampling_period_s: 0.0
getting needed X, y for AWS025
getting needed X, y for AWS007




sampling_period_s: 0.0
getting needed X, y for AWS026
getting needed X, y for AWS005
sampling_period_s: 0.0
getting needed X, y for AWS028
getting needed X, y for AWS003
sampling_period_s: 0.0
getting needed X, y for AWS009
getting needed X, y for AWS029
sampling_period_s: 0.0
getting needed X, y for AWS030
sampling_period_s: 0.0
getting needed X, y for AWS031
sampling_period_s: 0.0
sampling_period_s: 0.019951820373535156
resampling to 32Hz (0.03125s) from 50 Hz (0.01995s)
sampling_period_s: 0.0
sampling_period_s: 0.0
sampling_period_s: 0.0
getting needed X, y for AWS032
getting needed X, y for AWS033
getting needed X, y for AWS034
getting needed X, y for AWS035




getting needed X, y for AWS010
getting needed X, y for AWS011
sampling_period_s: 0.019952058792114258
resampling to 32Hz (0.03125s) from 50 Hz (0.01995s)
sampling_period_s: 0.0
getting needed X, y for AWS016
getting needed X, y for AWS008
getting needed X, y for AWS013
sampling_period_s: 0.020028114318847656
resampling to 32Hz (0.03125s) from 49 Hz (0.02003s)
sampling_period_s: 0.020013093948364258
resampling to 32Hz (0.03125s) from 49 Hz (0.02001s)
getting needed X, y for AWS014
sampling_period_s: 0.020016193389892578
resampling to 32Hz (0.03125s) from 49 Hz (0.02002s)
sampling_period_s: 0.0
sampling_period_s: 0.020027875900268555
resampling to 32Hz (0.03125s) from 49 Hz (0.02003s)
getting needed X, y for AWS015
sampling_period_s: 0.019958972930908203
sampling_period_s: 0.0
resampling to 32Hz (0.03125s) from 50 Hz (0.01996s)
sampling_period_s: 0.0
sampling_period_s: 0.020013093948364258
getting needed X, y for AWS004
resampling to 32Hz (0.03125s) from 49 Hz (0.02001s)
getting needed X

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


TypeError: 'NoneType' object is not subscriptable