In [19]:

import os
import math
import time
from datetime import datetime
import logging
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

import yaml
import torch
from torch import nn
import torch.nn.functional as F

from sklearn.model_selection import KFold

import datasource, causal_cnn_models, modules, net_utils

# Set-up

## Logger

In [2]:
logger = logging.getLogger(__name__)

def log(msg):
    logger.debug(msg)


def config_logger(log_file=None):
    r"""Config logger."""
    global logger
    logger.handlers.clear()
    logger.setLevel(logging.DEBUG)
    
    format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    # create console handler and set level to debug
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(format)
    logger.addHandler(ch)

    # create file handler which logs even debug messages
    if log_file:
        fh = logging.FileHandler(log_file)
        fh.setFormatter(format)
        logger.addHandler(fh)


def viz_epoch_batch(epoch, i_batch, x_batch, x_hat_batch, log_path):
    # folder = os.path.join("logs", "recon_vae", log_filename)
    folder = os.path.join(log_path, "recon_vae")
    if not os.path.exists(folder):
        os.makedirs(folder)
    x_batch = x_batch.detach().cpu().numpy()
    x_hat_batch = x_hat_batch.detach().cpu().numpy()
    for i in range(1):
        orig = x_batch[i, 0, :]
        recon = x_hat_batch[i, 0, :]
        _, ax = plt.subplots()
        ax.plot(range(len(orig)), orig)
        # plt.savefig(
        #     f"{folder}/epoch{epoch}_item{i}_orig.png",
        #     format='png', dpi=300, bbox_inches='tight')
        ax.plot(range(len(recon)), recon)
        # plt.ylim((0, 2)) 
        plt.savefig(
            f"{folder}/epoch{epoch}_batch{i_batch}_item{i}.png",
            format='png', dpi=300, bbox_inches='tight')
    plt.close()

    
def save_models(model_file_instance_pairs):
    for model_file in model_file_instance_pairs.keys():
        net = model_file_instance_pairs.get(model_file)
        torch.save(net.state_dict(), model_file)

def load_models(model_file_instance_pairs, device="cpu"):
    for model_file in model_file_instance_pairs.keys():
        net = model_file_instance_pairs.get(model_file)
        net.load_state_dict(
            torch.load(model_file, map_location=device, weights_only=True))
    # No return seems necessary, in-memory models updated.

net_utils.fix_randomness()

In [24]:
params = None 
SIM_FILE = 'simSplitVaeParameterizer'
CFG_FILE = 'config.yml'
with open(CFG_FILE, 'r') as stream:
        params = yaml.safe_load(stream)
        params['seg_len'] = params['hz'] * params['seg_len_sec']
        params['decoder']['width'] = params['seg_len']

tm_sim_start = f"{datetime.now():%Y%m%d%H%M%S}"
params['tm_sim_start'] = tm_sim_start
log_path = f"logs/{SIM_FILE}_{CFG_FILE[:-4]}_{params['data_path'].replace('/','')}_split{params['n_split']}_ecg_{tm_sim_start}"
model_path = f"{log_path}/models"
log_file = (f"{log_path}/{tm_sim_start}.log")
if not os.path.exists(log_path):
        os.makedirs(log_path)
if not os.path.exists(model_path):
        os.makedirs(model_path)
logger = logging.getLogger(__name__)
config_logger(log_file)

# config_logger()

# DEVICE = torch.device(f"cuda:{params['cuda']}" if torch.cuda.is_available() else "cpu")
# DEVICE = "cpu"
DEVICE = "mps"

log(params)

2024-09-25 17:19:57,491 - __main__ - DEBUG - {'batch_size': 32, 'cuda': 0, 'data_path': 'data/mesa/polysomnography/set2x10', 'early_stop_delta': 0.001, 'early_stop_patience': 30, 'hz': 100, 'hz_rr': 5, 'lr': 0.001, 'lr_scheduler_patience': 10, 'max_epoch': 10, 'min_lr': '1e-6', 'seg_len_sec': 30, 'seg_len': 3000, 'n_split': 10, 'n_class': 2, 'in_channels': 1, 'encoder': {'in_channels': 1, 'channels': 128, 'depth': 5, 'reduced_size': 64, 'out_channels': 32, 'kernel_size': 5, 'dropout': 0.3, 'softplus_eps': 0.0001, 'sd_output': True}, 'decoder': {'k': 32, 'width': 3000, 'in_channels': 64, 'channels': 128, 'depth': 5, 'out_channels': 1, 'kernel_size': 5, 'gaussian_out': False, 'softplus_eps': 0.0001, 'dropout': 0.0}, 'tm_sim_start': '20240925171957'}


## Data

In [15]:
"""Data source"""
class_map = {0:0, 1:1, 2:1, 3:1, 4:1, 5:2}
# class_map = {0:0, 1:1, 2:1, 3:2, 4:2, 5:3}
n_class = len(set(class_map.values()))
params['n_class'] = n_class
log(f"class-map: {class_map}")
ds = datasource.MesaDb(
    f"{os.path.expanduser('~')}/data/mesa/polysomnography", data_subdir="set1x20",
    hz=100, n_subjects=-1, hz_rr=params['hz_rr'], class_map=class_map, 
    is_rr_sig=False, is_rsp=False, is_ecg_beats=True, log=log,
)

2024-09-25 16:54:21,289 - __main__ - DEBUG - class-map: {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 2}
2024-09-25 16:54:21,290 - __main__ - DEBUG - Data base-dir:/Users/brenton/data/mesa/polysomnography, data:set1x20, annot:annotations-events-nsrr, hz:100, seg_sec:30, class_map:{0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 2}, n_classes:3
2024-09-25 16:54:21,291 - __main__ - DEBUG - Loading mesa-sleep-1044...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-1044.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:54:21,427 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:54:22,533 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:54:27,622 - __main__ - DEBUG - [mesa-sleep-1044] 244 events, age:0
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
2024-09-25 16:54:52,691 - __main__ - DEBUG - 	n_seg:1368, n_evt:244, annot_dist:{0: 670, 1: 98, 2: 424, 3: 170, 5: 77}, clz_lbl_dist:{0: 599, 1: 692, 2: 77}, remain:2900
2024-09-25 16:54:52,691 - __main__ - DEBUG - Loading mesa-sleep-0558...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0558.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:54:52,698 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:54:52,813 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:54:54,488 - __main__ - DEBUG - [mesa-sleep-0558] 131 events, age:1
2024-09-25 16:55:09,609 - __main__ - DEBUG - 	n_seg:863, n_evt:131, annot_dist:{0: 658, 1: 44, 2: 293, 3: 194, 5: 70}, clz_lbl_dist:{0: 430, 1: 365, 2: 68}, remain:2900
2024-09-25 16:55:09,609 - __main__ - DEBUG - Loading mesa-sleep-1901...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-1901.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:55:09,620 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:55:09,768 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:55:09,971 - __main__ - DEBUG - [mesa-sleep-1901] 214 events, age:1
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  warn(
2024-09-25 16:55:29,666 - __main__ - DEBUG - No label for annot '9' in mesa-sleep-1901.edf
2024-09-25 16:55:29,667 - __main__ - DEBUG - 	n_seg:1436, n_evt:214, annot_dist:{0: 665, 1: 178, 2: 483, 3: 9, 5: 104, 9: 1}, clz_lbl_dist:{0: 662, 1: 670, 2: 104}, remain:3000
2024-09-25 16:55:29,668 - __main__ - DEBUG - Loading mesa-sleep-1917...


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

2024-09-25 16:55:29,675 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:55:29,803 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:55:31,755 - __main__ - DEBUG - [mesa-sleep-1917] 171 events, age:0
  mrrs /= th2
  mrrs /= th2
2024-09-25 16:55:56,233 - __main__ - DEBUG - 	n_seg:1435, n_evt:171, annot_dist:{0: 665, 1: 57, 2: 571, 3: 23, 5: 123}, clz_lbl_dist:{0: 661, 1: 651, 2: 123}, remain:2800
2024-09-25 16:55:56,234 - __main__ - DEBUG - Loading mesa-sleep-0312...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0312.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:55:56,244 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:55:56,402 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:55:56,580 - __main__ - DEBUG - [mesa-sleep-0312] 189 events, age:1
2024-09-25 16:56:16,915 - __main__ - DEBUG - 	n_seg:1197, n_evt:189, annot_dist:{0: 712, 2: 315, 1: 124, 3: 2, 5: 47}, clz_lbl_dist:{0: 709, 1: 441, 2: 47}, remain:0
2024-09-25 16:56:16,916 - __main__ - DEBUG - Loading mesa-sleep-0271...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0271.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:56:16,924 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:56:17,045 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:56:18,580 - __main__ - DEBUG - [mesa-sleep-0271] 168 events, age:0
  mrrs /= th2
  mrrs /= th2
2024-09-25 16:56:37,063 - __main__ - DEBUG - 	n_seg:1074, n_evt:168, annot_dist:{0: 348, 1: 51, 2: 622, 3: 49, 5: 129}, clz_lbl_dist:{0: 223, 1: 722, 2: 129}, remain:2900
2024-09-25 16:56:37,064 - __main__ - DEBUG - Loading mesa-sleep-1803...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-1803.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:56:37,070 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:56:37,160 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:56:37,671 - __main__ - DEBUG - [mesa-sleep-1803] 89 events, age:1
2024-09-25 16:56:50,706 - __main__ - DEBUG - 	n_seg:632, n_evt:89, annot_dist:{0: 729, 1: 29, 2: 201, 3: 72, 5: 48}, clz_lbl_dist:{0: 282, 1: 302, 2: 48}, remain:2900
2024-09-25 16:56:50,706 - __main__ - DEBUG - Loading mesa-sleep-1790...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-1790.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:56:50,716 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:56:50,819 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:56:52,355 - __main__ - DEBUG - [mesa-sleep-1790] 334 events, age:0
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

2024-09-25 16:57:12,369 - __main__ - DEBUG - 	n_seg:1196, n_evt:334, annot_dist:{0: 544, 2: 358, 1: 145, 3: 66, 5: 86}, clz_lbl_dist:{0: 541, 1: 569, 2: 86}, remain:2900
2024-09-25 16:57:12,370 - __main__ - DEBUG - Loading mesa-sleep-0934...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0934.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:57:12,382 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:57:12,514 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:57:13,908 - __main__ - DEBUG - [mesa-sleep-0934] 266 events, age:0
  mrrs /= th2
  mrrs /= th2
2024-09-25 16:57:35,860 - __main__ - DEBUG - 	n_seg:1192, n_evt:266, annot_dist:{0: 488, 1: 148, 2: 477, 3: 1, 5: 85}, clz_lbl_dist:{0: 481, 1: 626, 2: 85}, remain:2900
2024-09-25 16:57:35,861 - __main__ - DEBUG - Loading mesa-sleep-0643...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0643.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:57:35,871 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:57:35,987 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:57:37,696 - __main__ - DEBUG - [mesa-sleep-0643] 151 events, age:0
  warn(
  warn(
  warn(
  warn(
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  mrrs /= th2
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  warn(
  mrrs /= th2
  warn(
  warn(
  mrrs /= th2
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  warn(
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  warn(
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  mrrs /= th2
  warn(
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  mrrs /= th2
  warn(
  warn(
  warn(
  warn(
  warn(
  mrrs /= th2
  warn(
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  mrrs /= th2
  warn(
  warn(
  warn(
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  warn(
  mrrs /= th2
  warn(
  warn(
  mrrs /= th2
  warn(
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  warn(
  warn(
  warn(
  mrrs /= th2
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

2024-09-25 16:57:43,369 - __main__ - DEBUG - 	n_seg:238, n_evt:151, annot_dist:{0: 828, 1: 73, 2: 307, 3: 4, 5: 107}, clz_lbl_dist:{0: 235, 1: 3}, remain:2900
2024-09-25 16:57:43,370 - __main__ - DEBUG - Loading mesa-sleep-1010...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-1010.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:57:43,378 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:57:43,475 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:57:43,632 - __main__ - DEBUG - [mesa-sleep-1010] 222 events, age:1
2024-09-25 16:57:52,470 - __main__ - DEBUG - 	n_seg:388, n_evt:222, annot_dist:{0: 261, 1: 81, 2: 449, 3: 113, 5: 176}, clz_lbl_dist:{0: 13, 1: 264, 2: 111}, remain:0
2024-09-25 16:57:52,471 - __main__ - DEBUG - Loading mesa-sleep-0332...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0332.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:57:52,484 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:57:52,648 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:57:53,794 - __main__ - DEBUG - [mesa-sleep-0332] 562 events, age:1
2024-09-25 16:58:31,393 - __main__ - DEBUG - 	n_seg:1281, n_evt:562, annot_dist:{0: 748, 2: 406, 1: 398, 5: 127}, clz_lbl_dist:{0: 540, 1: 657, 2: 84}, remain:2900
2024-09-25 16:58:31,393 - __main__ - DEBUG - Loading mesa-sleep-1789...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-1789.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:58:31,401 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:58:31,557 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:58:33,987 - __main__ - DEBUG - [mesa-sleep-1789] 251 events, age:0
2024-09-25 16:58:55,693 - __main__ - DEBUG - 	n_seg:1526, n_evt:251, annot_dist:{0: 693, 1: 231, 2: 440, 5: 194, 3: 1}, clz_lbl_dist:{0: 660, 1: 672, 2: 194}, remain:2800
2024-09-25 16:58:55,694 - __main__ - DEBUG - Loading mesa-sleep-0046...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0046.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:58:55,702 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:58:55,820 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:58:57,438 - __main__ - DEBUG - [mesa-sleep-0046] 208 events, age:1
  mrrs /= th2
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  warn(
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  mrrs /= th2
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  warn(
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  warn(
  warn(
  warn(
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  mrrs /= th2
  warn(
  warn(
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  warn(
  mrrs /= th2
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  warn(
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  warn(
  warn(
2024-09-25 16:59:15,558 - __main__ - DEBUG - 	n_seg:1120, n_evt:208, annot_dist:{0: 426, 1: 109, 2: 330, 3: 214, 5: 120}, clz_lbl_dist:{0: 347, 1: 653, 2: 120}, remain:2900
2024-09-25 16:59:15,559 - __main__ - DEBUG - Loading mesa-sleep-0537...


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

2024-09-25 16:59:15,567 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:59:15,658 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:59:15,824 - __main__ - DEBUG - [mesa-sleep-0537] 98 events, age:0
  mrrs /= th2
  mrrs /= th2
2024-09-25 16:59:35,213 - __main__ - DEBUG - 	n_seg:1060, n_evt:98, annot_dist:{0: 539, 1: 44, 2: 369, 3: 13, 5: 115}, clz_lbl_dist:{0: 519, 1: 426, 2: 115}, remain:0
2024-09-25 16:59:35,216 - __main__ - DEBUG - Loading mesa-sleep-1604...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-1604.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:59:35,228 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:59:35,352 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:59:36,875 - __main__ - DEBUG - [mesa-sleep-1604] 253 events, age:1
  mrrs /= th2
  warn(
  mrrs /= th2
  warn(
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  mrrs /= th2
  warn(
  warn(
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

2024-09-25 16:59:56,609 - __main__ - DEBUG - 	n_seg:1165, n_evt:253, annot_dist:{1: 124, 2: 392, 0: 590, 5: 93}, clz_lbl_dist:{1: 516, 0: 556, 2: 93}, remain:2900
2024-09-25 16:59:56,610 - __main__ - DEBUG - Loading mesa-sleep-0236...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0236.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 16:59:56,618 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 16:59:56,693 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 16:59:57,758 - __main__ - DEBUG - [mesa-sleep-0236] 158 events, age:1
  mrrs /= th2
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rc

Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
2024-09-25 17:00:11,057 - __main__ - DEBUG - 	n_seg:871, n_evt:158, annot_dist:{0: 477, 1: 60, 2: 348, 5: 71, 3: 3}, clz_lbl_dist:{0: 389, 1: 411, 2: 71}, remain:2900
2024-09-25 17:00:11,060 - __main__ - DEBUG - Loading mesa-sleep-0155...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0155.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 17:00:11,069 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 17:00:11,161 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 17:00:11,310 - __main__ - DEBUG - [mesa-sleep-0155] 206 events, age:1
  mrrs /= th2
  mrrs /= th2
2024-09-25 17:00:29,120 - __main__ - DEBUG - 	n_seg:1080, n_evt:206, annot_dist:{0: 271, 1: 101, 2: 531, 3: 11, 5: 166}, clz_lbl_dist:{0: 271, 1: 643, 2: 166}, remain:0
2024-09-25 17:00:29,120 - __main__ - DEBUG - Loading mesa-sleep-0419...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0419.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 17:00:29,128 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 17:00:29,243 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 17:00:34,076 - __main__ - DEBUG - [mesa-sleep-0419] 158 events, age:0
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
2024-09-25 17:00:57,514 - __main__ - DEBUG - 	n_seg:1321, n_evt:158, annot_dist:{0: 632, 1: 81, 2: 513, 3: 45, 5: 168}, clz_lbl_dist:{0: 514, 1: 639, 2: 168}, remain:2900
2024-09-25 17:00:57,516 - __main__ - DEBUG - Loading mesa-sleep-0744...


Extracting EDF parameters from /Users/brenton/data/mesa/polysomnography/set1x20/mesa-sleep-0744.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


2024-09-25 17:00:57,528 - __main__ - DEBUG - channels: ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3', 'Pres', 'Flow', 'Snore', 'Thor', 'Abdo', 'Leg', 'Therm', 'Pos', 'EKG_Off', 'EOG-L_Off', 'EOG-R_Off', 'EMG_Off', 'EEG1_Off', 'EEG2_Off', 'EEG3_Off', 'Pleth', 'OxStatus', 'SpO2', 'HR', 'DHR']


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


2024-09-25 17:00:57,655 - __main__ - DEBUG - channels: ['EKG'], search:ekg, src_hz:256
2024-09-25 17:00:57,847 - __main__ - DEBUG - [mesa-sleep-0744] 396 events, age:0
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  warn(
  warn(


Traceback (most recent call last):
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 428, in _initialise
    seg = self.process_validate_segment(seg, f)
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 485, in process_validate_segment
    'beats': get_beats(ecg_cleaned, self.hz, self.rr_seg_dim, n_beat=2) if self.is_ecg_beats else [],
  File "/Users/brenton/Library/CloudStorage/OneDrive-Personal/Documents/_UNIVERSITY/_MASTEROFAI/2024Tri2/split_vae/datasource.py", line 82, in get_beats
    beats = nk.ecg_segment(y, rpeaks=None, sampling_rate=hz, show=False)
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/ecg/ecg_segment.py", line 64, in ecg_segment
    heartbeats = epochs_create(
  File "/Users/brenton/miniconda3/envs/split_vae/lib/python3.10/site-packages/neurokit2/epochs/epoc

2024-09-25 17:01:21,665 - __main__ - DEBUG - 	n_seg:1299, n_evt:396, annot_dist:{0: 649, 1: 179, 2: 406, 3: 81, 5: 125}, clz_lbl_dist:{0: 508, 1: 666, 2: 125}, remain:0
2024-09-25 17:01:21,668 - __main__ - DEBUG - Total files:20, n_seg:21742, distribution:(array([0, 1, 2]), array([ 9140, 10588,  2014]))


In [16]:
ds.record_names.sort()
log(ds.record_names)

n_rec = len(ds.record_names)
train_frac = math.ceil(n_rec*0.8)
test_frac = n_rec - train_frac
validation_frac = math.ceil(train_frac*0.1)
train_frac = train_frac - validation_frac
print(n_rec, train_frac, validation_frac, test_frac)

train_rec_names = ds.record_names[:train_frac]
validation_rec_names = ds.record_names[train_frac:train_frac+validation_frac]
test_rec_names = ds.record_names[train_frac+validation_frac:]
log(f"N ({n_rec}) train/val/test: {train_frac}/{validation_frac}/{test_frac}")
log(f"Train: {train_rec_names}, val: {validation_rec_names}, test:{test_rec_names}")

train_idx = []
for rec in train_rec_names:
    train_idx.extend(ds.record_wise_segments[rec])
validation_idx = []
for rec in validation_rec_names:
    validation_idx.extend(ds.record_wise_segments[rec])
test_idx = []
for rec in test_rec_names:
    test_idx.extend(ds.record_wise_segments[rec])

r"Data loaders"
train_dataset = datasource.PartialDataset(ds, seg_index=train_idx, shuffle=True)
data_loader_train = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True)

val_dataset = datasource.PartialDataset(ds, seg_index=validation_idx, shuffle=True)
data_loader_val = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True)

test_dataset = datasource.PartialDataset(ds, seg_index=test_idx, shuffle=False)
data_loader_test = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=True)
log(f"Data-loader size train: {len(data_loader_train)}, val: {len(data_loader_val)}, test: {len(data_loader_test)}")

2024-09-25 17:02:20,390 - __main__ - DEBUG - ['mesa-sleep-0046', 'mesa-sleep-0155', 'mesa-sleep-0236', 'mesa-sleep-0271', 'mesa-sleep-0312', 'mesa-sleep-0332', 'mesa-sleep-0419', 'mesa-sleep-0537', 'mesa-sleep-0558', 'mesa-sleep-0643', 'mesa-sleep-0744', 'mesa-sleep-0934', 'mesa-sleep-1010', 'mesa-sleep-1044', 'mesa-sleep-1604', 'mesa-sleep-1789', 'mesa-sleep-1790', 'mesa-sleep-1803', 'mesa-sleep-1901', 'mesa-sleep-1917']
2024-09-25 17:02:20,393 - __main__ - DEBUG - N (20) train/val/test: 14/2/4
2024-09-25 17:02:20,394 - __main__ - DEBUG - Train: ['mesa-sleep-0046', 'mesa-sleep-0155', 'mesa-sleep-0236', 'mesa-sleep-0271', 'mesa-sleep-0312', 'mesa-sleep-0332', 'mesa-sleep-0419', 'mesa-sleep-0537', 'mesa-sleep-0558', 'mesa-sleep-0643', 'mesa-sleep-0744', 'mesa-sleep-0934', 'mesa-sleep-1010', 'mesa-sleep-1044'], val: ['mesa-sleep-1604', 'mesa-sleep-1789'], test:['mesa-sleep-1790', 'mesa-sleep-1803', 'mesa-sleep-1901', 'mesa-sleep-1917']
2024-09-25 17:02:20,398 - __main__ - DEBUG - Data-lo

20 14 2 4
label distribution: ['0:5778', '1:7208', '2:1366']
label distribution: ['0:1216', '1:1188', '2:287']
label distribution: ['0:2146', '1:2192', '2:361']


## Model

In [27]:
"""prepare model"""
reload(causal_cnn_models)

params['n_class'] = n_class
params_decoder = params['decoder'].copy()
params_decoder['width'] = params['hz'] * params['seg_len_sec']

net = causal_cnn_models.FoldVaeClassifFoldWeightParameterizer(
    params['encoder'], params_decoder, n_split=params['n_split'], 
    n_class=params['n_class'], log=log, debug=True,
)

log(net)
log(f"# params total: {net_utils.count_parameters(net)}")

x = torch.randn(32, 1, params['seg_len'])

outputs = net(x)    
recon_x = outputs['x_hat']
clz_proba = outputs['clz_proba']

print(f"recon: {recon_x.shape}, proba:{clz_proba.shape}")


2024-09-25 17:24:05,086 - __main__ - DEBUG - FoldVaeClassifFoldWeightParameterizer(
  (encoder): CausalCNNVEncoder(
    (network): Sequential(
      (0): CausalCNN(
        (network): Sequential(
          (0): CausalConvolutionBlock(
            (causal): Sequential(
              (0): Conv1d(1, 128, kernel_size=(5,), stride=(1,), padding=(4,))
              (1): Chomp1d()
              (2): LeakyReLU(negative_slope=0.01)
              (3): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(4,))
              (4): Chomp1d()
              (5): LeakyReLU(negative_slope=0.01)
            )
            (upordownsample): Conv1d(1, 128, kernel_size=(1,), stride=(1,))
          )
          (1): CausalConvolutionBlock(
            (causal): Sequential(
              (0): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(8,), dilation=(2,))
              (1): Chomp1d()
              (2): LeakyReLU(negative_slope=0.01)
              (3): Conv1d(128, 128, kernel_size=(5,), stride=(1

--CausalCNNVDecoder linear1:torch.Size([32, 64])
--CausalCNNVDecoder linear2:torch.Size([32, 19200])
--CausalCNNVDecoder reshape:torch.Size([32, 64, 300])


2024-09-25 17:24:05,632 - __main__ - DEBUG - DEC x_split[0]:torch.Size([32, 1, 300]), _x_hat:torch.Size([32, 1, 300]), x_hat:torch.Size([32, 1, 300])


--CausalCNNVDecoder out_causal_cnn:torch.Size([32, 1, 300])


2024-09-25 17:24:05,951 - __main__ - DEBUG - --VAE enc_mu:torch.Size([32, 32]), enc_sd:torch.Size([32, 32])
2024-09-25 17:24:05,952 - __main__ - DEBUG - --VAE _z:torch.Size([32, 32]), enc_mu:torch.Size([32, 2, 32]), enc_sd:torch.Size([32, 2, 32])
2024-09-25 17:24:05,952 - __main__ - DEBUG - ENC x_split[1]:torch.Size([32, 1, 300]), _z:torch.Size([32, 32]), z:torch.Size([32, 64])
2024-09-25 17:24:06,184 - __main__ - DEBUG - DEC x_split[1]:torch.Size([32, 1, 300]), _x_hat:torch.Size([32, 1, 300]), x_hat:torch.Size([32, 1, 600])
2024-09-25 17:24:06,495 - __main__ - DEBUG - --VAE enc_mu:torch.Size([32, 32]), enc_sd:torch.Size([32, 32])
2024-09-25 17:24:06,496 - __main__ - DEBUG - --VAE _z:torch.Size([32, 32]), enc_mu:torch.Size([32, 3, 32]), enc_sd:torch.Size([32, 3, 32])
2024-09-25 17:24:06,496 - __main__ - DEBUG - ENC x_split[2]:torch.Size([32, 1, 300]), _z:torch.Size([32, 32]), z:torch.Size([32, 96])
2024-09-25 17:24:06,707 - __main__ - DEBUG - DEC x_split[2]:torch.Size([32, 1, 300]), _x

recon: torch.Size([32, 1, 3000]), proba:torch.Size([32, 3])


# Model Training

In [26]:
def calculate_recon_loss(
        criteria_recon, recon_net, input, x_hat):
    loss_recon = 1. * criteria_recon(x_hat.squeeze(1), input.squeeze(1))
    elbo = 1. * net.kl
    loss_recon =  loss_recon + elbo    
    return loss_recon


def calculate_classif_loss(criteria_classif, cls_proba, labels):
    loss_classif = 1*criteria_classif(cls_proba, labels)
    return loss_classif


r"Prepare model training"
model_files = [
    f"{model_path}/fold0_net.pt", 
]
model_instances = [
    net, 
]

class_weights = torch.from_numpy(net_utils.get_class_weights(
    [ds.seg_labels[i] for i in train_idx], n_class=n_class, log=log
)[-1]).type(torch.FloatTensor).to(DEVICE)

optimizer = torch.optim.Adam(net.parameters(), lr=params['lr'])

criteria_classif = nn.CrossEntropyLoss(weight=class_weights)
criteria_recon = nn.BCELoss()

r"Model training"
alpha = 500.
lambda_1 = 200.
lambda_2 = 0.
min_val_loss = 1000.
net.to(DEVICE)

for epoch in range(params['max_epoch']):
    since = time.time()
    epoch_recon_loss, epoch_classif_loss = 0., 0.
    
    net.train()
    # net_classif.train()
    for i_batch, batch_data in enumerate(data_loader_train):

        inputs = batch_data['ecg'].to(DEVICE)
        labels = batch_data['label'].to(DEVICE)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        x_hats = outputs['x_hat']
        z = outputs['z']
        cls_proba = outputs['clz_proba'] 

        if epoch % 5 == 0 and i_batch == 0:
            viz_epoch_batch(epoch, i_batch, inputs, x_hats, log_path)

        loss_recon = calculate_recon_loss(criteria_recon, net, inputs, x_hats)
        loss_classif = calculate_classif_loss(criteria_classif, cls_proba, labels)
        
        total_loss = alpha*loss_recon + lambda_1*loss_classif
        total_loss.backward()
        optimizer.step()

        epoch_recon_loss += loss_recon.detach().cpu().numpy()
        epoch_classif_loss += loss_classif.detach().cpu().numpy()

    
    time_elapsed = time.time() - since
    epoch_recon_loss = epoch_recon_loss / len(data_loader_train)
    epoch_classif_loss = epoch_classif_loss / len(data_loader_train)

    val_loss = 0.
    
    net.eval()
    # net_classif.eval()
    with torch.no_grad():
        for batch_data in data_loader_val:
            inputs = batch_data['ecg'].to(DEVICE)
            labels = batch_data['label'].to(DEVICE)

            outputs = net(inputs)
            x_hats = outputs['x_hat']
            z = outputs['z']
            cls_proba = outputs['clz_proba'] 

            loss = calculate_classif_loss(criteria_classif, cls_proba, labels)
            val_loss += loss.detach().cpu().numpy()
        val_loss = val_loss / len(data_loader_val)

    if val_loss < min_val_loss:
        save_models({
            model_files[0]: model_instances[0],
        })
        # torch.save(net.state_dict(), model_file)
        log(f"Val loss updated {min_val_loss} -> {val_loss}")
        min_val_loss = val_loss

    log(
        f"epoch:{epoch}, loss-recon:{epoch_recon_loss:.5f}, loss-classif:{epoch_classif_loss:.5f}, "
        f"time:{time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s"
    )

2024-09-25 17:24:13,812 - __main__ - DEBUG - freq:[5778. 7208. 1366.], weights:[0.23641398 0.18951165 1.        ]
