In [None]:
import os
from pprint import pprint
import sys
import json

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sys.path.append('..')

from sleeprnn.common import viz
from sleeprnn.data import utils

viz.notebook_full_width()
%matplotlib inline

DATASET_DIR = os.path.abspath(os.path.join(utils.PATH_DATA, 'moda'))

In [None]:
fs = 256
border_duration = 30
dataset = np.load(os.path.join(DATASET_DIR, 'segments/moda_preprocessed_segments.npz'))
signals = dataset['signals']
labels = dataset['labels']
subjects = dataset['subjects']
phases = dataset['phases']

# Sanity checks

In [None]:
for key in dataset.files:
    print(key, dataset[key].shape, dataset[key].dtype)

In [None]:
np.sum(phases == 1)

In [None]:
np.sum(phases == 2)

In [None]:
np.unique(subjects).size

In [None]:
np.unique(subjects[phases == 1]).size

In [None]:
np.unique(subjects[phases == 2]).size

In [None]:
signals.shape[0] * 115 / 3600

In [None]:
total_pages = signals.shape[0] * 120 / 20
n_folds = 5
test_pages = int(total_pages / n_folds)
val_pages = int((total_pages - test_pages) / n_folds)
train_pages = total_pages - test_pages - val_pages
print("test %d, val %d, train %d" % (test_pages, val_pages, train_pages))

In [None]:
signals_valid = signals[:, 30*256:-30*256]
signals_valid.shape

In [None]:
plt.hist(signals_valid.flatten(), bins=30)
plt.title("Signal amplitudes (uV)")
plt.show()

In [None]:
signals_valid.min(), signals_valid.max(), signals_valid.mean(), signals_valid.std()

In [None]:
min_in_segments = signals_valid.min(axis=1)
max_in_segments = signals_valid.max(axis=1)

In [None]:
plt.hist(min_in_segments)
plt.title("Min val in each segment (uV)")
plt.show()

In [None]:
plt.hist(max_in_segments)
plt.title("Max val in each segment (uV)")
plt.show()

In [None]:
choose = np.where(min_in_segments < -220)[0][0]
chosen_signal = signals_valid[choose, :]
print(chosen_signal.size / fs)
start_sec = 70
end_sec = 80

time_axis = np.arange(chosen_signal.size) / fs
plt.figure(figsize=(10, 3), dpi=80)
plt.plot(time_axis[fs*start_sec:fs*end_sec], signals[choose, :][fs*start_sec:fs*end_sec], linewidth=0.8)
plt.show()

In [None]:
_, counts = np.unique(subjects[phases==2], return_counts=True)
n_blocks, freq = np.unique(counts, return_counts=True)
print(n_blocks, freq)

# Spindle labels distributions across the length of the segments

In [None]:
np.unique(np.clip(labels.sum(axis=0), a_min=None, a_max = 0), return_counts=True)

In [None]:
labels_valid = labels[:, border_duration*fs:-border_duration*fs]
labels_valid.shape
plt.figure(figsize=(10, 2), dpi=140)
plt.plot(labels_valid.sum(axis=0), linewidth=0.8, linestyle="none", marker='o', markersize=1)
plt.show()


In [None]:
np.where(labels_valid.sum(axis=0)<20)[0].size - 36

In [None]:
36/fs, 33/fs

In [None]:
labels_valid.sum(axis=0).min()

# Activity distribution across segments

In [None]:
labels_valid = labels[:, border_duration*fs:-border_duration*fs]
n_labels_in_segment = labels_valid.sum(axis=1)
n_seconds_in_segment = n_labels_in_segment / fs
plt.hist(n_seconds_in_segment, bins=40)
plt.xlabel("seconds of spindles")
plt.show()

In [None]:
near_empty_locs = np.where(n_labels_in_segment < fs*0.3)[0]
near_empty_labels = labels_valid[near_empty_locs]
print(near_empty_locs.size)
np.unique(near_empty_labels.sum(axis=1), return_counts=True)

In [None]:
start_sec = 110
end_sec = 120
plt.figure(figsize=(10, 2), dpi=100)
plt.plot(signals_valid[near_empty_locs[0], :][fs*start_sec:fs*end_sec], linewidth=0.8)
plt.show()

In [None]:
labels_valid = labels[:, border_duration*fs:-border_duration*fs]
spindles = [utils.seq2stamp(l_valid) for l_valid in labels_valid]

In [None]:
n_spindles_per_segment = [s.shape[0] for s in spindles]
plt.hist(n_spindles_per_segment)
plt.xlabel("n spindles in segments")
plt.show()

In [None]:
spindles_all = np.concatenate(spindles, axis=0)
durations = (spindles_all[:, 1] - spindles_all[:, 0]) / fs
plt.hist(durations)
plt.xlabel("spindle duration")
plt.show()

# Density and number of spindles per subject

In [None]:
subject_ids = np.unique(subjects)
densities = []
numbers = []
n_blocks = []
for subject_id in subject_ids:
    subject_locs = np.where(subjects == subject_id)[0]
    labels_subject = labels[subject_locs]
    labels_subject_valid = labels_subject[:, border_duration*fs:-border_duration*fs]
    subject_spindles = [utils.seq2stamp(l_valid) for l_valid in labels_subject_valid]
    n_spindles = np.concatenate(subject_spindles, axis=0).shape[0]
    n_seconds = subject_locs.size * 115
    density_spm = n_spindles / n_seconds * 60
    densities.append(density_spm)
    numbers.append(n_spindles)
    n_blocks.append(subject_locs.size)
densities = np.array(densities)
numbers = np.array(numbers)
n_blocks = np.array(n_blocks)

In [None]:
plt.hist(densities)
plt.xlabel("Density (spm)")
plt.show()

In [None]:
plt.hist(numbers)
plt.xlabel("Spindles per subject")
plt.show()

In [None]:
numbers[numbers <= 0].size


In [None]:
plt.hist(numbers[n_blocks==10], bins=30)
plt.xlabel("Spindles per subject")
plt.show()

In [None]:
print(numbers[n_blocks==10].min())

# Visualizar páginas 

In [None]:
page_duration = 20
print("There are %d segments of 115s" % signals.shape[0])
print("There are %s pages of 20s per segment" % (signals[:, border_duration*fs:-border_duration*fs].shape[1] / page_duration / fs))
# We adopt the strategy here of adding 2.5s of "border" at the beginning and end of each block

In [None]:
def draw_signal(segment_id, page_id):
    fig, ax = plt.subplots(1, 1, figsize=(12, 2), dpi=140)
    
    border_block = int((border_duration - 2.5) * fs)
    block_signal = signals[segment_id, border_block:-border_block]
    block_label = labels[segment_id, border_block:-border_block]
    
    start_sample = int(page_id * page_duration * fs)
    end_sample = int(start_sample + page_duration * fs)
    
    page_signal = block_signal[start_sample:end_sample]
    page_label = block_label[start_sample:end_sample]
    page_label = np.clip(page_label, a_min=0, a_max=1)
    
    time_axis = np.arange(-2.5, 115+2.5, 1/fs)
    page_time_axis = time_axis[start_sample:end_sample]
    
    mask = ((time_axis < 2.5) | (time_axis > 115)).astype(np.int32)
    page_mask = mask[start_sample:end_sample]
    
    ax.plot(page_time_axis, page_signal, linewidth=0.7)
    ax.fill_between(page_time_axis, page_mask * -150, page_mask * 150, facecolor=viz.PALETTE['grey'], linewidth=0.2, alpha=0.3)
    ax.fill_between(page_time_axis, page_label * -50, page_label * -60, facecolor=viz.PALETTE['red'], linewidth=0.2)
    ax.set_title("Segment %d, page %s (Subject %s, Phase %d)" % (segment_id, page_id, subjects[segment_id], phases[segment_id]))
    ax.set_ylim([-150, 150])
    ax.set_xlim([page_time_axis[0], page_time_axis[-1]])
    
    plt.show()

In [None]:
segment_id = 200

style = {'description_width': 'initial'}
layout= widgets.Layout(width='1000px')
widgets.interact(
    lambda page_id: draw_signal(segment_id, page_id),
    page_id=widgets.FloatSlider(
        min=0, max=5, step=0.5, value=0, 
        continuous_update=False,
        style=style,
        layout=layout
    ));