In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from pprint import pprint
import sys

import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib import colors, gridspec
import numpy as np
import pyedflib
from scipy.signal import convolve2d, firwin

project_root = '..'
sys.path.append(project_root)

from sleeprnn.common import constants, pkeys, viz
from sleeprnn.common.optimal_thresholds import OPTIMAL_THR_FOR_CKPT_DICT
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection.feeder_dataset import FeederDataset
from sleeprnn.detection import metrics
from sleeprnn.helpers import reader, misc
from sleeprnn.data.mass_ss import PATH_MASS_RELATIVE, PATH_REC, PATH_MARKS, PATH_STATES

RESULTS_PATH = os.path.join(project_root, 'results')
MASS_PATH = os.path.join(project_root, utils.PATH_DATA, PATH_MASS_RELATIVE)

%matplotlib inline
viz.notebook_full_width()

# Load subject's data and predictions

In [None]:
subject_id = 14

annot_chn = 'EEG C3-CLE'
fs = 256
page_duration = 20
state_ids = np.array(['1', '2', '3', '4', 'R', 'W', '?'])
unknown_id = '?'  # Character for unknown state in hypnogram
n2_id = '2'  # Character for N2 identification in hypnogram
hypno_num_dict = {'1': -1, '2': -2, '3': -3, '4': -4, 'R': 0, 'W': 1, '?': 2}

In [None]:
# Load PSG
file_rec = os.path.join(MASS_PATH, PATH_REC, '01-02-%04d PSG.edf' % subject_id)
signals = {}
ignore_channels = ['Resp nasal']
with pyedflib.EdfReader(file_rec) as file:
    print("Reading %s" % file_rec)
    channel_names = file.getSignalLabels()
    for i, name in enumerate(channel_names):
        if name in ignore_channels:
            continue
        this_signal = file.readSignal(i)
        fs_decimal = file.samplefrequency(i)
        # Particular fix for mass dataset:
        fs_rounded = int(np.round(fs_decimal))
        # Transform the original fs frequency with decimals to rounded version
        this_signal = utils.resample_signal_linear(this_signal, fs_old=fs_decimal, fs_new=fs_rounded)
        # Now resample to the required frequency
        if fs != fs_rounded:
            print('Resampling from %d Hz to required %d Hz' % (fs_rounded, fs))
            this_signal = utils.resample_signal(this_signal, fs_old=fs_rounded, fs_new=fs)
        this_signal = this_signal.astype(np.float32)
        signals[name] = this_signal
print("Done.")

In [None]:
def hypnogram_str2num(my_hypno, state_ids, num_values_dict):
    hypno_num = np.zeros(my_hypno.size)
    for single_state_id in state_ids:
        hypno_num[my_hypno == single_state_id] = num_values_dict[single_state_id]
    return hypno_num

In [None]:
# Load annotations -- Hypnogram
file_states = os.path.join(MASS_PATH, PATH_STATES, '01-02-%04d Base.edf' % subject_id)
# Total pages not necessarily equal to total_annots
page_size = int(page_duration * fs)
signal_length = signals[annot_chn].size
total_pages = int(np.ceil(signal_length / page_size))
with pyedflib.EdfReader(file_states) as file:
    print("Reading %s" % file_states)
    annotations = file.readAnnotations()
onsets = np.array(annotations[0])
durations = np.round(np.array(annotations[1]))
stages_str = annotations[2]
# keep only 20s durations
valid_idx = (durations == page_duration)
onsets = onsets[valid_idx]
onsets_pages = np.round(onsets / page_duration).astype(np.int32)
stages_str = stages_str[valid_idx]
stages_char = [single_annot[-1] for single_annot in stages_str]
# Build complete hypnogram
total_annots = len(stages_char)
not_unkown_ids = [state_id for state_id in state_ids if state_id != unknown_id]
not_unkown_state_dict = {}
for state_id in not_unkown_ids:
    state_idx = np.where([stages_char[i] == state_id for i in range(total_annots)])[0]
    not_unkown_state_dict[state_id] = onsets_pages[state_idx]
hypnogram = []
for page in range(total_pages):
    state_not_found = True
    for state_id in not_unkown_ids:
        if page in not_unkown_state_dict[state_id] and state_not_found:
            hypnogram.append(state_id)
            state_not_found = False
    if state_not_found:
        hypnogram.append(unknown_id)
hypnogram = np.asarray(hypnogram)
hypnogram_int = hypnogram_str2num(hypnogram, state_ids, hypno_num_dict)
# Extract N2 pages
n2_pages = np.where(hypnogram == n2_id)[0]
# Drop first, last and second to last page of the whole registers if they where selected.
last_page = total_pages - 1
n2_pages = n2_pages[(n2_pages != 0) & (n2_pages != last_page) & (n2_pages != last_page - 1)]
n2_pages = n2_pages.astype(np.int16)
print("Done.")

In [None]:
# Load annotations -- Spindles
file_marks = os.path.join(MASS_PATH, PATH_MARKS, '01-02-%04d SpindleE1.edf' % subject_id)
with pyedflib.EdfReader(file_marks) as file:
    print("Reading %s" % file_marks)
    annotations = file.readAnnotations()
onsets = np.array(annotations[0])
durations = np.array(annotations[1])
offsets = onsets + durations
marks_time = np.stack((onsets, offsets), axis=1)  # time-stamps
# Transforms to sample-stamps
marks = np.round(marks_time * fs).astype(np.int32)
# Combine marks that are too close according to standards
marks = stamp_correction.combine_close_stamps(marks, fs, 0.3)
# Fix durations that are outside standards
marks = stamp_correction.filter_duration_stamps(marks, fs, 0.3, 3.0)
# keep only N2
marks = utils.extract_pages_for_stamps(marks, n2_pages, page_size)
# Build binary sequence
marks_bin = utils.stamp2seq(marks, 0, signals[annot_chn].size - 1)
print("Done.")

In [None]:
def center_probabilities(probabilities, center):
    """input: probas with class 1 iff proba > center.
    output: probas with class 1 iff proba > 0.5
    """
    probabilities = probabilities.astype(np.float32)
    bias_center = np.log(center / (1-center))
    eps = 1e-6
    probabilities = np.clip(probabilities, eps, 1-eps)
    logits = np.log(probabilities / (1 - probabilities))
    logits = logits - bias_center
    probabilities = 1 / (1 + np.exp(-logits))
    probabilities = probabilities.astype(np.float16)
    return probabilities

In [None]:
# Load predictions -- RED-CWT
ckpt_folder = '20200724_reproduce_red_n2_train_mass_ss/v19_rep1'
fs_proba = 25
fs_detections = 200
seed_id_list = [0, 1, 2, 3]
set_list = [constants.VAL_SUBSET]
dataset_name = constants.MASS_SS_NAME
task_mode = constants.N2_RECORD
predictions_dict = reader.read_prediction_with_seeds(ckpt_folder, dataset_name, task_mode, seed_id_list, set_list)
print("Selecting subject's probabilities and detections.")
for seed_id in seed_id_list:
    seed_val_subjects = predictions_dict[seed_id][constants.VAL_SUBSET].all_ids
    if subject_id in seed_val_subjects:
        seed_thr = OPTIMAL_THR_FOR_CKPT_DICT[ckpt_folder][seed_id]
        predictions_dict[seed_id][constants.VAL_SUBSET].set_probability_threshold(seed_thr)
        probabilities = predictions_dict[seed_id][constants.VAL_SUBSET].get_subject_probabilities(subject_id).copy()
        probabilities = center_probabilities(probabilities, seed_thr)
        detections = predictions_dict[seed_id][constants.VAL_SUBSET].get_subject_stamps(subject_id).copy()
        detections = (detections.astype(np.float32) * fs / fs_detections).astype(np.int32)
        detections_bin = utils.stamp2seq(detections, 0, signals[annot_chn].size - 1)
        break
print("Done.")

# Global view

In [None]:
# fig = plt.figure(figsize=(12, 4), dpi=100)
# gs = gridspec.GridSpec(3, 1, height_ratios=[4, 2, 1])
fig, ax = plt.subplots(5, 1, figsize=(12, 6), dpi=100, sharex=True)

# Hypnogram
hypnogram_times_to_plot = np.stack([np.arange(hypnogram.size), np.arange(1, hypnogram.size+1)], axis=1).flatten() * page_duration
hypnogram_int_to_plot = np.stack([hypnogram_int, hypnogram_int], axis=1).flatten()
ax[0].plot(hypnogram_times_to_plot, hypnogram_int_to_plot, linewidth=0.8)
ax[0].set_yticks([hypno_num_dict[s] for s in state_ids])
ax[0].set_yticklabels(state_ids)
ax[0].set_title("S%02d - Hypnogram" % subject_id)

# Signal
signal_to_plot = signals[annot_chn]
signal_times_to_plot = np.arange(signal_to_plot.size) / fs
ax[1].plot(signal_times_to_plot, signal_to_plot, linewidth=0.6)
ax[1].set_title("S%02d - Signal %s" % (subject_id, annot_chn))

# Marks
ax[2].plot(signal_times_to_plot, marks_bin, linewidth=0.6)
ax[2].set_title("S%02d - Marks" % subject_id)

# Detections
ax[3].plot(signal_times_to_plot, detections_bin, linewidth=0.6)
ax[3].set_title("S%02d - Detections" % subject_id)

# Probabilities
proba_times_to_plot = np.arange(probabilities.size) / fs_proba
ax[4].plot(proba_times_to_plot, probabilities, linewidth=0.6)
ax[4].set_title("S%02d - Probabilities" % subject_id)
ax[4].set_xlabel("Time [s]")

plt.tight_layout()
plt.show()

# Distribution of marks and predictions over night

In [None]:
marks_centers = marks.mean(axis=1) / fs
detections_centers = detections.mean(axis=1) / fs

pages_per_bin = 3
n_pages = hypnogram.size
bin_width = pages_per_bin * page_duration
last_second = n_pages * page_duration
bins = np.arange(0, last_second+1, bin_width)

fig, ax = plt.subplots(2, 1, figsize=(12, 4), dpi=100, sharex=True, sharey=True)
ax[0].hist(marks_centers, bins=bins)
ax[0].set_title("Marks centers")

ax[1].hist(detections_centers, bins=bins)
ax[1].set_title("Detections centers")
ax[1].set_xlabel("Time [s]")

plt.tight_layout()
plt.show()

# Single channel visualization with bandpass filters

In [None]:
def apply_fir_filter_np(signal, kernel):
    new_signal = convolve2d(signal.reshape(1, -1), kernel.reshape(1, -1), mode="same")
    new_signal = new_signal.flatten()
    return new_signal


def lowpass_np(signal, fs, cutoff, filter_duration_ref=6, wave_expansion_factor=0.5):
    numtaps = fs * filter_duration_ref / (cutoff ** wave_expansion_factor)
    numtaps = int(2 * (numtaps // 2) + 1)  # ensure odd numtaps
    lp_kernel = firwin(numtaps, cutoff=cutoff, window="hamming", fs=fs).astype(np.float32)
    lp_kernel /= lp_kernel.sum()
    new_signal = apply_fir_filter_np(signal, lp_kernel)
    return new_signal


def highpass_np(signal, fs, cutoff, filter_duration_ref=6, wave_expansion_factor=0.5):
    numtaps = fs * filter_duration_ref / (cutoff ** wave_expansion_factor)
    numtaps = int(2 * (numtaps // 2) + 1)  # ensure odd numtaps
    lp_kernel = firwin(numtaps, cutoff=cutoff, window="hamming", fs=fs).astype(np.float32)
    lp_kernel /= lp_kernel.sum()
    # HP = delta - LP
    hp_kernel = -lp_kernel
    hp_kernel[numtaps//2] += 1
    new_signal = apply_fir_filter_np(signal, hp_kernel)
    return new_signal


def bandpass_np(signal, fs, lowcut, highcut, filter_duration_ref=6, wave_expansion_factor=0.5):
    new_signal = signal
    if lowcut is not None:
        new_signal = highpass_np(
            new_signal, fs, lowcut, filter_duration_ref, wave_expansion_factor)
    if highcut is not None:
        new_signal = lowpass_np(
            new_signal, fs, highcut, filter_duration_ref, wave_expansion_factor)
    return new_signal

In [None]:
bandpass_filters = {'theta': [4, 8], 'alpha': [8, 12], 'sigma': [11, 16], 'beta': [16, 30]}
n_filters = len(bandpass_filters)
# Filter
band_names = ['theta', 'alpha', 'sigma', 'beta']
band_signals = {}
for band_name in band_names:
    print("Processing band %s" % band_name, flush=True)
    lowcut, highcut = bandpass_filters[band_name]
    band_signal = bandpass_np(signals[annot_chn], fs, lowcut, highcut).astype(np.float32)
    band_signals[band_name] = band_signal
print("Done.")

In [None]:
def plot_signal_with_filters(page_idx, show_only_n2=False, border=5):
    fig = plt.figure(figsize=(12, 4), dpi=120)
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 6])
    
    page_loc = n2_pages[page_idx] if show_only_n2 else page_idx
    page_loc = max(page_loc, 1)
    
    # Probability
    ax = fig.add_subplot(gs[0])
    start_sample_proba = int(page_loc * page_duration * fs_proba - border * fs_proba)
    end_sample_proba = int((page_loc + 1) * page_duration * fs_proba + border * fs_proba)
    time_axis_proba = np.arange(start_sample_proba, end_sample_proba) / fs_proba
    ax.plot(time_axis_proba, probabilities[start_sample_proba:end_sample_proba], linewidth=1.1, color=viz.PALETTE['red'])
    ax.set_ylim([-0.05, 1.05])
    ax.set_xlim(start_sample_proba / fs_proba, end_sample_proba / fs_proba)
    ax.set_xticks([])
    ax.set_xticks(np.arange(start_sample_proba / fs_proba, end_sample_proba / fs_proba, 1), minor=True)
    ax.set_yticks([0.5], minor=True)
    ax.grid(which="minor")
    ax.set_title("S%02d - Page %d (Stage %s)" % (subject_id, page_loc, hypnogram[page_loc]))
    
    # Signal + bands
    ax = fig.add_subplot(gs[1])
    start_sample = int(page_loc * page_duration * fs - border * fs)
    end_sample = int((page_loc + 1) * page_duration * fs + border * fs)
    time_axis = np.arange(start_sample, end_sample) / fs
    ax.plot(time_axis, signals[annot_chn][start_sample:end_sample], linewidth=0.6, color=viz.PALETTE['dark'])
    ax.set_xlim(start_sample / fs, end_sample / fs)
    dy = 70
    offset = -150 + dy
    offsets_list = []
    grid_list = []
    for i, band_name in enumerate(band_names):
        offset -= dy
        offsets_list.append(offset)
        grid_list.extend([offset - 20, offset - 10, offset + 10, offset + 20])
        ax.plot(
            time_axis, 
            band_signals[band_name][start_sample:end_sample] + offset, 
            linewidth=0.6, color=viz.PALETTE['blue'],
            label=band_name)
    ax.set_ylim([offset - 50, 150])
    ax.set_yticks(offsets_list + [-100, -50, 0, 50, 100])
    ax.set_yticks(grid_list, minor=True)
    ax.set_yticklabels(band_names + [-100, -50, 0, 50, 100])
    ax.set_xticks(np.arange(start_sample / fs, end_sample / fs, 1), minor=True)
    ax.grid(which="minor")
    ax.set_xticks(np.arange(start_sample / fs, end_sample / fs + 0.1, 5))
    
    # Detections and marks
    marks_in_page = utils.filter_stamps(marks, start_sample, end_sample)
    for s_mark in marks_in_page:
        ax.fill_between(s_mark / fs, -100, 0, alpha=0.3, facecolor=viz.PALETTE['blue'])
    dets_in_page = utils.filter_stamps(detections, start_sample, end_sample)
    for s_mark in dets_in_page:
        ax.fill_between(s_mark / fs, 100, 0, alpha=0.3, facecolor=viz.PALETTE['red'])
        
    plt.tight_layout()
    plt.show()

In [None]:
show_only_n2 = True
start_value = 0

max_pages = n2_pages.size if show_only_n2 else hypnogram.size
max_pages -= 1
style = {'description_width': 'initial'}
layout= widgets.Layout(width='1000px')
widgets.interact(
    lambda page_idx: plot_signal_with_filters(page_idx, show_only_n2=show_only_n2),
    page_idx=widgets.IntSlider(
        min=0, max=max_pages, step=1, value=start_value, continuous_update=False, style=style, layout=layout));

# All channels visualization