# Set up

In [None]:
import os
import shutil
import random
import tqdm
import glob
import re

import pyedflib
from scipy import signal

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score
import tensorflow as tf
import keras
import keras.backend as K
from keras.layers import Conv1D, Conv1DTranspose, LeakyReLU, Lambda
from keras.layers import Input, Multiply, LayerNormalization, GlobalAveragePooling1D, Softmax, Add
from keras.utils.vis_utils import plot_model
from keras.models import load_model
from keras.callbacks import EarlyStopping
from livelossplot import PlotLossesKeras

In [None]:
# U-Time data
DATASET_PATH = f'/***/TEST_SET_3600/'
TARGET       = 'shhs1-200006'

hyp_path = os.path.join(DATASET_PATH, TARGET, f'{TARGET}_hyp_FIR.npz')
edf_path = os.path.join(DATASET_PATH, TARGET, f'{TARGET}_FIR.edf')
with pyedflib.EdfReader(edf_path) as edf:
    signal_headers = [edf.getSignalHeader(0)]
    signal_headers[0]['sample_rate'] = 100
    header = edf.getHeader()
    hyp = np.load(hyp_path)['arr_0']
    eeg = edf.readSignal(0).reshape((len(hyp), -1, 1))
print(edf_path)
print(hyp.shape, eeg.shape)

In [None]:
_stage_label = ['W', 'N1', 'N2', 'N3', 'R']

# for plot
_ylim1 = [-150., 150.]
_ylim3 = [0., 30.]
_width = 1/18
_alpha = .8
_fontsize = 30

five_stage_idx = [-1] * 5
for i in range(len(hyp)):
    for j in range(5):
        if hyp[i] == j and five_stage_idx[j] == -1:
            five_stage_idx[j] = i
            break
five_stage_idx[2] = 212
five_stage_idx

In [None]:
def init_seed(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    session_conf = tf.compat.v1.ConfigProto(
        intra_op_parallelism_threads=1,
        inter_op_parallelism_threads=1
    )
    sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
    tf.compat.v1.keras.backend.set_session(sess)

def _fft(data):
    Fs = 120
    data = data.ravel()
    N = len(data)
    if N > 4000:
        print('Seems to be some error; got data length:', N)
        return
    window = np.hanning(N)

    freq = np.fft.fftfreq(N, 1.0/Fs)
    F = np.fft.fft(data * window)
    F = abs(F)/(N/2)
    F[0] = F[0]/2
    return freq[:N//2], F[:N//2] * 1/(sum(window)/2/N)

def print_fft(data, idx=29):
    """comparison between raw/reconstructed data"""    
    data = data[idx].ravel()
    
    x_fft, y_fft = _fft(data)

    _, ax = plt.subplots(1, 2, figsize=(24, 4))
    ax[0].plot(data)
    ax[0].set_title(f'Raw data: {_stage_label[int(y_test[idx])]}')
    ax[1].plot(x_fft, y_fft, color='tab:orange')
    ax[1].set_title('FFT(Raw data)')
    plt.show()

In [None]:
class LeakyAlt(keras.layers.Layer):
    def __init__(self, alpha=0.3, **kwargs):
        super(LeakyAlt, self).__init__(**kwargs)
        self.alpha = alpha

    def call(self, inputs):
        return -K.relu(-inputs, alpha=self.alpha)
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'alpha'      : self.alpha
        })
        return config

In [None]:
class StylizeModel:
    def __init__(self, eeg, hyp, model=None) -> None:
        if model is None:
            self.model = keras.models.load_model('./model/model_7746.h5', custom_objects={'LeakyReLU': LeakyReLU, 'LeakyAlt': LeakyAlt, 'root_mean_squared_error': StylizeModel.root_mean_squared_error})
            print('Model loaded')
        else:
            self.model = model
        self.eeg = eeg
        self.hyp = hyp
        self.encoder, self.decoder = self._get_en_decoder()

        if self.eeg is not None:
            self.latent = self.encoder.predict(self.eeg)
        w, _ = self.decoder.get_layer(name='squeeze').get_weights()
        self.n_mfilter = w.shape[1] # w.shape: (kernel_size, f_in, f_out)
        self.edge = self.encoder.output_shape[-1] - self.n_mfilter
        self.imp_ascending = self._get_importance(load=True)

        self.rmse_transition = self.rec_transition = None

    def evaluate(self):
        """show stylize model's performance"""
        pred_test_oh, pred_test_wave = self.model.predict(x_test)
        pred_test_score = np.argmax(pred_test_oh, axis=-1)

        print(pd.DataFrame(confusion_matrix(y_test, pred_test_score), columns=_stage_label, index=_stage_label))
        print(pd.DataFrame(confusion_matrix(y_test, pred_test_score, normalize='true'), columns=_stage_label, index=_stage_label))
            
        print(classification_report(y_test, pred_test_score, target_names=_stage_label, digits=4, zero_division=0))

        print_fft(x_test)
        print_fft(pred_test_wave)
        
    def _get_en_decoder(self):
        StylizeModel.init_seed()
        kernel_size = 32
        inputs       = Input((3600, 1))
        x_1          = self.model.get_layer(name=f'ord_{kernel_size}_1')(inputs)
        x_2          = self.model.get_layer(name=f'alt_{kernel_size}_1')(inputs)
        x_1          = self.model.get_layer(name=f'ord_{kernel_size}_2')(x_1)
        x_2          = self.model.get_layer(name=f'alt_{kernel_size}_2')(x_2)
        outputs      = Add()([x_1, x_2])
        encoder = keras.Model(inputs, outputs)

        inputs       = Input(shape=encoder.output_shape[1:])
        mask_layer   = Input(shape=encoder.output_shape[1:])
        x            = Multiply()([inputs, mask_layer])
        enc          = self.model.get_layer(name=f'dec_{kernel_size}')(x)
        enc          = self.model.get_layer(name='rec')(enc)
        scoring      = self.model.get_layer(name='lambda')(x)
        scoring      = self.model.get_layer(name='ln')(scoring)
        scoring      = self.model.get_layer(name='squeeze')(scoring)
        scoring      = GlobalAveragePooling1D()(scoring)
        scoring      = Softmax(name='scoring')(scoring)
        decoder = keras.Model([inputs, mask_layer], [scoring, enc])
        
        return encoder, decoder
        
    def _get_importance(self, load=True, plot=False):
        imp_ascending = []
        if load:
            imp_ascending = [16, 23, 17, 2, 5, 7, 4, 25, 20, 21, 12, 15, 22, 26, 10, 24, 11, 27, 19, 9, 1, 8, 14, 18, 13, 3, 6, 0]
        else:
            l_test = self.encoder.predict(x_test)
            for _ in tqdm.tqdm(range(self.n_mfilter)):
                pred_acc = []
                for i in range(self.n_mfilter):
                    # no need to calculate
                    if i in imp_ascending:
                        pred_acc.append(0)
                        continue
                    # set mask
                    mask = np.ones_like(l_test)
                    mask[:, :, self.edge//2+i] = 0
                    for l in imp_ascending:
                        mask[:, :, self.edge//2+l] = 0
                    # evaluation
                    pred_oh = self.decoder.predict([l_test, mask])[0]
                    pred_acc.append(accuracy_score(y_test, np.argmax(pred_oh, axis=-1)))

                # save importance
                imp_ascending.append(np.argmax(pred_acc))
                
                if plot:
                    # plot
                    _, ax = plt.subplots(1, 1, figsize=(10, 3))
                    ax.plot(pred_acc)
                    ax.scatter(np.argmax(pred_acc), np.max(pred_acc), c='r')
                    ax.set_ylim([0., .9])
                    ax.set_ylabel('Accuracy')
                    ax.set_xticks(range(self.n_mfilter))
                    ax.grid()
                    plt.show()
        return imp_ascending
    
    def make_edf(self, dir_name=''):
        """save stylized signal to EDF"""
        for i in range(len(self.imp_ascending)+1):
            mask = np.ones_like(self.latent)
            for j in range(i):
                mask[:, :, self.edge//2+self.imp_ascending[j]] = 0
            pred_wave = self.decoder.predict([self.latent, mask])[1]
            
            pred_wave = signal.resample_poly(pred_wave, 100, 120, axis=1)
            pred_wave = pred_wave.reshape((1, -1)) # (n_channels, len_data)
            
            base = f'squeeze_{i}'
            dir_path = os.path.join('REC', dir_name, TARGET, base) # REC/dir_name/shhs1-200006/squeeze_0
            os.makedirs(dir_path, exist_ok=True)
            with pyedflib.EdfWriter(os.path.join(dir_path, f'{base}.edf'), n_channels=1, file_type=pyedflib.FILETYPE_EDF) as f:
                f.setSignalHeaders(signal_headers)
                f.setHeader(header)
                f.writeSamples(pred_wave)
            shutil.copyfile('/***/shhs1-200006_hyp_FIR.npz',
                            os.path.join(dir_path, 'shhs1-200006_hyp.npz'))
            print('Saved to', dir_path)
            
    def _get_transition(self):
        self.rmse_transition = []
        self.rec_transition = []
        for i in tqdm.tqdm(range(len(self.imp_ascending)+1), desc='### calculate transition'):
            mask = np.ones_like(self.latent)
            for j in range(i):
                mask[:, :, self.edge//2+self.imp_ascending[j]] = 0
            pred_wave = self.decoder.predict([self.latent, mask])[1]
            
            rmse_val = StylizeModel.root_mean_squared_error(eeg, pred_wave)
            self.rmse_transition.append(rmse_val.numpy())
            
            rec = np.array([pred_wave[j] for j in five_stage_idx]) # (5, 3600, 1)
            self.rec_transition.append(rec)
        
    def show_rmse_transition(self):
        if self.rmse_transition is None:
            self._get_transition()

        plt.figure(figsize=(15, 3))
        plt.plot(self.rmse_transition)
        plt.xticks(range(len(self.rmse_transition)))
        plt.xlabel('Number of latent vector removed', fontsize=20)
        plt.ylabel('RMSE', fontsize=20)
        plt.grid()
        plt.show()
        
    def show_rec_transition(self, row=1):
        if self.rec_transition is None:
            self._get_transition()
            
        def _plot_some(ax, data, label, c=None):
            ax.plot(data, label=label, c=c)
            ax.set_ylim(_ylim1)
            x_fft, y_fft = _fft(data.ravel())
            ax3 = ax.twinx().twiny()
            ax3.plot(x_fft, y_fft, c='tab:orange', label=f'FFT({label})')
            ax3.set_ylim(_ylim3)
            h1, l1 = ax.get_legend_handles_labels()
            h3, l3 = ax3.get_legend_handles_labels()
            ax.legend(h1 + h3, l1 + l3)
            plt.tight_layout()

        _, ax = plt.subplots(1+row, 5, figsize=(30, 4*(1+row)), sharey='row')
        if 'x_test' in globals():
            for i, idx in enumerate(five_stage_idx):
                ax1 = ax[0][i]
                _plot_some(ax1, x_test[idx], label='raw')
                ax1.set_title(f'Raw data, {_stage_label[i]}', fontsize=_fontsize)
        for i in tqdm.tqdm(range(row), desc='### plot transition'):
            for j in range(5):
                ax1 = ax[i+1][j]
                _plot_some(ax1, self.rec_transition[i][j], label='rec', c='navy') # specify threshold
                ax1.set_title(f'Removed {i}th({self.imp_ascending[i]}), {_stage_label[j]}', fontsize=_fontsize)
        plt.show()
        
    def compare(self, idx, extent):
        """comparison of raw/rec signals"""
        print(f'### Signal index: {idx}, stylized to {extent}/{self.n_mfilter}')
        
        mask = np.ones_like(self.latent)
        for j in range(extent):
            mask[:, :, self.edge//2+self.imp_ascending[j]] = 0
        stylized = self.decoder.predict([self.latent, mask])[1]

        _, ax = plt.subplots(2, 1, figsize=(20, 6), sharey='row')
        ax1 = ax[0]
        ax2 = ax1.twinx()
        ax3 = ax2.twiny()
        if 'x_test' in globals():
            ax1.plot(x_test[idx], label='raw')
            x_fft, y_fft = _fft(x_test[idx].ravel())
            ax3.bar(x_fft, y_fft, color='tab:orange', width=_width, alpha=_alpha)

        ax1.set_xticks(np.arange(0, 3600+1, 600))
        ax1.set_ylim(_ylim1)
        ax3.set_xlabel('frequency[Hz]', fontsize=_fontsize)
        ax3.set_ylim(_ylim3)
        plt.tight_layout()
        ##################################
        ax1 = ax[1]
        ax2 = ax1.twinx()
        ax3 = ax2.twiny()
        ax1.plot(stylized[idx], c='navy')
        x_fft, y_fft = _fft(stylized[idx].ravel())
        ax3.bar(x_fft, y_fft, color='tab:orange', width=_width, alpha=_alpha)

        ax1.set_xticks(np.arange(0, 3600+1, 600))
        ax1.set_ylim(_ylim1)
        ax1.set_xlabel(f'Time[s]', fontsize=_fontsize)
        ax3.set_ylim(_ylim3)
        plt.tight_layout()
        plt.show()
        
    def show_ttp(self):
        """show TTP(tend to predict) and plot"""
        self.sensitive_to = [set(), set(), set(), set(), set()]
        self.ttp = {}

        for i, imp in tqdm.tqdm(enumerate(self.imp_ascending), desc='calculate ttp       ', total=self.n_mfilter):
            mask = np.zeros_like(self.latent)
            mask[:, :, self.edge//2+imp] = 1
            pred_oh = self.decoder.predict([self.latent, mask])[0]
            
            cm = confusion_matrix(hyp, np.argmax(pred_oh, axis=-1))
            pred_amb = np.argmax(np.sum(cm, axis=0))
            self.sensitive_to[pred_amb].add(i)
            self.ttp[i] = _stage_label[pred_amb]

        print('imp_ascending:', self.imp_ascending)

        for i in range(5):
            print(f'Sensitive to {_stage_label[i]}: {self.sensitive_to[i]}')

        print(self.ttp)
        
        imp_score = [0] * len(self.imp_ascending)
        for i, idx in enumerate(self.imp_ascending):
            imp_score[idx] = i / (len(self.imp_ascending)-1) * 100

        plt.figure(figsize=(20, 2))
        sns.heatmap([imp_score], cmap='Reds', annot=True, cbar=True, fmt='.1f')
        plt.title('Importance')
        plt.xlabel('index')
        plt.yticks([])
        plt.show()

        plt.figure(figsize=(10, 5))
        colorlist = ['r', 'b', 'c', 'm', 'g']
        for j in range(5):
            for i in self.sensitive_to[j]:
                plt.scatter(j, imp_score[i], color=colorlist[j])
                # plt.text(j+.05, imp_score[i], f'({i}){imp_score[i]:.1f}')
                plt.text(j+.05, imp_score[i], f'({i})')
        plt.xlabel('sensitive to')
        plt.xticks(range(5), _stage_label)
        plt.ylabel('Importance')
        plt.title('Importance / Tend to predict')
        plt.show()

    @classmethod
    def init_seed(cls, seed=0):
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        np.random.seed(seed)
        tf.random.set_seed(seed)
        session_conf = tf.compat.v1.ConfigProto(
            intra_op_parallelism_threads=1,
            inter_op_parallelism_threads=1
        )
        sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
        tf.compat.v1.keras.backend.set_session(sess)

    @classmethod
    def root_mean_squared_error(cls, y_true, y_pred):
        return K.sqrt(K.mean(K.square(y_pred - y_true))) 

# New model training

In [None]:
def load_edf_hyp(dataset_path):
    x = np.empty((0, 3600, 1))
    y = np.empty((0,))
    all_dir_name = sorted(glob.glob(os.path.join(dataset_path, 'shhs1-200*/')))
    for dir_name in all_dir_name:
        data_stage_label = dir_name[-4:-1]
        hyp_path = os.path.join(dir_name, f'shhs1-200{str(data_stage_label).zfill(3)}_hyp_FIR.npz')
        edf_path = os.path.join(dir_name, f'shhs1-200{str(data_stage_label).zfill(3)}_FIR.edf')
        with pyedflib.EdfReader(edf_path) as edf:
            hyp = np.load(hyp_path)['arr_0']
            eeg = edf.readSignal(0).reshape((len(hyp), -1, 1))
        x = np.concatenate([x, eeg], axis=0)
        y = np.concatenate([y, hyp], axis=0)
    return x, y

x_train, y_train = load_edf_hyp('/***/TRAIN_SET_3600/')
x_val, y_val     = load_edf_hyp('/***/VAL_SET_3600/')
x_test, y_test   = load_edf_hyp('/***/TEST_SET_3600/')
print(x_train.shape, y_train.shape)
print(x_val.shape, y_val.shape)
print(x_test.shape, y_test.shape)

y_train_oh, y_val_oh = y_train.reshape((-1, 1)), y_val.reshape((-1, 1))
ohe                  = OneHotEncoder(categories="auto", sparse=False).fit(y_train_oh)
y_train_oh, y_val_oh = ohe.transform(y_train_oh), ohe.transform(y_val_oh)
print(y_train_oh.shape, y_val_oh.shape)

In [None]:
np.savez('test_data', x_test, y_test)

In [None]:
ks = [32]
fi = [5, 32]
st = [6, 8]
edge = 32 - 28

StylizeModel.init_seed()
inputs  = keras.layers.Input(shape=x_train.shape[1:])

# Set en/decoder
enc_1   = Conv1D(fi[0], kernel_size=ks[0], strides=st[0], padding='same', activation=LeakyReLU(), name=f'ord_{ks[0]}_1')(inputs)
enc_2   = Conv1D(fi[0], kernel_size=ks[0], strides=st[0], padding='same', activation=LeakyAlt(),  name=f'alt_{ks[0]}_1')(inputs)
enc_1   = Conv1D(fi[1], kernel_size=ks[0], strides=st[1], padding='same', activation=LeakyReLU(), name=f'ord_{ks[0]}_2')(enc_1)
enc_2   = Conv1D(fi[1], kernel_size=ks[0], strides=st[1], padding='same', activation=LeakyAlt(),  name=f'alt_{ks[0]}_2')(enc_2)

x       = Add()([enc_1, enc_2])
rec     = x
scoring = Lambda(lambda x: x[:, :, edge//2:-edge//2], name='lambda')(x)

# Reconstruction
rec     = Conv1DTranspose(fi[0], kernel_size=ks[0], strides=st[1], padding='same', activation='linear', name=f'dec_{ks[0]}')(rec)
rec     = Conv1DTranspose(    1, kernel_size=ks[0], strides=st[0], padding='same', activation='linear', name='rec')(rec)

# Stage scoring
scoring = LayerNormalization(name='ln')(scoring)
scoring = Conv1D(5, kernel_size=10, strides=2, padding='same', name='squeeze')(scoring)
scoring = GlobalAveragePooling1D()(scoring)
scoring = Softmax(name='scoring')(scoring)

model = keras.Model(inputs=inputs, outputs=[scoring, rec])
model.compile(optimizer=keras.optimizers.Adam(lr=0.0005), # 0.001 or 0.0005
                loss= {'scoring': 'categorical_crossentropy', 'rec': StylizeModel.root_mean_squared_error},
                loss_weights={'scoring': 20., 'rec': 1.},
                metrics={'scoring': 'accuracy'})

plot_model(model, show_shapes=True, show_layer_names=False, dpi=50)

In [None]:
_ = model.fit(x_train, {'scoring': y_train_oh, 'rec': x_train},
              batch_size=128, epochs=100,
              callbacks=[[EarlyStopping(monitor='val_loss', patience=3)], PlotLossesKeras()],
              validation_data=(x_val, {'scoring': y_val_oh, 'rec': x_val}))

In [None]:
model.save('model_7746.h5')

In [None]:
m = StylizeModel(None, None, model=model)
m.evaluate()

# Let's get started

In [None]:
m = StylizeModel(eeg, hyp)
# m.make_edf(dir_name='hoge')
m.compare(idx=123, extent=25)
m.show_rmse_transition()
m.show_rec_transition(row=1)

# U-Time evaluation

## Acc. F1 transition

In [None]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

acc_utime = np.empty((0, 6))
f1_utime = np.empty((0, 6))
dir_names = sorted(glob.glob('/***/squeeze_*'), key=natural_keys)
for i, dir in enumerate(dir_names):
    true = np.load(os.path.join(dir, 'true.npz'))['arr_0']
    pred = np.load(os.path.join(dir, 'pred.npz'))['arr_0']
    cm = confusion_matrix(true, pred)
    acc_6 = np.diag(cm) / np.sum(cm, axis=-1)
    acc_6 = np.append(acc_6, accuracy_score(pred, true))
    acc_utime = np.concatenate([acc_utime, acc_6.reshape((1, 6))])
    f1_6 = f1_score(true, pred, average=None)
    f1_6 = np.append(f1_6, np.mean(f1_6))
    f1_utime = np.concatenate([f1_utime, f1_6.reshape((1, 6))])

    if i == 25:
        print(i)
        print(cm)
        print(classification_report(true, pred, target_names=_stage_label, digits=4, zero_division=0))

plt.figure(figsize=(15, 3))
plt.plot(acc_utime[:, 5], '-x', label='mean')
for i in range(5):
    plt.plot(acc_utime[:, i], label=_stage_label[i])
plt.xticks(range(len(acc_utime)))
plt.xlabel('Number of latent vector removed', fontsize=20)
plt.ylabel('U-Time Accuracy', fontsize=20)
plt.grid()
plt.legend(bbox_to_anchor=(1.0, 0.85, 0.3, 0.2), loc='upper left')
plt.show()

plt.figure(figsize=(15, 3))
plt.plot(f1_utime[:, 5], '-x', label='mean')
for i in range(5):
    plt.plot(f1_utime[:, i], label=_stage_label[i])
plt.xticks(range(len(f1_utime)))
plt.xlabel('Number of latent vector removed', fontsize=20)
plt.ylabel('U-Time F1 score', fontsize=20)
plt.grid()
plt.legend(bbox_to_anchor=(1.0, 0.85, 0.3, 0.2), loc='upper left')
plt.show()

## Kappa transition

In [None]:
# U-Time trained by IIIS+MASS
df = pd.read_csv('/***/evaluation_kappa.csv')
df = df.rename(columns={'cls 0': 'W', 'cls 1': 'N1', 'cls 2': 'N2', 'cls 3': 'N3', 'cls 4': 'REM'})
df = df.drop(len(df)-1, axis=0)

df.plot(figsize=(15, 3), grid=True, style=['x--', '-', '-', '-', '-', '-'],
        xticks=range(len(df)), xlabel='Number of latent vector removed',
        ylim=[-0.1, 0.7], 
        ylabel='Kappa coefficient').legend(loc='upper right')
plt.show()