# Imports

In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import glob

from utils import read_mesa_data, set_random_seed
from sklearn.model_selection import train_test_split

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, cohen_kappa_score, roc_auc_score, average_precision_score
from tensorflow.keras.utils import to_categorical

from score import Score
from models import ModelCLA
from models.utils import cat_crossentropy_cut

from tensorflow.keras import optimizers

# Settings

In [2]:
# paths to edf files
data_paths = glob.glob("mesa/data/*.edf")

# sampling freq
fs = 256

# batch size
batch_size = 1

# seed
seed = 10
set_random_seed(seed)

# Model

In [3]:
# number of features
num_features = 1

# create the model
model = ModelCLA((2**23, num_features))

# metrics
AUPRC = tf.keras.metrics.AUC(curve='PR', name="AUPRC")
AUROC = tf.keras.metrics.AUC(curve='ROC', name="AUROC")

# compile model
model.compile(
        loss={'arousal': 'binary_crossentropy', 'stage': cat_crossentropy_cut},
        optimizer=optimizers.Adam(learning_rate=1e-4),
        metrics={'arousal': [AUPRC, AUROC], 'stage': "accuracy"},
        loss_weights={'arousal': 1., 'stage': 1.}
        )

tensor_in = tf.random.uniform((1, 2**23, num_features))
tensor_out = model(tensor_in)[0]
out_shape = tensor_out.shape[1]
r = tensor_in.shape[1] // tensor_out.shape[1]
del tensor_in, tensor_out

# Datasets

In [4]:
class CreateDS(tf.keras.utils.Sequence):
    def __init__(self, data_paths, r, batch_size, ds_type="train", pad_len=2**22):
        self.data_paths = data_paths
        self.label_paths = [path.replace("data", "labels").replace(".edf", "-profusion.xml").replace("mesa-sleep", "mesa-sleep-mesa-sleep") for path in data_paths]
        self.ds_type = ds_type
        self.r = int(r)
        self.pad_len = pad_len
        self.batch_size = batch_size if ds_type != "test" else 1
        self.shuffle = True if ds_type == "train" else False
    
    def __len__(self):
        return int(np.ceil(len(self.data_paths) / self.batch_size))
    
    def __getitem__(self, idx):
        start_idx = self.batch_size * idx
        end_idx = self.batch_size * (idx + 1) if self.batch_size * (idx + 1) <= len(self.data_paths) else len(self.data_paths)
        data_batch = []
        arousal_batch = []
        stage_batch = []

        for i in range(start_idx, end_idx):
            data, arousals, stages = self.getitem(i)
            data_batch.append(data)
            arousal_batch.append(arousals)
            stage_batch.append(stages)
        return np.stack(data_batch, axis=0), (np.stack(arousal_batch, axis=0), to_categorical(np.stack(stage_batch, axis=0), num_classes=6))

    
    def getitem(self, idx):
        # read data and label from disk
        data, arousals, stages = read_mesa_data(self.data_paths[idx], self.label_paths[idx])

        # augment
        if self.ds_type == "train":
            data, arousals, stages = self.augment_data(data, arousals, stages)  
        
        # if it is not test, resample labels. In this case pad data and labels
        if self.ds_type == "train" or self.ds_type == "val":
            data = np.pad(data, pad_width=((0, self.pad_len - data.shape[0]), (0, 0)))
            arousals = np.pad(arousals, pad_width=((0, 2*self.pad_len - arousals.shape[0])))
            stages = np.pad(stages, pad_width=((0, 2*self.pad_len - stages.shape[0])))
            arousals = arousals[::2*self.r]
            stages = stages[::2*self.r]
        
        # if it is test, pad data
        if self.ds_type == "test":
            try:
                data = np.pad(data, pad_width=((0, self.pad_len - data.shape[0]), (0, 0)))
            except:
                print(data.shape, self.pad_len - data.shape[0])
                raise ValueError("index can't contain negative values")
        
        return data, arousals, stages
    
    def on_epoch_end(self):
        if self.shuffle:
            idx = np.random.permutation(len(self.data_paths))
            self.data_paths = np.array(self.data_paths)[idx]
            self.label_paths = np.array(self.label_paths)[idx]
    
    def augment_data(self, data, arousals, stages):
        low = 0.9
        high = 1.1
        scale = np.random.uniform(low, high)
        data = data * scale

        return data, arousals, stages

## Datasets

In [5]:
# total number of files
num_files = len(data_paths)

# number of training files
train_split = int(0.5 * num_files)
# number of validation files
val_split   = int(0.2 * num_files)
# number of test files
test_split  = num_files - train_split - val_split    
# split files
train_files, tmp = train_test_split(data_paths, test_size=val_split + test_split, random_state=seed)
val_files, test_files = train_test_split(tmp, test_size=test_split, random_state=seed) 

# datasets
ds_train = CreateDS(train_files, r, batch_size, ds_type="train", pad_len=2**23)
ds_val = CreateDS(val_files, r, batch_size, ds_type="val", pad_len=2**23)
ds_test = CreateDS(test_files, r, batch_size, ds_type="test", pad_len=2**23)

# Training

In [None]:
# callbacks
early_stop = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=20, restore_best_weights=True, mode="max")
save = tf.keras.callbacks.CSVLogger("log_mesa.csv", separator=',', append=False)

# set random seed
np.random.seed(0)
tf.random.set_seed(0)

# fit model
history = model.fit(x=ds_train,
                    validation_data=ds_val,
                    epochs=200,
                    callbacks=[early_stop, save],
                    validation_freq=1,
                    workers=10,
                    use_multiprocessing=True) 

# plot 
plt.figure(dpi=200)
perf = history.history["loss"]
perf_val = history.history["val_loss"]
ep = np.arange(1, len(perf)+1)

plt.plot(ep, perf, label="Training")
plt.plot(ep, perf_val, label="Validation")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.tight_layout()
plt.show()

model.save_weights("weights/mesa_model-C8-L3A")

# Testing

In [None]:
# scorer
scr = Score()

# epoch length for stages
epoch_len = fs * 30
y_true_all = []
y_pred_all = []
y_true_all_a = []
y_pred_all_a = []
y_score_all_a = []

# loop over files
for i, (X, (y_true_a, y_true_s)) in enumerate(ds_test):
    # get true labels
    y_true_a = y_true_a.ravel()
    y_true_s = y_true_s.argmax(axis=-1).ravel()
    # make predictions
    y_pred_a, y_pred_s = model.predict(X)
    y_pred_a = y_pred_a.ravel()
    y_pred_s = y_pred_s.argmax(axis=-1).ravel()

    # over-sample predictions
    y_pred_a = np.repeat(y_pred_a, 2*r)
    y_pred_s = np.repeat(y_pred_s, 2*r)

    # append or cut predictions
    len_true = np.max(y_true_a.shape)
    len_pred = np.max(y_pred_a.shape)

    # cut if predictions are longer
    if len_pred > len_true:     
        y_pred_a = y_pred_a[:-(len_pred - len_true)]
        y_pred_s = y_pred_s[:-(len_pred - len_true)]
    
    # make length of stage labels a multiple of epcoh_len and get epoch labels
    y_true_s = y_true_s[:(y_true_s.shape[0]//epoch_len)*epoch_len].reshape((-1, epoch_len))
    y_true_s = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=1, arr=y_true_s)

    y_pred_s = y_pred_s[:(y_pred_s.shape[0]//epoch_len)*epoch_len].reshape((-1, epoch_len))
    y_pred_s = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=1, arr=y_pred_s)
    
    # epoch-based arousal metrics
    # make length of arousal labels a multiple of epcoh_len and get epoch labels
    y_true_a2 = y_true_a[:(y_true_a.shape[0]//epoch_len)*epoch_len].reshape((-1, epoch_len))
    y_true_a2 = np.apply_along_axis(lambda x: np.max(x), axis=1, arr=y_true_a2)
    
    y_score_a2 = y_pred_a[:(y_pred_a.shape[0]//epoch_len)*epoch_len].reshape((-1, epoch_len))
    y_score_a2 = np.apply_along_axis(lambda x: np.mean(x), axis=1, arr=y_score_a2)
    
    y_pred_a2 = y_pred_a >= 0.5
    y_pred_a2 = y_pred_a2[:(y_pred_a2.shape[0]//epoch_len)*epoch_len].reshape((-1, epoch_len))
    y_pred_a2 = np.apply_along_axis(lambda x: np.max(x), axis=1, arr=y_pred_a2)

    # remove undefined epochs
    y_pred_s = y_pred_s[y_true_s != 5]
    y_true_s = y_true_s[y_true_s != 5]    

    # append predictions 
    y_true_all += y_true_s.tolist()
    y_pred_all += y_pred_s.tolist()
    
    y_true_all_a += y_true_a2.tolist()
    y_pred_all_a += y_pred_a2.tolist()
    y_score_all_a += y_score_a2.tolist()
    
    # get score
    scr.score_record(y_true_a, y_pred_a, record_name=str(i))

# print final scores
print("------------- Results for Arousals -------------")
print(f"AUPRC: {scr.gross_auprc():.3f}, AUROC: {scr.gross_auroc():.3f}")

print("\n------------- Results for Arousals (Epoch-Based) -------------")
print(classification_report(y_true_all_a, y_pred_all_a, digits=3, ))
print(f"AUPRC: {average_precision_score(y_true_all_a, y_score_all_a):.3f}, AUROC: {roc_auc_score(y_true_all_a, y_score_all_a):.3f}")
print(f"Cohen's Kappa: {cohen_kappa_score(y_true_all_a, y_pred_all_a):.3f}")

# print stage scores
print("\n-------------Results for Sleep Stages -------------")
print(classification_report(y_true_all, y_pred_all, digits=3))
print(f"Cohen's Kappa: {cohen_kappa_score(y_true_all, y_pred_all):.3f}")