In [55]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import pywt
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from sleep_data import SleepDataMASS

In [2]:
dataset = SleepDataMASS(load_from_checkpoint=True)


Loading MASS from checkpoint

Pages in train set: 5838
Pages in val set: 1869
Pages in test set: 3094

Pages in MASS dataset: 10801


In [39]:
subset_name = "val"

fs = dataset.get_fs()
signals, exp_marks = dataset.get_augmented_numpy_subset(subset_name, 1, 0)
size_half_page = int(signals.shape[1]/4)
signals = signals[:, size_half_page:-size_half_page]
exp_marks = exp_marks[:, size_half_page:-size_half_page:10]
print("Signal", signals.shape)
print("Expert marks", exp_marks.shape)

nn_marks = np.loadtxt("falcondata/marks_ppt3/predictions_"+subset_name+"_central.csv")
print("NN marks", nn_marks.shape)

n_pages = exp_marks.shape[0]

Signal (1869, 4000)
Expert marks (1869, 400)
NN marks (1869, 400)


In [56]:
def set_cwt(fs):
    # Parameters CWT
    fb = 1.5
    lower_freq = 3
    upper_freq = 40
    n_scales = 32
    # Generate initial and last scale
    s_0 = fs / upper_freq
    s_n = fs / lower_freq
    # Generate the array of scales
    base = np.power(s_n / s_0, 1 / (n_scales - 1))
    scales = s_0 * np.power(base, np.arange(n_scales))
    # Pywavelets
    w = pywt.ContinuousWavelet('cmor')
    w.center_frequency = 1
    w.bandwidth_frequency = fb
    return scales, w
# scales, w = set_cwt(fs)

In [57]:
def get_cwt(segment, scales, w, fs):
    coef, freqs = pywt.cwt(segment, scales, w, 1/fs)
    abs_coef = np.abs(coef)
    return abs_coef * freqs[:, np.newaxis]

In [78]:
def plot_page(page):
    global signals, exp_marks, nn_marks, fs
    fig = plt.figure(figsize=(15, 6)) 
    gs = gridspec.GridSpec(4, 1, height_ratios=[2, 3, 1, 1]) 
    
    time_axis = np.arange(0, signals.shape[1])/fs
    this_signal = signals[page, :]
    this_mark = exp_marks[page, :]
    this_pred = nn_marks[page, :]
    scales, w = set_cwt(fs)
    this_cwt = get_cwt(this_signal, scales, w, fs)

    # Signal
    ax0 = fig.add_subplot(gs[0])
    ax0.plot(time_axis, this_signal)
    ax0.get_xaxis().set_ticks([])
    ax0.get_yaxis().set_ticks([])
    ax0.set_xlim([0, 20])
    
    # CWT (con Pywavelets)
    ax1 = fig.add_subplot(gs[1])
    ax1.imshow(this_cwt, interpolation=None, aspect='auto', extent=[1, 20, 32, 1])
    
    # ax1.plot(this_mark)
    
    # Expert mark
    
    # NN mark
    plt.tight_layout()
    plt.show()

interact(plot_page, page=widgets.IntSlider(min=1,max=n_pages,step=1,value=1));


interactive(children=(IntSlider(value=1, description='page', max=1869, min=1), Output()), _dom_classes=('widge…