In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
from pprint import pprint

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

project_root = os.path.abspath('..')
sys.path.append(project_root)

from sleeprnn.nn import expert_feats
from sleeprnn.helpers import reader
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection import metrics
from sleeprnn.common import constants, pkeys, viz

viz.notebook_full_width()

# Load data

In [None]:
# Choose segment
subject_id = 1
which_kc_stamp = 195
left_context = 10
right_context = 10

# Load data
dataset = reader.load_dataset(constants.MASS_KC_NAME, verbose=False)
stamps_kc = dataset.get_subject_stamps(subject_id=subject_id)
dataset = reader.load_dataset(constants.MASS_SS_NAME, verbose=False)
signal = dataset.get_subject_signal(subject_id=subject_id, normalize_clip=False)
stamps_ss = dataset.get_subject_stamps(subject_id=subject_id)
fs = dataset.fs
print('%d SS stamps.' % stamps_ss.shape[0])
print('%d KC stamps.' % stamps_kc.shape[0])

# Extract chosen segment
central_sample = stamps_kc[which_kc_stamp, :].mean()
start_sample = int(central_sample - fs * left_context)
end_sample = int(central_sample + fs * right_context)
segment_signal = signal[start_sample:end_sample]
segment_stamps_ss = utils.filter_stamps(stamps_ss, start_sample, end_sample)
segment_stamps_kc = utils.filter_stamps(stamps_kc, start_sample, end_sample)
time_axis = np.arange(start_sample, end_sample) / fs
segment_label_ss = np.zeros(end_sample)
for ss in segment_stamps_ss:
    segment_label_ss[ss[0]:ss[1]+1] = 1
segment_label_ss = segment_label_ss[start_sample:]
print("Segment of %d samples extracted." % segment_signal.size)

# Plotting functions
general_lw = 1.0
signal_lw = 0.6
mark_lw = 3
signal_color = viz.GREY_COLORS[7]
custom_color = viz.PALETTE['red']
ss_color = viz.PALETTE['blue']
kc_color = viz.PALETTE['green']


def draw_signal(
    ax, time_axis, my_signal, y_min=-100, y_max=100, yticks=[-50, -25, 0, 25, 50], 
    xlabel='Intervals of 1 s', ylabel='EEG ($\mu$V)', 
    linewidth=signal_lw, color=signal_color
):
    ax.plot(time_axis, my_signal, linewidth=linewidth, color=color)
    ax.set_xlabel(xlabel, fontsize=viz.FONTSIZE_GENERAL)
    ax.set_ylabel(ylabel, fontsize=viz.FONTSIZE_GENERAL)
    ax.set_yticks(yticks)
    ax.set_xlim([time_axis[0], time_axis[-1]+0.05])
    ax.set_ylim([y_min, y_max])
    ax.set_xticks([])
    ax.set_xticks(np.arange(time_axis[0], time_axis[-1]+0.1, 1), minor=True)
    ax.grid(b=True, axis='x', which='minor')
    ax.tick_params(labelsize=viz.FONTSIZE_GENERAL)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='x', which='both',length=0)
    return ax
    

def draw_stamps(ax, my_stamps_ss, my_stamps_kc, ss_loc=-60, kc_loc=-70):
    # SS
    label_used = False
    for stamp in my_stamps_ss:
        label = None if label_used else "Sleep Spindle"
        label_used = True
        ax.plot(stamp / fs, [ss_loc, ss_loc], label=label, color=ss_color, linewidth=mark_lw)
    # KC
    label_used = False
    for stamp in my_stamps_kc:
        label = None if label_used else "K-Complex"
        label_used = True
        ax.plot(stamp / fs, [kc_loc, kc_loc], label=label, color=kc_color, linewidth=mark_lw)
    return ax

# A7 layer

In [None]:
tf.reset_default_graph()
expert_feats_tf = expert_feats.a7_layer_tf(
    segment_signal.reshape(1, -1), fs,
    window_duration=0.5,
    window_duration_absSigPow=0.2,
    sigma_lowcut=11,
    sigma_highcut=16,
    use_log_absSigPow=True,
    use_log_relSigPow=True,
    use_log_sigCov=True,
    use_zscore_relSigPow=True,
    use_zscore_sigCov=True,
    use_zscore_sigCorr=False,
    remove_delta_in_cov=False,
    dispersion_mode="std"
)
with tf.Session() as sess:
    expert_feats_np = sess.run(expert_feats_tf)

In [None]:
feats_names = ['absSigPow', 'relSigPow', 'sigCov', 'sigCorr']

fig, axes = plt.subplots(5, 1, dpi=100, figsize=(8, 6))

ax = axes[0]
ax = draw_signal(ax, time_axis, segment_signal)
ax = draw_stamps(ax, segment_stamps_ss, segment_stamps_kc)
ax.legend(loc='lower left', bbox_to_anchor=(0,1), ncol=2, fontsize=viz.FONTSIZE_GENERAL, frameon=False)

for i in range(4):
    ax = axes[i+1]
    this_feat = expert_feats_np[..., i].flatten()
    this_name = feats_names[i]
    y_max = max(np.abs(this_feat).max(), 1)
    y_min = -y_max
    ax = draw_signal(ax, time_axis, this_feat, y_min=y_min, y_max=y_max, yticks=[-1, 0, 1], ylabel=this_name)

plt.tight_layout()
plt.show()

# Visualization for other segments

In [None]:
which_kc_stamp = 300

# Extract chosen segment
central_sample = stamps_kc[which_kc_stamp, :].mean()
start_sample = int(central_sample - fs * left_context)
end_sample = int(central_sample + fs * right_context)
segment_signal = signal[start_sample:end_sample]
segment_stamps_ss = utils.filter_stamps(stamps_ss, start_sample, end_sample)
segment_stamps_kc = utils.filter_stamps(stamps_kc, start_sample, end_sample)
time_axis = np.arange(start_sample, end_sample) / fs
segment_label_ss = np.zeros(end_sample)
for ss in segment_stamps_ss:
    segment_label_ss[ss[0]:ss[1]+1] = 1
segment_label_ss = segment_label_ss[start_sample:]
print("Segment of %d samples extracted." % segment_signal.size)

tf.reset_default_graph()
expert_feats_tf = expert_feats.a7_layer_tf(
    segment_signal.reshape(1, -1), fs,
    window_duration=0.5,
    window_duration_absSigPow=0.2,
    sigma_lowcut=11,
    sigma_highcut=16,
    use_log_absSigPow=True,
    use_log_relSigPow=True,
    use_log_sigCov=True,
    use_zscore_relSigPow=True,
    use_zscore_sigCov=True,
    use_zscore_sigCorr=True,
    remove_delta_in_cov=False,
    dispersion_mode="made"
)
with tf.Session() as sess:
    expert_feats_np = sess.run(expert_feats_tf)
    
feats_names = ['absSigPow', 'relSigPow', 'sigCov', 'sigCorr']

fig, axes = plt.subplots(5, 1, dpi=160, figsize=(8, 6))

ax = axes[0]
ax = draw_signal(ax, time_axis, segment_signal)
ax = draw_stamps(ax, segment_stamps_ss, segment_stamps_kc)
ax.legend(loc='lower left', bbox_to_anchor=(0,1), ncol=2, fontsize=viz.FONTSIZE_GENERAL, frameon=False)

for i in range(4):
    ax = axes[i+1]
    this_feat = expert_feats_np[..., i].flatten()
    this_name = feats_names[i]
    y_max = max(np.abs(this_feat).max(), 1)
    y_min = -y_max
    ax = draw_signal(ax, time_axis, this_feat, y_min=y_min, y_max=y_max, yticks=[-1, 0, 1], ylabel=this_name)

plt.tight_layout()
plt.show()