In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from scipy.signal import find_peaks, freqz

sys.path.append('..')

from sleeprnn.helpers import reader
from sleeprnn.data import utils
from sleeprnn.common import constants, pkeys, viz

viz.notebook_full_width()
%matplotlib inline

# Load dataset

In [None]:
dataset = reader.load_dataset(constants.CAP_SS_NAME)
fs = dataset.fs

# Visualize single page

In [None]:
which_expert = 1

subject_id = dataset.all_ids[31]

signal = dataset.get_subject_signal(subject_id, normalize_clip=False, which_expert=which_expert)
marks = dataset.get_subject_stamps(subject_id, pages_subset=constants.N2_RECORD, which_expert=which_expert)
n2_pages = dataset.get_subject_pages(subject_id, pages_subset=constants.N2_RECORD)
print("Subject", subject_id)
print("Total spindle marks:", marks.shape[0])
print("Total N2 pages:", n2_pages.size)

# Weird amplitude
weird_locs = np.where(np.abs(signal) > 300)[0]
weird_pages = weird_locs // dataset.page_size
weird_pages = np.unique(weird_pages)
print(weird_pages)
n2_locs = [i_n2 for i_n2 in range(n2_pages.size) if n2_pages[i_n2] in weird_pages]
print("N2 pages indices subset:", n2_locs)

fig, ax = plt.subplots(1, 1, figsize=(12, 3), dpi=120)
ax.plot(signal, linewidth=0.8)
ax.set_ylim([-500, 500])
plt.show()



# Weird change
#changes = np.abs(np.diff(signal))
#plt.hist(changes[changes > 20])
#plt.show()

In [None]:
def draw_single_page(n2_page_index_to_show, dpi):
    selected_page = n2_pages[n2_page_index_to_show]
    start_sample = int(selected_page * dataset.page_size)
    end_sample = start_sample + dataset.page_size
    

    segment_signal = signal[start_sample:end_sample]
    segment_marks = utils.filter_stamps(marks, start_sample, end_sample)

    sigma_signal = utils.broad_filter(segment_signal, fs, lowcut=11, highcut=16)
    time_axis = np.arange(start_sample, end_sample) / fs

    fig, axes = plt.subplots(2, 1, figsize=(12, 6), dpi=dpi)

    axes[0].set_title("Subject %s, page in record %d" % (subject_id, selected_page))
    axes[0].plot(time_axis, segment_signal, linewidth=0.6)
    axes[1].plot(time_axis, sigma_signal, linewidth=0.6)
    for ax in axes[:2].flatten():
        ax.set_ylim([-300, 300])
        ax.set_xlim([start_sample/fs, end_sample/fs])
        for mark in segment_marks:
            mark = np.clip(mark, a_min=start_sample, a_max=end_sample)
            ax.plot(mark / fs, [-50, -50], linewidth=5, alpha=0.5, color=viz.PALETTE['red'])
        ax.set_xticks(np.linspace(start_sample/fs, end_sample/fs, 41), minor=True)
        ax.grid(axis="x", which="minor")
    
    plt.tight_layout()
    plt.show()


    # power, freq = utils.power_spectrum(segment_signal, fs)
    # fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=120)
    # ax.plot(freq, power, linewidth=0.7)
    # ax.set_xlim([0, 30])
    # ax.set_xlabel("Frequency (Hz)")
    # plt.show()
    
    
n2_page_index_to_show = 0  # 110


dpi = 120
print('Total N2 pages: %d' % n2_pages.size)
style = {'description_width': 'initial'}
layout= widgets.Layout(width='1000px')
widgets.interact(
    lambda page_id: draw_single_page(page_id, dpi=dpi),
    page_id=widgets.IntSlider(
        min=0, max=n2_pages.size-1, step=1, value=n2_page_index_to_show, 
        continuous_update=False,
        style=style,
        layout=layout
    ));

# Check duration histogram

In [None]:
marks_list = dataset.get_stamps(pages_subset=constants.N2_RECORD, which_expert=which_expert)

In [None]:
marks = np.concatenate(marks_list, axis=0)
print(marks.shape)

In [None]:
durations = (marks[:, 1] - marks[:, 0]) / fs

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=120)
ax.hist(durations, bins=np.arange(0.3, 3.001, 0.2))
plt.show()

In [None]:
np.sum(durations < 0.4)

In [None]:
thr = 0.5
all_msg = []
all_fractions = []
all_means = []
for subject_id, subject_marks in zip(dataset.all_ids, marks_list):
    subject_durations = (subject_marks[:, 1] - subject_marks[:, 0]) / fs
    mean_duration = np.mean(subject_durations)
    msg = "Suject %s: %3d SS marks, %3d less than %ss (%1.2f%%). Mean duration: %1.3fs" % (
        subject_id, subject_durations.size, np.sum(subject_durations < thr), thr, 100 * np.sum(subject_durations < thr) / subject_durations.size, mean_duration)
    all_msg.append(msg)
    all_fractions.append(-100 * np.sum(subject_durations < thr) / subject_durations.size)
    all_means.append(mean_duration)
idx_sorted = np.argsort(all_fractions)
all_msg = [all_msg[i] for i in idx_sorted]
# for msg in all_msg:
    # print(msg)
    
# print("\n Mean of means in duration: %1.3fs" % np.mean(all_means))
# print("Mean of all durations aggregated: %1.3fs" % np.mean(durations))

# Analysis of parameters

## Duration histogram per subject

In [None]:
which_expert = 1

marks_list = dataset.get_stamps(pages_subset=constants.N2_RECORD, which_expert=which_expert)
duration_resolution = 0.2
fig, ax = plt.subplots(9,9, figsize=(14, 8), dpi=80, sharex=True)
for i, (subject_id, subject_marks) in enumerate(zip(dataset.all_ids, marks_list)):
    subject_durations = (subject_marks[:, 1] - subject_marks[:, 0]) / fs
    mean_duration = np.mean(subject_durations)
    row = i // 9
    col = i % 9
    ax[row, col].set_title("CAP %s (Mean %1.2fs)" % (subject_id, mean_duration), fontsize=8)
    ax[row, col].hist(subject_durations, bins=np.arange(0.3, 3.001, duration_resolution))
    ax[row, col].axvline(mean_duration, color='k', linestyle="--", linewidth=1.5, alpha=0.5)
[s_ax.set_xlim([0, 3]) for s_ax in ax[-1, :]]
[s_ax.set_xlabel("SS duration (s)", fontsize=6) for s_ax in ax[-1, :]]
[s_ax.tick_params(labelsize=6) for s_ax in ax.flatten()]
plt.tight_layout()
plt.show()

## Spindle parameters

In [None]:
def analyze_spindle(spindle, fs):
    duration = spindle.size / fs
    pp_amplitude = spindle.max() - spindle.min()
    rms = np.sqrt(np.mean(spindle ** 2))
    central_freq_count = find_peaks(spindle)[0].size / duration
    # Compute peak frequency by fft
    w, h = freqz(spindle)
    resp_freq = w * fs / (2*np.pi)
    resp_amp = abs(h)
    max_loc = np.argmax(resp_amp)
    central_freq_fft = resp_freq[max_loc]
    results = {
        'duration': duration,
        'pp_amplitude': pp_amplitude,
        'rms': rms,
        'central_freq_count': central_freq_count,
        'central_freq_fft': central_freq_fft
    }
    return results


def listify_dictionaries(list_of_dicts):
    dict_of_lists = {}
    for key in list_of_dicts[0].keys():
        dict_of_lists[key] = []
        for single_dict in list_of_dicts:
            dict_of_lists[key].append(single_dict[key])
    return dict_of_lists

In [None]:
which_expert = 3

train_analysis = []
for subject_id in dataset.all_ids:
    signal = dataset.get_subject_signal(subject_id, normalize_clip=False, which_expert=which_expert)
    stamps = dataset.get_subject_stamps(subject_id, which_expert=which_expert, pages_subset=constants.N2_RECORD)
    sigma_signal = utils.broad_filter(signal, fs, lowcut=9, highcut=17)
    spindles = [sigma_signal[s_start:s_end] for (s_start, s_end) in stamps]
    analysis = []
    for spindle in spindles:
        analysis_results = analyze_spindle(spindle, fs)
        analysis.append(analysis_results)
    analysis = listify_dictionaries(analysis)
    train_analysis.append(analysis)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(20, 8), dpi=80)
showfliers = False
xticklabels_rotation = 90

axes[0, 0].set_title("Spindle Parameters (CAP-E%d)" % which_expert, fontweight='bold', loc="left", fontsize=14)

# Duration
ax = axes[0, 0]
ax.set_ylabel("Duration [s]")
ax.boxplot(
    [analysis['duration'] for analysis in train_analysis], 
    labels=dataset.all_ids, showfliers=showfliers,
    flierprops={'markersize': 2})
ax.grid(axis='y')
ax.set_xlabel("Subject")
ax.set_yticks(np.arange(0.5, 3.0 + 0.001, 0.25))
ax.set_ylim([0.5, 3.0])
ax.tick_params(axis='x', rotation=xticklabels_rotation)

# Frequency
ax = axes[1, 0]
ax.set_ylabel("Frequency [Hz]")
ax.boxplot(
    [analysis['central_freq_fft'] for analysis in train_analysis], 
    labels=dataset.all_ids, showfliers=showfliers,
    flierprops={'markersize': 2})
ax.grid(axis='y')
ax.set_yticks([11, 12, 13, 14, 15])
ax.set_ylim([10, 16])
ax.set_xlabel("Subject")
ax.tick_params(axis='x', rotation=xticklabels_rotation)

# Amplitude PP
ax = axes[0, 1]
ax.set_ylabel("Amplitude PP [V]")
ax.boxplot(
    [analysis['pp_amplitude'] for analysis in train_analysis], 
    labels=dataset.all_ids, showfliers=showfliers,
    flierprops={'markersize': 2})
ax.grid(axis='y')
ax.set_yticks([20, 40, 60, 80])
ax.set_ylim([10, 100])
ax.set_xlabel("Subject")
ax.tick_params(axis='x', rotation=xticklabels_rotation)

# Amplitude RMS
ax = axes[1, 1]
ax.set_ylabel("Amplitude RMS [V]")
ax.boxplot(
    [analysis['rms'] for analysis in train_analysis], 
    labels=dataset.all_ids, showfliers=showfliers,
    flierprops={'markersize': 2})
ax.grid(axis='y')
ax.set_xlabel("Subject")
ax.set_yticks([5, 10, 15, 20])
ax.set_ylim([0, 25])
ax.tick_params(axis='x', rotation=xticklabels_rotation)

for ax in axes.flatten():
    ax.tick_params(labelsize=8)


plt.tight_layout()

fname_prefix = 'params_cap_e%d' % which_expert
plt.savefig("%s.pdf" % fname_prefix, bbox_inches="tight", pad_inches=0.1)
plt.savefig("%s.png" % fname_prefix, bbox_inches="tight", pad_inches=0.1, dpi=200)
plt.show()