In [1]:

import os
import math
import time
from datetime import datetime
import logging
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

import yaml
import torch
from torch import nn
import torch.nn.functional as F

from sklearn.model_selection import KFold

import datasource, causal_cnn_models, modules, net_utils

In [2]:
logger = logging.getLogger(__name__)

def log(msg):
    logger.debug(msg)


def config_logger(log_file=None):
    r"""Config logger."""
    global logger
    logger.handlers.clear()
    logger.setLevel(logging.DEBUG)
    
    format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    # create console handler and set level to debug
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(format)
    logger.addHandler(ch)

    # create file handler which logs even debug messages
    if log_file:
        fh = logging.FileHandler(log_file)
        fh.setFormatter(format)
        logger.addHandler(fh)


def viz_epoch_batch(epoch, i_batch, x_batch, x_hat_batch, log_path):
    # folder = os.path.join("logs", "recon_vae", log_filename)
    folder = os.path.join(log_path, "recon_vae")
    if not os.path.exists(folder):
        os.makedirs(folder)
    x_batch = x_batch.detach().cpu().numpy()
    x_hat_batch = x_hat_batch.detach().cpu().numpy()
    for i in range(1):
        orig = x_batch[i, 0, :]
        recon = x_hat_batch[i, 0, :]
        _, ax = plt.subplots()
        ax.plot(range(len(orig)), orig)
        # plt.savefig(
        #     f"{folder}/epoch{epoch}_item{i}_orig.png",
        #     format='png', dpi=300, bbox_inches='tight')
        ax.plot(range(len(recon)), recon)
        # plt.ylim((0, 2)) 
        plt.savefig(
            f"{folder}/epoch{epoch}_batch{i_batch}_item{i}.png",
            format='png', dpi=300, bbox_inches='tight')
    plt.close()

    
def save_models(model_file_instance_pairs):
    for model_file in model_file_instance_pairs.keys():
        net = model_file_instance_pairs.get(model_file)
        torch.save(net.state_dict(), model_file)

def load_models(model_file_instance_pairs, device="cpu"):
    for model_file in model_file_instance_pairs.keys():
        net = model_file_instance_pairs.get(model_file)
        net.load_state_dict(
            torch.load(model_file, map_location=device, weights_only=True))
    # No return seems necessary, in-memory models updated.

net_utils.fix_randomness()

In [None]:
params = None 
SIM_FILE = 'simMultimodalEcgNoAge'
CFG_FILE = 'config_multimodal_ecg.yml'
with open(CFG_FILE, 'r') as stream:
        params = yaml.safe_load(stream)
        params['seg_len'] = params['hz'] * params['seg_len_sec']
        params['decoder']['width'] = params['seg_len']

# if bool(params['age_classif']):
#     SIM_FILE = 'FoldVaeClassifFoldWeightZ'

tm_sim_start = f"{datetime.now():%Y%m%d%H%M%S}"
params['tm_sim_start'] = tm_sim_start
log_path = f"logs/{SIM_FILE}_{CFG_FILE[:-4]}_{params['data_path'].replace('/','')}_split{params['n_split']}_ecg{params['input_ecg']}_rr{params['input_rr']}_rsp{params['input_rsp']}_{tm_sim_start}"
model_path = f"{log_path}/models"
log_file = (f"{log_path}/{tm_sim_start}.log")
if not os.path.exists(log_path):
        os.makedirs(log_path)
if not os.path.exists(model_path):
        os.makedirs(model_path)
logger = logging.getLogger(__name__)
config_logger(log_file)

# config_logger()

DEVICE = torch.device(f"cuda:{params['cuda']}" if torch.cuda.is_available() else "cpu")
# DEVICE = "cpu"

log(params)

In [None]:
"""Data source"""
class_map = {0:0, 1:1, 2:1, 3:1, 4:1, 5:2}
# class_map = {0:0, 1:1, 2:1, 3:2, 4:2, 5:3}
n_class = len(set(class_map.values()))
params['n_class'] = n_class
log(f"class-map: {class_map}")
ds = datasource.MesaDb(
    f"{os.path.expanduser('~')}/data/mesa/polysomnography", data_subdir="set1x20",
    hz=100, n_subjects=-1, hz_rr=params['hz_rr'], class_map=class_map, 
    is_rr_sig=bool(params['input_rr']), is_rsp=bool(params['input_rsp']), is_ecg_beats=False, log=log,
)

# ds = datasource.MesaDbCsv(
#     f"{os.path.expanduser('~')}/data/mesa/polysomnography", data_subdir="set1x20",
#     hz=100, class_map=class_map, n_subjects=-1, hz_rr=params['hz_rr'],
#     is_rr_sig=bool(params['input_rr']), is_rsp=bool(params['input_rsp']), is_ecg_beats=False, log=log,
# )

In [None]:
"""prepare model"""

params['n_class'] = n_class
params_decoder = params['decoder'].copy()
params_decoder['width'] = params['hz'] * params['seg_len_sec']

net = causal_cnn_models.FoldVaeClassifFoldWeight(
    params['encoder'], params_decoder, n_split=params['n_split'], 
    n_class=params['n_class'], log=log, debug=True,
)

# net = causal_cnn_models.FoldVaeClassifFoldWeightZ(
#     params['encoder'], params_decoder, n_split=params['n_split'], 
#     n_class=params['n_class'], log=log, debug=True,
# )

log(net)
log(f"# params total: {net_utils.count_parameters(net)}")

x = torch.randn(32, 1, params['seg_len'])

outputs = net(x)    
recon_x = outputs['x_hat']
# z = outputs['z']
clz_proba = outputs['clz_proba']
# clz_proba_voted = outputs['clz_proba_voted']
clz_proba_age = outputs['clz_proba_age']

# recon_x, z, clz_proba_classif_stage, clz_proba_classif_age = net(x)
print(f"recon: {recon_x.shape}, proba-1:{clz_proba.shape}, proba-2:{clz_proba_age.shape}")


In [None]:
ds.record_names.sort()
log(ds.record_names)

n_rec = len(ds.record_names)
train_frac = math.ceil(n_rec*0.8)
test_frac = n_rec - train_frac
validation_frac = math.ceil(train_frac*0.1)
train_frac = train_frac - validation_frac
print(n_rec, train_frac, validation_frac, test_frac)

train_rec_names = ds.record_names[:train_frac]
validation_rec_names = ds.record_names[train_frac:train_frac+validation_frac]
test_rec_names = ds.record_names[train_frac+validation_frac:]
log(f"N ({n_rec}) train/val/test: {train_frac}/{validation_frac}/{test_frac}")
log(f"Train: {train_rec_names}, val: {validation_rec_names}, test:{test_rec_names}")

train_idx = []
for rec in train_rec_names:
    train_idx.extend(ds.record_wise_segments[rec])
validation_idx = []
for rec in validation_rec_names:
    validation_idx.extend(ds.record_wise_segments[rec])
test_idx = []
for rec in test_rec_names:
    test_idx.extend(ds.record_wise_segments[rec])

r"Data loaders"
train_dataset = datasource.PartialDataset(ds, seg_index=train_idx, shuffle=True)
data_loader_train = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True)

val_dataset = datasource.PartialDataset(ds, seg_index=validation_idx, shuffle=True)
data_loader_val = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True)

test_dataset = datasource.PartialDataset(ds, seg_index=test_idx, shuffle=False)
data_loader_test = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=True)
log(f"Data-loader size train: {len(data_loader_train)}, val: {len(data_loader_val)}, test: {len(data_loader_test)}")

In [None]:

def calculate_recon_loss(
        criteria_recon, recon_net, input, x_hat):
    loss_recon = 1. * criteria_recon(x_hat.squeeze(1), input.squeeze(1))
    elbo = 1. * net.kl
    loss_recon =  loss_recon + elbo    
    return loss_recon


def calculate_classif_loss(criteria_classif, cls_proba, labels):
    loss_classif = 1*criteria_classif(cls_proba, labels)
    return loss_classif


r"Prepare model training"
model_files = [
    f"{model_path}/fold0_net.pt", 
]
model_instances = [
    net, 
]

class_weights = torch.from_numpy(net_utils.get_class_weights(
    [ds.seg_labels[i] for i in train_idx], n_class=n_class, log=log
)[-1]).type(torch.FloatTensor).to(DEVICE)

age_class_weights = torch.from_numpy(net_utils.get_class_weights(
    [ds.segments[i]['age'] for i in train_idx], n_class=2, log=log
)[-1]).type(torch.FloatTensor).to(DEVICE)

optimizer = torch.optim.Adam(net.parameters(), lr=params['lr'])
# optimizer = torch.optim.SGD(net.parameters(), lr=params['lr'], momentum=0.9)
# optimizer = torch.optim.Adam(
#     list(net.encoder.parameters())+list(net.decoder.parameters())+list(net_classif.parameters())+list(net_age_classif.parameters()), 
#     lr=params['lr'])
# optimizer = torch.optim.SGD([
#     {'params': net.parameters()},
#     # {'params': net_classif.parameters(), 'lr':1e-3},
#     # {'params': net_age_classif.parameters(), 'lr':1e-3},
# ], lr=1e-2, momentum=0.9)
criteria_classif = nn.CrossEntropyLoss(weight=class_weights)
criteria_age_classif = nn.CrossEntropyLoss(weight=age_class_weights)
criteria_recon = nn.BCELoss()

r"Model training"
alpha = 500.
lambda_1 = 200.
lambda_2 = 0.
min_val_loss = 1000.
net.to(DEVICE)
# net_classif.to(DEVICE)
# net_age_classif.to(DEVICE)
for epoch in range(params['max_epoch']):
    since = time.time()
    epoch_recon_loss, epoch_classif_loss, epoch_age_classif_loss = 0., 0., 0.
    
    net.train()
    # net_classif.train()
    for i_batch, batch_data in enumerate(data_loader_train):
        # inputs, x_hats = [], []
        # inputs.append(batch_data['ecg'].to(DEVICE))
        inputs = batch_data['ecg'].to(DEVICE)
        labels = batch_data['label'].to(DEVICE)
        labels_age = batch_data['age'].to(DEVICE) if bool(params['age_classif']) else None

        
        # gt_1 = inputs[inputs > 1.]
        # lt_0 = inputs[inputs < 0.]
        # print("gt_1: ", gt_1.detach().cpu().numpy)
        # print("lt_0: ", lt_0.detach().cpu().numpy)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        x_hats = outputs['x_hat']
        z = outputs['z']
        cls_proba = outputs['clz_proba'] 
        cls_proba_age = outputs['clz_proba_age']

        if epoch % 5 == 0 and i_batch == 0:
            viz_epoch_batch(epoch, i_batch, inputs, x_hats, log_path)

        loss_recon = calculate_recon_loss(criteria_recon, net, inputs, x_hats)
        loss_classif = calculate_classif_loss(criteria_classif, cls_proba, labels)
        loss_age_classif = 0. if not bool(params['age_classif']) else calculate_classif_loss(criteria_age_classif, cls_proba_age, labels_age)
        
        total_loss = alpha*loss_recon + lambda_1*loss_classif + lambda_2*loss_age_classif
        total_loss.backward()
        optimizer.step()

        epoch_recon_loss += loss_recon.detach().cpu().numpy()
        epoch_classif_loss += loss_classif.detach().cpu().numpy()
        epoch_age_classif_loss += 0. if not bool(params['age_classif']) else  loss_age_classif.detach().cpu().numpy()

    
    time_elapsed = time.time() - since
    epoch_recon_loss = epoch_recon_loss / len(data_loader_train)
    epoch_classif_loss = epoch_classif_loss / len(data_loader_train)
    epoch_age_classif_loss = epoch_age_classif_loss / len(data_loader_train)

    val_loss = 0.
    
    net.eval()
    # net_classif.eval()
    with torch.no_grad():
        for batch_data in data_loader_val:
            inputs = batch_data['ecg'].to(DEVICE)
            labels = batch_data['label'].to(DEVICE)
            labels_age = batch_data['age'].to(DEVICE) if bool(params['age_classif']) else None

            outputs = net(inputs)
            x_hats = outputs['x_hat']
            z = outputs['z']
            cls_proba = outputs['clz_proba'] 
            cls_proba_age = outputs['clz_proba_age'] if bool(params['age_classif']) else None
            # x_hats, z, cls_proba, cls_proba_age = net(inputs)

            # loss_recon = calculate_recon_loss(criteria_recon, net, inputs, x_hats)
            # loss_classif = calculate_classif_loss(criteria_classif, cls_proba, labels)
            # loss_age_classif = 0. if not bool(params['age_classif']) else calculate_classif_loss(criteria_age_classif, cls_proba_age, labels_age)            
            # loss = alpha*loss_recon + lambda_1*loss_classif + lambda_2*loss_age_classif

            # cls_proba = net_classif(z)
            loss = calculate_classif_loss(criteria_classif, cls_proba, labels)
            loss_age = 0. if not bool(params['age_classif']) else calculate_classif_loss(criteria_age_classif, cls_proba_age, labels_age)
            val_loss += (loss + loss_age).detach().cpu().numpy()
        val_loss = val_loss / len(data_loader_val)

    if val_loss < min_val_loss:
        save_models({
            model_files[0]: model_instances[0],
            # model_files[1]: model_instances[1],
            # model_files[2]: model_instances[2],
        })
        # torch.save(net.state_dict(), model_file)
        log(f"Val loss updated {min_val_loss} -> {val_loss}")
        min_val_loss = val_loss

    log(
        f"epoch:{epoch}, loss-recon:{epoch_recon_loss:.5f}, loss-classif:{epoch_classif_loss:.5f}, "
        f"loss-age-classif::{epoch_age_classif_loss:.5f}, val-loss:{val_loss:.5f}, "
        f"time:{time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")