In [1]:
import os
import sys
import pickle

PROJECT_ROOT = os.path.abspath('..')
sys.path.append(PROJECT_ROOT)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy
from tqdm import tqdm

from sleeprnn.helpers.reader import load_dataset
from sleeprnn.common import constants, viz, pkeys
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection.postprocessor import PostProcessor
from sleeprnn.detection.predicted_dataset import PredictedDataset
from sleeprnn.detection.feeder_dataset import FeederDataset
from sleeprnn.detection import det_utils
from figs_thesis import fig_utils

viz.notebook_full_width()

param_filtering_fn = fig_utils.get_filtered_signal_for_event
param_frequency_fn = fig_utils.get_frequency_by_fft
param_amplitude_fn = fig_utils.get_amplitude_event

RESULTS_PATH = os.path.join(PROJECT_ROOT, 'results')

  return f(*args, **kwds)


# Max amplitude preliminaries (MASS-KC)

In [None]:
mass = load_dataset('mass_kc')
signals = mass.get_signals(normalize_clip=False)
marks = mass.get_stamps(pages_subset='n2')

In [None]:
all_kc_min = []
all_kc_max = []
for sub_signal, sub_marks in zip(signals, marks):
    kc = [sub_signal[m[0]:m[1]+1] for m in sub_marks]
    kc_min = np.array([np.min(s) for s in kc])
    kc_max = np.array([np.max(s) for s in kc])
    all_kc_min.append(kc_min)
    all_kc_max.append(kc_max)
all_kc_min = np.concatenate(all_kc_min)
all_kc_max = np.concatenate(all_kc_max)
print(all_kc_min.shape, all_kc_max.shape)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(6, 3), dpi=100)
ax = axes[0]
ax.hist(all_kc_min)
ax.set_title("Neg Peak")
ax = axes[1]
ax.hist(all_kc_max)
ax.set_title("Pos Peak")
plt.tight_layout()
plt.show()

prct = 98

print("Negative Peak prct %d: %1.4f uV" % (prct, np.percentile(all_kc_min, 100-prct)))
print("Positive Peak prct %d: %1.4f uV" % (prct, np.percentile(all_kc_max, prct)))

# Signal parameters preliminaries (MODA)

In [None]:
moda = load_dataset('moda_ss')
signals = moda.get_signals(normalize_clip=False)
pages = moda.get_pages(pages_subset='n2')

moda_page_stats = {
    'scale': [],
    'exponent': [],
    'max_ratio': [],
    'standard_deviation': [],
    'r2': [],
}

for sub_signal, sub_pages in tqdm(zip(signals, pages)):
    sub_signal = sub_signal.reshape(-1, moda.fs * moda.page_duration)[sub_pages]  # [n_pages, n_samples]
    freq, pages_spectrum = utils.compute_pagewise_fft(sub_signal, moda.fs, window_duration=2)
    pages_scales, pages_exponents = utils.compute_pagewise_powerlaw(freq, pages_spectrum)  # (n_pages,)
    
    # Deviation from power law
    f_min = 2
    f_max = 30
    valid_locs = np.where((freq >= f_min) & (freq <= f_max))[0]
    dev_f = freq[valid_locs]
    dev_x = pages_spectrum[:, valid_locs]
    dev_x_law = [fit_s * (dev_f ** fit_e) for fit_s, fit_e in zip(pages_scales, pages_exponents)]
    dev_x_law = np.stack(dev_x_law, axis=0)
    error = dev_x / dev_x_law  # n_pages, n_freqs
    max_error = np.max(error, axis=1)  # to detect weird peaks, shape (n_pages,)
    
    # for r2, we remove sigma
    valid_locs = np.where((dev_f < 10) | (dev_f > 17))[0]
    dev_f = dev_f[valid_locs]
    log_dev_x = np.log(dev_x[:, valid_locs])
    log_dev_x_law = np.log(dev_x_law[:, valid_locs])
    squared_data = np.sum((log_dev_x - log_dev_x.mean(axis=1).reshape(-1, 1)) ** 2, axis=1)
    squared_residuals = np.sum((log_dev_x - log_dev_x_law) ** 2, axis=1)
    r2 = 1 - squared_residuals / squared_data  # (n_pages,)
    
    moda_page_stats['scale'].append(pages_scales)
    moda_page_stats['exponent'].append(pages_exponents)
    moda_page_stats['max_ratio'].append(max_error)
    moda_page_stats['standard_deviation'].append(sub_signal.std(axis=1))
    moda_page_stats['r2'].append(r2)

for key in moda_page_stats.keys():
    moda_page_stats[key] = np.concatenate(moda_page_stats[key])
    print(key, moda_page_stats[key].shape)

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(14, 4), dpi=120)
for i_k, key in enumerate(moda_page_stats.keys()):
    ax = axes[i_k]
    ax.hist(moda_page_stats[key], bins=20)
    ax.set_title('MODA\n%s. Min %1.4f Max %1.4f' % (key, moda_page_stats[key].min(), moda_page_stats[key].max()), fontsize=8)
    print(key, moda_page_stats[key].min(), moda_page_stats[key].max())
    ax.tick_params(labelsize=8)
plt.tight_layout()
plt.show()

# NSRR EXPLORATION

In [2]:
nsrr_preds = fig_utils.PredictedNSRR()

In [3]:
nsrr = load_dataset(constants.NSRR_SS_NAME, load_checkpoint=True)

Dataset nsrr_ss with 3 patients.
Loading from checkpoint... Loaded
Global STD: None
Dataset nsrr_ss with 11593 patients.


# Reproducción de tendencias

In [4]:
n_folds = 10

fold_ids_list = np.arange(n_folds)
predictions = nsrr_preds.get_predictions(fold_ids_list, nsrr)

print("Loaded predictions for %d subjects" % (len(predictions.all_ids)))

Loaded predictions for 1000 subjects


In [None]:
min_n2_minutes = 60

subject_ids = predictions.all_ids

table_byevent = {'duration': []}
table_bysubject = {
    'subject_id': [], 
    'duration': [], 'density': [], 'proba_event': [], 
    'n2_minutes': [], 'origin': [],
    'age': [], 'female': [],
}
for i in range(len(subject_ids)):
    subject_id = subject_ids[i]
    n2_pages = predictions.data[subject_id]['n2_pages']
    n2_minutes = n2_pages.size * nsrr.original_page_duration / 60
    
    if n2_minutes < min_n2_minutes:
        print("Skipped by N2 minutes: Subject %s with %d N2 minutes" % (subject_id, n2_minutes))
        continue
    
    marks = predictions.get_subject_stamps(subject_id)
    n_marks = marks.shape[0]
    
    if n_marks == 0:
        print("Skipped by Marks     : Subject %s with %d marks (%d N2 minutes)" % (subject_id, n_marks, n2_minutes))
        continue

    durations = (marks[:, 1] - marks[:, 0] + 1) / nsrr.fs
    subject_proba = predictions.get_subject_stamps_probabilities(subject_id)
    
    subject_mean_duration = np.mean(durations)
    subject_density = n_marks / n2_minutes
    subject_mean_proba = np.mean(subject_proba)
    subdataset = subject_id[:-4]
    
    subject_data = nsrr.read_subject_data(subject_id, exclusion_of_pages=False)
    age = float(subject_data['age'].item())
    female = int(subject_data['sex'].item() == 'f')
    
    table_byevent['duration'].append(durations)
    table_bysubject['subject_id'].append(subject_id)
    table_bysubject['duration'].append(subject_mean_duration)
    table_bysubject['density'].append(subject_density)
    table_bysubject['proba_event'].append(subject_mean_proba)
    table_bysubject['n2_minutes'].append(n2_minutes)
    table_bysubject['origin'].append(subdataset)
    table_bysubject['age'].append(age)
    table_bysubject['female'].append(female)
    
for key in table_byevent:
    table_byevent[key] = np.concatenate(table_byevent[key])
table_byevent = pd.DataFrame.from_dict(table_byevent)
table_bysubject = pd.DataFrame.from_dict(table_bysubject)
print("Done.")

In [None]:
table_byevent

In [None]:
table_bysubject

In [None]:
# by-event stuff
table_byevent.hist(bins=np.arange(0.3, 3.0 + 0.001, 0.1))
plt.tight_layout()
plt.show()

In [None]:
table_bysubject.hist(bins=20)
plt.tight_layout()
plt.show()

In [None]:
a = (table_bysubject['n2_minutes'] * table_bysubject['density'])
a[a <= 20].hist(bins=np.arange(0, 20 + 0.001, 1))
plt.show()
print(a.min())

In [None]:
zoom_density = 0.5

table_zoom = table_bysubject[table_bysubject.density < zoom_density]
table_zoom.hist()
plt.tight_layout()
plt.show()

In [None]:
# correlations

param_names = table_bysubject.select_dtypes(include=np.number).columns.tolist()
n_params = len(param_names)
n_plots = n_params * (n_params - 1) / 2
n_cols = 5
n_rows = int(np.ceil(n_plots / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2.5, n_rows* 2.5), dpi=100)
axes = axes.flatten()
global_count = -1
for i in range(n_params):
    for j in range(i + 1, n_params):
        global_count += 1
        ax = axes[global_count]
        
        x_data = table_bysubject[param_names[i]].values
        y_data = table_bysubject[param_names[j]].values
        
        ax.plot(x_data, y_data, linestyle="none", marker='o', markersize=3, alpha=0.1)
        ax.set_xlabel(param_names[i], fontsize=8)
        ax.set_ylabel(param_names[j], fontsize=8)
        
        ax.tick_params(labelsize=8)
        ax.grid()
plt.tight_layout()
plt.show()

In [None]:
table_bysubject[table_bysubject.duration > 1.2]

# Check stats of all pages

In [None]:
page_stats = {
    'scale': [],
    'exponent': [],
    'max_ratio': [],
    'standard_deviation': [],
    'r2': [],
}

for subject_id in tqdm(predictions.all_ids):
    subject_data = nsrr.read_subject_data(subject_id, exclusion_of_pages=False)
    sub_signal = subject_data['signal']
    sub_pages = predictions.data[subject_id]['n2_pages']
    
    n2_minutes = sub_pages.size * nsrr.original_page_duration / 60
    if n2_minutes < 60:
        continue
    
    sub_signal = sub_signal.reshape(-1, nsrr.fs * nsrr.original_page_duration)[sub_pages]  # [n_pages, n_samples]
    freq, pages_spectrum = utils.compute_pagewise_fft(sub_signal, nsrr.fs, window_duration=2)
    pages_scales, pages_exponents = utils.compute_pagewise_powerlaw(freq, pages_spectrum)  # (n_pages,)
    
    # Deviation from power law
    f_min = 2
    f_max = 30
    valid_locs = np.where((freq >= f_min) & (freq <= f_max))[0]
    dev_f = freq[valid_locs]
    dev_x = pages_spectrum[:, valid_locs]
    dev_x_law = [fit_s * (dev_f ** fit_e) for fit_s, fit_e in zip(pages_scales, pages_exponents)]
    dev_x_law = np.stack(dev_x_law, axis=0)
    error = dev_x / dev_x_law  # n_pages, n_freqs
    max_error = np.max(error, axis=1)  # to detect weird peaks, shape (n_pages,)
    
    # for r2, we remove sigma
    valid_locs = np.where((dev_f < 10) | (dev_f > 17))[0]
    dev_f = dev_f[valid_locs]
    log_dev_x = np.log(dev_x[:, valid_locs])
    log_dev_x_law = np.log(dev_x_law[:, valid_locs])
    squared_data = np.sum((log_dev_x - log_dev_x.mean(axis=1).reshape(-1, 1)) ** 2, axis=1)
    squared_residuals = np.sum((log_dev_x - log_dev_x_law) ** 2, axis=1)
    r2 = 1 - squared_residuals / squared_data  # (n_pages,)
    
    page_stats['scale'].append(pages_scales)
    page_stats['exponent'].append(pages_exponents)
    page_stats['max_ratio'].append(max_error)
    page_stats['standard_deviation'].append(sub_signal.std(axis=1))
    page_stats['r2'].append(r2)

for key in page_stats.keys():
    page_stats[key] = np.concatenate(page_stats[key])
    print(key, page_stats[key].shape)

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(14, 4), dpi=120)
for i_k, key in enumerate(page_stats.keys()):
    ax = axes[i_k]
    
    value_vector = page_stats[key]
    moda_value_vector = moda_page_stats[key]
    
    if key in ['scale', 'max_ratio', 'standard_deviation']:
        value_vector = np.log(value_vector)
        moda_value_vector = np.log(moda_value_vector)
        name_str = 'log %s' % key
    else:
        name_str = key
        
    moda_min = moda_value_vector.min()
    moda_max = moda_value_vector.max()
    
    # Count number of pages falling outside MODA ranges
    n_inliers = np.sum((value_vector >= moda_min) * (value_vector <= moda_max))
    n_total = value_vector.size
    fraction_inliers = 100 * n_inliers / n_total
    
    ax.hist(value_vector, bins=50)
    ax.set_title('NSRR\n%s\nMin %1.4f Max %1.4f\n(MODA Min %1.4f Max %1.4f)\nInliers: %1.2f%%' % (
        name_str, value_vector.min(), value_vector.max(), moda_min, moda_max, fraction_inliers), fontsize=8)
    ax.tick_params(labelsize=8)
    ax.axvline(moda_min, color="k", linestyle="--", linewidth=0.8)
    ax.axvline(moda_max, color="k", linestyle="--", linewidth=0.8)
    ax.set_yticks([])
plt.tight_layout()
plt.show()

In [None]:
spectra = []
all_std = []
large_subjects = []
origin_subject_id = []
origin_page_id = []
for subject_id in predictions.all_ids:
    signal = nsrr.get_subject_signal(subject_id, normalize_clip=False)
    n2_pages = predictions.data[subject_id]['n2_pages']
    n2_minutes = n2_pages.size * nsrr.original_page_duration / 60
    
    if n2_minutes < 60:
        continue
    
    x_pages = signal.reshape(-1, nsrr.original_page_duration * nsrr.fs)[n2_pages]
    for i_n2, x in enumerate(x_pages):
        freq, power = utils.power_spectrum_by_sliding_window(x, nsrr.fs, window_duration=2)
        spectra.append(power)
        this_std = x.std()
        all_std.append(this_std)
        origin_subject_id.append(subject_id)
        origin_page_id.append(n2_pages[i_n2])
        if this_std > 50:
            large_subjects.append(subject_id)
spectra = np.stack(spectra, axis=0)
all_std = np.array(all_std)
large_subjects = np.unique(large_subjects)
origin_subject_id = np.array(origin_subject_id)
origin_page_id = np.array(origin_page_id)
print("Done")
print(spectra.shape, all_std.shape, large_subjects.shape)

In [None]:
# std hist
fig, ax = plt.subplots(1, 1, figsize=(8, 3), dpi=100)
ax.hist(all_std, bins=np.arange(0, 130 + 0.001, 2.5))
ax.set_xticks(np.arange(0, 130 + 0.001, 10))
ax.grid()
plt.show()

print("STD -- Min %1.4f, Mean %1.4f, Median %1.4f, Max %1.4f" % (all_std.min(), all_std.mean(), np.median(all_std), all_std.max()))

# Central 95% of pages
# STD between 7.5 and 36.
# Central 99% of pages
# STD between 6.15 and 73.34

prctl = 0.5
print("STD -- Percentile %s: %1.4f" % (prctl, np.percentile(all_std, prctl)))

prctl = 99.5
print("STD -- Percentile %s: %1.4f" % (prctl, np.percentile(all_std, prctl)))
print(np.sum(all_std > 50), all_std.size, 100 * np.sum(all_std > 50)/all_std.size)

In [None]:
valid_locs = np.where((freq >= 0.3) & (freq <= 30))[0]
freq_short = freq[valid_locs]
spectra_short = spectra[:, valid_locs]

# Compute power law
locs_to_use = np.where(freq_short >= 4)[0]
x_data = freq_short[locs_to_use]
y_data = spectra_short[350, locs_to_use]

locs_no_sigma = np.where((x_data < 10) | (x_data > 17))[0]
x_data_no_sigma = x_data[locs_no_sigma]
y_data_no_sigma = y_data[locs_no_sigma]

log_x = np.log(x_data_no_sigma)
log_y = np.log(y_data_no_sigma)
pl_exponent, pl_intercept, _, _, _ = scipy.stats.linregress(log_x,log_y)
def fitted_power_law(x):
    return (x ** pl_exponent) * np.exp(pl_intercept)
print(pl_exponent, pl_intercept)

fig, axes = plt.subplots(1, 2, figsize=(6, 3), dpi=140)

ax = axes[0]
ax.plot(x_data, y_data, linewidth=0.8)
ax.plot(x_data_no_sigma, y_data_no_sigma, linewidth=0.8)
ax.plot(x_data, fitted_power_law(x_data), linewidth=0.8)
ax.set_xlim([3, 35])
ax.tick_params(labelsize=8)
ax.set_xlabel("Frequency (Hz)", fontsize=8)

ax = axes[1]
ax.plot(x_data, y_data, linewidth=0.8)
ax.plot(x_data_no_sigma, y_data_no_sigma, linewidth=0.8)
ax.plot(x_data, fitted_power_law(x_data), linewidth=0.8)
ax.set_xlim([3, 35])
ax.set_xscale('log')
ax.set_yscale('log')
ax.tick_params(labelsize=8)
ax.set_xlabel("Frequency (Hz)", fontsize=8)
plt.tight_layout()
plt.show()

In [None]:
exponent_list = []
for curve in spectra_short:
    #Compute power law
    locs_to_use = np.where(freq_short >= 4)[0]
    x_data = freq_short[locs_to_use]
    y_data = curve[locs_to_use]

    locs_no_sigma = np.where((x_data < 10) | (x_data > 17))[0]
    x_data_no_sigma = x_data[locs_no_sigma]
    y_data_no_sigma = y_data[locs_no_sigma]
    # Therefore we are considering frequencies 4-10 and 17-30 Hz for the power law.

    log_x = np.log(x_data_no_sigma)
    log_y = np.log(y_data_no_sigma)
    pl_exponent, pl_intercept, _, _, _ = scipy.stats.linregress(log_x,log_y)
    exponent_list.append(pl_exponent)
exponent_list = np.array(exponent_list)

In [None]:
exponent_list.shape

In [None]:
mean_val = np.median(exponent_list)
disp_val = scipy.stats.median_absolute_deviation(exponent_list)  # dispersion around median
disp_width = 2.2

lower_bound = mean_val - disp_width * disp_val
upper_bound = mean_val + disp_width * disp_val

n_inliers = np.sum((exponent_list >= lower_bound) * (exponent_list <= upper_bound))
print(n_inliers, 100 * n_inliers / exponent_list.size)
print(lower_bound, upper_bound)
print("")

fig, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=120)
ax.hist(exponent_list, bins=50)
ax.axvline(mean_val, color="k", linestyle="--", linewidth=1)
ax.axvline(upper_bound, color="k", linestyle="--", linewidth=1)
ax.axvline(lower_bound, color="k", linestyle="--", linewidth=1)
ax.set_title("Mean %1.4f" % mean_val)
plt.show()

for prct in [0, 0.5, 5, 95, 99.5, 100]:
    value = np.percentile(exponent_list, prct)
    print("Percentile %s is %1.4f" % (prct, value))

In [None]:
valid_locs = np.where((freq >= 0.3) & (freq <= 30))[0]
freq_short = freq[valid_locs]
spectra_short = spectra[:, valid_locs]

# Compute avg power of 2Hz freq bands
all_p_avg = []
all_f_avg = []
start_freq = 0
width = 2
n_bands = int(np.ceil((freq_short.max() - start_freq) / width))
for i_band in range(n_bands):
    end_freq = start_freq + width
    locs = np.where((f >= start_freq) & (f < end_freq))[0]
    center_freq = (start_freq + end_freq)/2
    start_freq = end_freq
    if center_freq < 2:
        continue
    this_power = spectra_short[:, locs].mean(axis=1)
    all_p_avg.append(this_power)
    all_f_avg.append(center_freq)
    
all_f_avg = np.array(all_f_avg)
all_p_avg = np.stack(all_p_avg, axis=1)

fig, axes = plt.subplots(1, 2, figsize=(6, 3), dpi=200)

ax = axes[0]
prctl_results = np.percentile(all_p_avg, (0, 0.5, 25, 50, 75, 99.5, 100), axis=0)
for curve in prctl_results:
    ax.plot(all_f_avg, curve, linewidth=0.8, marker='o')
ax.set_xlim([0.1, 35])
ax.tick_params(labelsize=8)
ax.set_xlabel("Frequency (Hz)", fontsize=8)

ax = axes[1]
for curve in prctl_results:
    ax.plot(all_f_avg, curve, linewidth=0.8, marker='o')
ax.set_xlim([0.1, 35])
ax.set_xscale('log')
ax.set_yscale('log')
ax.tick_params(labelsize=8)
ax.set_xlabel("Frequency (Hz)", fontsize=8)
plt.tight_layout()
plt.show()

In [None]:
valid_locs = np.where((freq >= 0.3) & (freq <= 30))[0]
freq_short = freq[valid_locs]
spectra_short = spectra[:, valid_locs]

fig, axes = plt.subplots(1, 2, figsize=(6, 3), dpi=200)

prctl_results = np.percentile(all_p_avg, (0.5, 99.5), axis=0)

ax = axes[0]
for curve in prctl_results:
    ax.plot(all_f_avg, curve, linewidth=0.8, marker='o', markersize=3, color="b")
for curve in spectra_short[outliers_locs]:
    ax.plot(freq_short, curve, linewidth=0.8, color="k", alpha=0.1)
ax.set_xlim([0.1, 35])
ax.tick_params(labelsize=8)
ax.set_xlabel("Frequency (Hz)", fontsize=8)

ax = axes[1]
for curve in prctl_results:
    ax.plot(all_f_avg, curve, linewidth=0.8, marker='o', markersize=3, color="b")
for curve in spectra_short[outliers_locs]:
    ax.plot(freq_short, curve, linewidth=0.8, color="k", alpha=0.1)
ax.set_xlim([0.1, 35])
ax.set_xscale('log')
ax.set_yscale('log')
ax.tick_params(labelsize=8)
ax.set_xlabel("Frequency (Hz)", fontsize=8)
plt.tight_layout()
plt.show()

In [None]:
loc = 350

p = spectra_short[loc]
f = freq_short
# Compute avg power of 2Hz freq bands
p_avg = []
f_avg = []
start_freq = 0
width = 2
n_bands = int(np.ceil((f.max() - start_freq) / width))
for i_band in range(n_bands):
    end_freq = start_freq + width
    locs = np.where((f >= start_freq) & (f < end_freq))[0]
    f_avg.append((start_freq + end_freq)/2)
    p_avg.append(p[locs].mean())
    start_freq = end_freq
f_avg = np.array(f_avg)
p_avg = np.array(p_avg)

plt.plot(f, p)
plt.plot(f_avg, p_avg, marker='o')
print(f_avg)
plt.title("Subject %s, page %d" % (origin_subject_id[loc], origin_page_id[loc]))
plt.show()
print("Subject %s, page %d" % (origin_subject_id[loc], origin_page_id[loc]))

In [None]:
lower_bound, upper_bound = np.percentile(all_p_avg, (0.5, 99.5), axis=0)
print("Original: ", all_p_avg.shape)
spectra_filt = []
inliers_locs = []
outliers_locs = []
for i_page, spec_page in enumerate(all_p_avg):
    if np.all(spec_page >= lower_bound) and np.all(spec_page <= upper_bound):
        spectra_filt.append(spec_page)
        inliers_locs.append(i_page)
    else:
        outliers_locs.append(i_page)
spectra_filt = np.stack(spectra_filt, axis=0)
print("Filtered:", spectra_filt.shape)
print("Percentage kept: %1.4f" % (100 * spectra_filt.shape[0] / all_p_avg.shape[0]))

# Check single subject

In [None]:
subject_id = 'shhs1-200721'   # 'shhs1-201711'  # 'sof-visit-8-02332'

print("subject %s" % subject_id)
subject_data = nsrr.read_subject_data(subject_id)
signal = subject_data['signal']
n2_pages = subject_data['n2_pages']
n2_minutes = n2_pages.size * nsrr.original_page_duration / 60
marks = predictions.get_subject_stamps(subject_id, pages_subset='n2')
proba = predictions.get_subject_probabilities(subject_id)
proba_up = np.repeat(proba, 8)
durations = (marks[:, 1] - marks[:, 0] + 1) / nsrr.fs

print("Marks", marks.shape)
print("N2 minutes", n2_minutes)

plt.plot(signal)
plt.show()

std_to_find = 50

x_pages_std = signal.reshape(-1, nsrr.original_page_duration * nsrr.fs)[n2_pages].std(axis=1)
n2_pages_locs = np.where(x_pages_std > std_to_find)[0]
std_selected = x_pages_std[n2_pages_locs]
pages_locs = n2_pages[n2_pages_locs]
start_sample = pages_locs * nsrr.original_page_duration * nsrr.fs
center_sample = start_sample + nsrr.original_page_duration * nsrr.fs // 2
print("Center samples where STD is greater than %1.4f" % std_to_find)
print(center_sample)
print(std_selected)

In [None]:
# mark_selected = marks[5]

center_sample = int((n2_pages[0] + 0.5) * 30 * 200) # mark_selected.mean()

window_duration = 30
window_size = nsrr.fs * window_duration
start_sample = int(center_sample - window_size // 2)
end_sample = int(start_sample + window_size)
time_axis = np.arange(start_sample, end_sample) / nsrr.fs

fig, ax = plt.subplots(1, 1, figsize=(10, 2.5), dpi=100)

ax.plot(time_axis, signal[start_sample:end_sample], linewidth=.6)

# ax.plot(mark_selected / nsrr.fs, [-100, -100], linewidth=4, color=viz.PALETTE['red'], alpha=0.7)

ax.fill_between(
    time_axis, 
    -300 - 50 * proba_up[start_sample:end_sample], 
    -300 + 50 * proba_up[start_sample:end_sample],
    color=viz.PALETTE['red'], alpha=1.0
)
ax.axhline(-300 - 50, linewidth=0.7, linestyle="-", color="k")
ax.axhline(-300 + 50, linewidth=0.7, linestyle="-", color="k")
ax.axhline(-300 - 25, linewidth=0.7, linestyle="--", color="k")
ax.axhline(-300 + 25, linewidth=0.7, linestyle="--", color="k")
ax.axhline(-300 + 0, linewidth=0.7, linestyle="-", color="k")

ax.set_ylim([-400, 200])
ax.set_xlim([start_sample/nsrr.fs, end_sample/nsrr.fs])

ax.grid()
ax.set_xlabel("Time (s)", fontsize=8)
ax.tick_params(labelsize=8)

page_std = signal[start_sample:end_sample].std()

ax.set_title("Subject %s (Age %1.4f, Sex %s). Page STD %1.4f" % (
    subject_id, subject_data['age'], subject_data['sex'], page_std
))

plt.tight_layout()
plt.show()

print("Page STD: %1.4f" % page_std)

# Compute spectrum of page
freq, power = utils.power_spectrum_by_sliding_window(signal[start_sample:end_sample], nsrr.fs, window_duration=2)
valid_locs = np.where((freq >= 0.3) & (freq <= 30))[0]
freq = freq[valid_locs]
power = power[valid_locs]

locs_to_use = np.where(freq >= 2)[0]
x_data = freq[locs_to_use]
y_data = power[locs_to_use]
locs_no_sigma = np.where((x_data < 10) | (x_data > 17))[0]
x_data_no_sigma = x_data[locs_no_sigma]
y_data_no_sigma = y_data[locs_no_sigma]
log_x = np.log(x_data_no_sigma)
log_y = np.log(y_data_no_sigma)
pl_exponent, pl_intercept, _, _, _ = scipy.stats.linregress(log_x,log_y)
def fitted_power_law(x):
    return (x ** pl_exponent) * np.exp(pl_intercept)
print(pl_exponent, pl_intercept)

fig, axes = plt.subplots(1, 2, figsize=(6, 2), dpi=100)

ax = axes[0]
ax.plot(freq, power)
ax.set_xlim([0.1, 30])
ax.tick_params(labelsize=8)
ax.set_xlabel("Frequency (Hz)", fontsize=8)

ax = axes[1]
ax.plot(freq, power)
ax.plot(freq, fitted_power_law(freq))
ax.set_xlim([0.1, 30])
ax.set_xscale('log')
ax.set_yscale('log')
ax.tick_params(labelsize=8)
ax.set_xlabel("Frequency (Hz)", fontsize=8)
ax.set_title("Exponent %1.4f" % (pl_exponent), fontsize=8)
plt.tight_layout()
plt.show()