In [1]:
import os
import sys
import pickle
import time

PROJECT_ROOT = os.path.abspath('..')
sys.path.append(PROJECT_ROOT)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import pandas as pd
import scipy
from scipy import stats
from tqdm import tqdm

from sleeprnn.helpers.reader import load_dataset
from sleeprnn.common import constants, viz, pkeys
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection.postprocessor import PostProcessor
from sleeprnn.detection.predicted_dataset import PredictedDataset
from sleeprnn.detection.feeder_dataset import FeederDataset
from sleeprnn.detection import det_utils
from figs_thesis import fig_utils

viz.notebook_full_width()

param_filtering_fn = fig_utils.get_filtered_signal_for_event
param_frequency_fn = fig_utils.get_frequency_by_fft
param_amplitude_fn = fig_utils.get_amplitude_event

RESULTS_PATH = os.path.join(PROJECT_ROOT, 'results')
LETTERS = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P']

  return f(*args, **kwds)


In [2]:
def load_predictions(parts_to_load, dataset, thr=0.5, verbose=True):
    if thr == 0.5:
        extra_str = ''
    else:
        extra_str = '_%1.2f' % thr
    pred_objects = []
    for part in parts_to_load:
        filepath = os.path.join(
            RESULTS_PATH, 'predictions_nsrr_ss',
            'ckpt_20210716_from_20210529_thesis_indata_5cv_e1_n2_train_moda_ss_ensemble_to_e1_n2_train_nsrr_ss',
            'v2_time',
            'prediction%s_part%d.pkl' % (extra_str, part)
        )
        with open(filepath, 'rb') as handle:
            pred_object = pickle.load(handle)
        pred_object.set_parent_dataset(dataset)
        pred_objects.append(pred_object)
    return pred_objects

# Load NSRR dataset and pre-computed predicted dataset

In [3]:
parts_to_load = [0]  # 0 to 11

nsrr = load_dataset(constants.NSRR_SS_NAME, load_checkpoint=True)
pred_objects_1 = load_predictions(parts_to_load, nsrr)
pred_objects_0 = load_predictions(parts_to_load, nsrr, thr=0.25)

Dataset nsrr_ss with 3 patients.
Loading from checkpoint... Loaded
Global STD: None
Dataset nsrr_ss with 11593 patients.


In [4]:
# Filenames of dataset checkpoints
byevent_proba_ckpt_path = os.path.join(
    RESULTS_PATH, 'predictions_nsrr_ss',
    'ckpt_20210716_from_20210529_thesis_indata_5cv_e1_n2_train_moda_ss_ensemble_to_e1_n2_train_nsrr_ss',
    'v2_time',
    'table_byevent_proba.csv'
)

In [25]:
params_load_checkpoint = False

# ############################

if params_load_checkpoint:
    print("Loading from checkpoint")
    table_byevent_proba = pd.read_csv(byevent_proba_ckpt_path)

else:
    # Perform computation and save checkpoint
    table_byevent_proba = {
        'subject_id': [],
        'center_sample': [],
        'prediction_part': [],
        'category': [],
        'probability': [],
        'duration': [], 
        #'frequency': [],
        #'amplitude_pp': [],
        #'amplitude_rms': [],
        #'c10_density': [],
        #'c20_density': [],
    }

    min_n2_minutes = 60
    verbose_min_minutes = False

    start_time = time.time()
    print("Generating table of parameters")
    n_parts = len(pred_objects_1)
    
    counter = 0
    
    for part_id in range(n_parts):
        predictions_1 = pred_objects_1[part_id]
        predictions_0 = pred_objects_0[part_id]
        print("Processing Part %d / %d" % (part_id + 1, n_parts))
        for i_subject in range(10):
        # for i_subject in tqdm(range(100)):
            subject_id = predictions_1.all_ids[i_subject]
            n2_pages = predictions_1.data[subject_id]['n2_pages']
            n2_minutes = n2_pages.size * nsrr.original_page_duration / 60
            if n2_minutes < min_n2_minutes:
                if verbose_min_minutes:
                    print("Skipped by N2 minutes: Subject %s with %d N2 minutes" % (subject_id, n2_minutes))
                continue

            # Class 1 spindles (real):
            marks_1 = predictions_1.get_subject_stamps(subject_id, pages_subset='wn')
            # Class 0 "spindles" (false):
            marks_0 = predictions_0.get_subject_stamps(subject_id, pages_subset='wn')
            # Let only those class 0 without intersecting class 1
            # If marks_1.size = 0 then marks_0 is by definition not intersecting
            if marks_1.size > 0:
                ov_mat = utils.get_overlap_matrix(marks_0, marks_1)
                is_intersecting = ov_mat.sum(axis=1)
                marks_0 = marks_0[is_intersecting == 0]
            if (marks_1.size + marks_0.size) == 0:
                continue  # There are no marks to work with
            
            #print(marks_0.shape[0], marks_1.shape[0])
            
            # Now only keep N2 stage marks
            n2_pages = predictions_1.data[subject_id]['n2_pages']
            print(n2_pages.shape)
            page_size = int(nsrr.fs * nsrr.original_page_duration)
            if marks_1.size > 0:
                #print(marks_1.shape)
                marks_1 = utils.extract_pages_for_stamps(marks_1, n2_pages, page_size)
                #print(marks_1.shape)
            #if marks_0.size > 0:
            #    marks_0 = utils.extract_pages_for_stamps(marks_0, n2_pages, page_size)
            if (marks_1.size + marks_0.size) == 0:
                continue  # There are no marks to work with
                
            print(marks_0.shape[0], marks_1.shape[0])
            
            marks = []
            marks_class = []
            if marks_1.size > 0:
                marks.append(marks_1)
                marks_class.append([1] * marks_1.shape[0])
            if marks_0.size > 0:
                marks.append(marks_0)
                marks_class.append([0] * marks_0.shape[0])
            marks = np.concatenate(marks, axis=0).astype(np.int32)
            marks_class = np.concatenate(marks_class).astype(np.int32)
            n_marks = marks.shape[0]
            counter += n_marks
            
            # Extract proba
            subject_proba = predictions_1.get_subject_probabilities(subject_id, return_adjusted=False)
            marks_proba = det_utils.get_event_probabilities(marks, subject_proba, downsampling_factor=8, proba_prc=75)
            marks_proba = marks_proba.astype(np.float32)
            
            # Parameters
            duration = (marks[:, 1] - marks[:, 0] + 1) / nsrr.fs
            
            table_byevent_proba['subject_id'].append([subject_id] * n_marks)
            table_byevent_proba['center_sample'].append(marks.mean(axis=1).astype(np.int32))
            table_byevent_proba['prediction_part'].append(np.array([part_id] * n_marks, dtype=np.int32))
            table_byevent_proba['category'].append(marks_class)
            table_byevent_proba['probability'].append(marks_proba)
            table_byevent_proba['duration'].append(duration)
            
    for key in table_byevent_proba:
        table_byevent_proba[key] = np.concatenate(table_byevent_proba[key])
    table_byevent_proba = pd.DataFrame.from_dict(table_byevent_proba)
    print("Done.") 
    
print(counter)
print(table_byevent_proba.shape)

Generating table of parameters
Processing Part 1 / 1
145 1778
(326,)
(1778, 2)
(1778, 2)
145 1778
207 1642
(435,)
(1642, 2)
(1642, 2)
207 1642
212 1316
(445,)
(1316, 2)
(1316, 2)
212 1316
206 1113
(428,)
(1113, 2)
(1113, 2)
206 1113
266 1072
(476,)
(1072, 2)
(1072, 2)
266 1072
326 2114
(490,)
(2114, 2)
(2114, 2)
326 2114
188 1601
(487,)
(1601, 2)
(1601, 2)
188 1601
190 1159
(393,)
(1159, 2)
(1159, 2)
190 1159
104 384
(366,)
(384, 2)
(384, 2)
104 384
213 1154
(486,)
(1154, 2)
(1154, 2)
213 1154
Done.
15390
(15390, 6)


In [29]:
marks_1.shape

(1154, 2)

In [35]:
a = utils.extract_pages_for_stamps(marks_1, [10], page_size)
a.shape

(4, 2)

In [36]:
a

array([[61544, 61695],
       [62560, 62639],
       [64144, 64319],
       [65272, 65423]], dtype=int32)

In [None]:
plt.scatter(table_byevent_proba.probability, table_byevent_proba.category)
plt.show()

In [None]:
locs = table_byevent_proba.category == 0
plt.scatter(table_byevent_proba.probability[locs], table_byevent_proba.duration[locs], alpha=0.5)
plt.show()

In [None]:
table_byevent_proba[
    (table_byevent_proba.category == 0) & (table_byevent_proba.probability > 0.6)
].sort_values(by="probability", ascending=False)

In [None]:
table_byevent_proba[
    (table_byevent_proba.subject_id == "ccshs-trec-1800065") 
    & (table_byevent_proba.center_sample > 1692000)
    & (table_byevent_proba.center_sample < 1692400)
]

In [None]:
# visualize
loc_to_viz = 5197
window_duration = 20

#
subject_info = table_byevent_proba.loc[loc_to_viz]
print(subject_info)
subject_data = nsrr.read_subject_data(subject_info.subject_id, exclusion_of_pages=False)
signal = subject_data['signal']
predictions = pred_objects_1[subject_info.prediction_part]
center_sample = subject_info.center_sample
start_sample = int(center_sample - window_duration * nsrr.fs // 2)
end_sample = int(start_sample + window_duration * nsrr.fs)
proba = predictions.get_subject_probabilities(
    subject_info.subject_id, )
proba_up = np.repeat(proba, 8)
time_axis = np.arange(start_sample, end_sample) / nsrr.fs
n2_pages = predictions.data[subject_info.subject_id]['n2_pages']
n2_pages_vector = np.zeros(signal.shape, dtype=np.int32)
page_size = int(nsrr.original_page_duration * nsrr.fs)
for p in n2_pages:
    start_page = p * page_size
    end_page = start_page + page_size
    n2_pages_vector[start_page:end_page] = 1

fig, ax = plt.subplots(1, 1, figsize=(12, 2.5), dpi=140)
ax.plot(time_axis, signal[start_sample:end_sample], linewidth=.6)
ax.fill_between(
    time_axis,
    200 * (1 - n2_pages_vector[start_sample:end_sample]),
    -200 * (1 - n2_pages_vector[start_sample:end_sample]),
    facecolor="k", alpha=0.1
)

ax.fill_between(
    time_axis, 
    -300 - 50 * proba_up[start_sample:end_sample], 
    -300 + 50 * proba_up[start_sample:end_sample],
    color=viz.PALETTE['red'], alpha=1.0
)
ax.axhline(-300 - 50, linewidth=0.7, linestyle="-", color="k")
ax.axhline(-300 + 50, linewidth=0.7, linestyle="-", color="k")
ax.axhline(-300 - 25, linewidth=0.7, linestyle="--", color="k")
ax.axhline(-300 + 25, linewidth=0.7, linestyle="--", color="k")
ax.axhline(-300 + 0, linewidth=0.7, linestyle="-", color="k")
ax.set_ylim([-400, 200])
ax.set_xlim([start_sample/nsrr.fs, end_sample/nsrr.fs])

ax.grid()
ax.set_xlabel("Time (s)", fontsize=8)
ax.tick_params(labelsize=8)
title_str = 'Subject %s. Loc %d. Center category %d' % (subject_info.subject_id, loc_to_viz, subject_info.category)
ax.set_title(title_str)
plt.tight_layout()
plt.show()

In [None]:
signal.shape