In [1]:
import os
os.environ["HTTP_PROXY"] = "http://192.168.45.100:3128"
os.environ["HTTPS_PROXY"] = "http://192.168.45.100:3128"
!pip install monai neurokit2 wfdb monai pytorch_lightning==1.7.7 wandb libauc==1.2.0 --upgrade --quiet

gpus= "0,1,2,3"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
os.environ["WANDB_API_KEY"] = '6cd6a2f58c8f4625faaea5c73fe110edab2be208'
%env WANDB_SILENT=true

!nvcc -V
!nvidia-smi

env: WANDB_SILENT=true
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0
Thu Nov 10 08:20:25 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA TITAN X ...  Off  | 00000000:05:00.0 Off |                  N/A |
| 37%   60C    P5    24W / 250W |      2MiB / 12288MiB |      0%      Default |
|                               |                      |                  N/A |
+---------------------

In [2]:
config_defaults = dict(
    dataSeed = 2,
    srTarget = 250,
    featureLength = 1024,
    sampler = False, #True,
    inChannels = 1,
    outChannels = 2,
    modelName='efficientnet-b0',
    norm = 'instance',
    upsample = 'deconv', #'pixelshuffle', # 'nontrainable'
    supervision = "NONE",
    skipModule = "NONE",
    trainaug = 'NONE',
    
    project = 'PVC_NET',
    path_logRoot = '20221110_final',
    spatial_dims = 1,
    learning_rate = 4e-3,
    batch_size = 256,# 256
    dropout=0,
    thresholdRPeak= 0.5,
    skipASPP = "NONE",
    lossFn='focalloss',
    se = 'se',
)

sweep_config = {
  "project" : config_defaults['project'],
  "name" : config_defaults["path_logRoot"], # sweep run name!!
  # "method" : "bayes",
  "method" : "grid",
  "metric": {
      # "name":"val_loss",
      # "goal":"minimize"},      
      "name":"testMIT_AUPRC_Class1Raw",
      "goal":"maximize",
  },      
  "parameters" : 
    {
    "srTarget":{"values": [125, 250, 360]}, # True, False
    "featureLength":{"values": [512, 1024]}, # True, False
    "dataSeed":{"values": [2]}, # True, False
    "sampler":{"values": [True]}, # True, False
    "lossFn":{"values": ['bceloss']}, # ,'bceloss'
    "modelName":{"values": ['efficientnet-b0']}, # 'resnet18', 'resnet34', 'resnet50', 'efficientnet-b0','efficientnet-b1','efficientnet-b2','efficientnet-b3','efficientnet-b4'
    "norm":{"values": ['instance','batch']}, # 'instance','batch'
    "se":{"values": ['se']}, # 'se','acm','nlnn','deeprft','cbam'
    "dropout":{"values": [0.1]}, 
    "upsample":{"values": ['pixelshuffle']}, # ['pixelshuffle','deconv','nontrainable']
    "supervision":{"values": ['TYPE2']}, # 'NONE','TYPE1','TYPE2'
    "outChannels":{"values": [2]}, # 2,3,4
    "trainaug":{'values':['NEUROKIT']}, #"NONE",'NEUROKIT','NEUROKIT2',"AUDIOMENTATION"
    "skipModule":{"values": ['NONE','ACM2_0_BOTTOM5','ACM4_0_BOTTOM5','ACM8_0_BOTTOM5','SE_BOTTOM5','NLNN_BOTTOM5','CBAM_BOTTOM5','FFC_BOTTOM5','DEEPRFT_BOTTOM5']}, # ,'SE_BOTTOM5','NLNN_BOTTOM5','CBAM_BOTTOM5','FFC_BOTTOM5','DEEPRFT_BOTTOM5','ACM2_BOTTOM5','ACM4_BOTTOM5','ACM8_BOTTOM5'
    "skipASPP":{"values": ['NONE']}, # 'NONE','BOTTOM5'
  }
}

In [5]:
%matplotlib inline
import warnings
warnings.filterwarnings(action='ignore')

import os, sys, shutil
import multiprocessing
import random
import time

import pickle
import pylab as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import sklearn
import sklearn.metrics

from glob import glob
from tqdm.notebook import tqdm, trange
from natsort import natsorted

import scipy
import scipy.io as sio
from scipy.signal import butter, filtfilt, lfilter
from scipy.signal import kaiserord, firwin, filtfilt, butter
from scipy.ndimage import label, binary_closing
from skimage import morphology
from scipy import ndimage

import kornia
import neurokit2 as nk
import librosa as lb
import soundfile as sf

import sklearn
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve, classification_report, auc

import cv2
import monai
from monai.inferers import sliding_window_inference
from monai.config import print_config

import pytorch_lightning as pl
from pytorch_lightning.callbacks import *
from pytorch_lightning.loggers import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import *

def set_seed(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)
    monai.utils.misc.set_determinism(seed=seed)
    pl.seed_everything(seed,True)    
    
def get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

set_seed()
device = get_device()

NUM_WORKERS = os.cpu_count()
print("Number of workers:", NUM_WORKERS)
print('multiprocessing.cpu_count()', multiprocessing.cpu_count())
print('cuda.is_available', torch.cuda.is_available())
print(device)
print_config()

2022-11-10 08:20:29,691 - Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.
2022-11-10 08:20:30,088 - Global seed set to 42
Number of workers: 12
multiprocessing.cpu_count() 12
cuda.is_available True
cuda
MONAI version: 1.0.1
Numpy version: 1.22.3
Pytorch version: 1.10.2+cu111
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 8271a193229fe4437026185e218d5b06f7c8ce69
MONAI __file__: /home/kevin/.local/lib/python3.9/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.2
scikit-image version: 0.19.2
Pillow version: 9.1.1
Tensorboard version: 2.10.1
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.11.3+cu111
tqdm version: 4.63.0
lmdb version: NOT INSTALLED or UNKN

In [6]:
def Youden_index(y_true, y_score):
    '''Find data-driven cut-off for classification    
    Cut-off is determied using Youden's index defined as sensitivity + specificity - 1.    
    Parameters
    ----------
    
    y_true : array, shape = [n_samples]
        True binary labels.
        
    y_score : array, shape = [n_samples]
        Target scores, can either be probability estimates of the positive class,
        confidence values, or non-thresholded measure of decisions (as returned by
        “decision_function” on some classifiers).

    === Example ===
    y = [0,0,0,1,1,1]
    yhat = [0.3,0.6,0.4,.7,.9,.8]
    Youden_index(y, yhat)

    References
    ----------    
    Ewald, B. (2006). Post hoc choice of cut points introduced bias to diagnostic research.
    Journal of clinical epidemiology, 59(8), 798-801.
    
    Steyerberg, E.W., Van Calster, B., & Pencina, M.J. (2011). Performance measures for
    prediction models and markers: evaluation of predictions and classifications.
    Revista Espanola de Cardiologia (English Edition), 64(9), 788-794.
    
    Jiménez-Valverde, A., & Lobo, J.M. (2007). Threshold criteria for conversion of probability
    of species presence to either–or presence–absence. Acta oecologica, 31(3), 361-369.
    '''
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    idx = np.argmax(tpr - fpr)
    return thresholds[idx]

import nets
class PVC_NET(pl.LightningModule):
    def __init__(self,hyperparameters):
        super(PVC_NET, self).__init__()
        
        self.hyperparameters = hyperparameters
        self.experiment_name = str(self.hyperparameters).replace("{","").replace("}","").replace("'","").replace(": ","").replace(", ","_").split('_project')[0] # cut name as it is too long
        path = f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/"
        print(f'saving path : {path}') 
        os.makedirs(path, mode=0o777, exist_ok=True)
        
        self.srTarget = self.hyperparameters['srTarget']
        self.featureLength= self.hyperparameters['featureLength']
        self.thresholdRPeak = self.hyperparameters['thresholdRPeak'] # 0.7
        self.learning_rate = self.hyperparameters['learning_rate']
        self.youden_index = 0.5
        self.testPlot = False
        
        # define model using hyperparamters
        if 'efficient' in hyperparameters['modelName'] or 'resnet' in hyperparameters['modelName']:
            self.net = nets.UNet(modelName = hyperparameters['modelName'], 
                            spatial_dims = hyperparameters['spatial_dims'],
                            in_channels = hyperparameters['inChannels'],
                            out_channels = hyperparameters['outChannels'],
                            norm = hyperparameters['norm'],
                            upsample = hyperparameters['upsample'],
                            dropout = hyperparameters['dropout'],
                            supervision = hyperparameters['supervision'],
                            skipModule = hyperparameters['skipModule'],
                            skipASPP =  hyperparameters['skipASPP'],
                            se_module= hyperparameters['se'],
                           )
            
#         elif 'U2' in hyperparameters['modelName']:
#             self.net = nets.U2NET(in_ch=hyperparameters['in_channels'],
#                                   out_ch=hyperparameters['out_channels'],
#                                   nnblock = hyperparameters['nnblock'],
#                                   ASPP = hyperparameters['ASPP'],
#                                   FFC = hyperparameters['FFC'],
#                                   acm = hyperparameters['acm'],
#                                   dropout = hyperparameters['dropout'],
#                                   temperature=1,
#                                   norm = hyperparameters['norm'],
#                                  )
            
#         elif 'unetr' in hyperparameters['modelName']:
#             self.net= monai.networks.nets.UNETR(hyperparameters['in_channels'], 
#                                                 hyperparameters['out_channels'], 
#                                                 2048,
#                                                 feature_size = 16,
#                                                 hidden_size = 768,
#                                                 mlp_dim = 3072,
#                                                 num_heads = 12,
#                                                 pos_embed = 'conv',
#                                                 norm_name= hyperparameters['norm'],
#                                                 conv_block = True,
#                                                 res_block = True,
#                                                 dropout_rate = 0.0,
#                                                 spatial_dims = hyperparameters['spatial_dims'],)
            
        # define loss using hyperparameters
        if hyperparameters['lossFn']=='bceloss':
            self.lossFn = nn.BCELoss()
        elif hyperparameters['lossFn']=='wbceloss':
            self.lossFn = BCELoss_class_weighted([.2, .8])
        elif hyperparameters['lossFn']=='diceceloss':
            self.lossFn = monai.losses.DiceCELoss()
        elif hyperparameters['lossFn']=='focalloss':
            self.lossFn = FocalLoss()
        elif hyperparameters['lossFn']=='dicefocalloss':
            self.lossFn = monai.losses.DiceFocalLoss()
        elif hyperparameters['lossFn']=='weightedfocalloss':
            self.lossFn = WeightedFocalLoss()
        elif hyperparameters['lossFn']=='propotionalLoss':
            self.lossFn = PropotionalLoss(per_image=False, smooth=1e-7, beta=0.7, bce=True)
            
        self.save_hyperparameters()
        
    def compute_loss(self, yhat, y):
        if isinstance(yhat,list) or isinstance(yhat,tuple):
            yhat, loss_dp = yhat
            loss = self.lossFn(yhat,y)
            loss = loss + loss_dp
        else:
            loss = self.lossFn(yhat, y)
        return loss
    
    def forward(self, x):
        result = self.net(x)
        return result
        
    def sliding_window_inference(self, x): # Inference Sliding window using MONAI API: Using this only valid and test when size of input is larger than 2048
        def predictor(x, return_idx = 0): # in case of network gets multiple output, we will use only 1st output
            result = self.forward(x)
            if isinstance(result, list) or isinstance(result, tuple):
                return result[return_idx]
            else:
                return result        
        return sliding_window_inference(x, self.featureLength, 8, predictor, mode='gaussian', overlap=0.75)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        # return {'optimizer': optimizer}
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=4, min_lr=1e-6)
        # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-3, pct_start=0.02, total_steps=self.trainer.estimated_stepping_batches)
        return {'optimizer': optimizer,
                'lr_scheduler': {'scheduler': scheduler, 'monitor': 'val_loss'}}
    
    def pipeline(self, batch, batch_idx):
        
        x = batch['signal'].float()
        y = batch['y_seg'].float()
        
        dataSource = batch['dataSource'][0]
        
        fname = batch['fname']
        pid = batch['pid']
        time = batch['time']
        
        yhat = self.sliding_window_inference(x) if x.shape[-1] > self.featureLength else self.forward(x)
        loss = self.compute_loss(yhat, y)

        if isinstance(yhat,tuple) or isinstance(yhat,list): # in case multi output model such as U2NET while training
            yhat = yhat[0]
                
        return {'loss':loss, "x": x, "y": y, "yhat":yhat, "dataSource":dataSource,'fname':fname}
    
    def training_step(self, batch, batch_idx):        
        result = self.pipeline(batch, batch_idx)
        self.log('loss', result['loss'], on_step=False, on_epoch=True, prog_bar=True)
        return {"loss":result['loss'], "x": result['x'], "y": result['y'], "yhat":result['yhat'], "dataSource":result['dataSource'], 'fname':result['fname']}

    def validation_step(self, batch, batch_idx):
        result = self.pipeline(batch, batch_idx)
        self.log('val_loss', result['loss'], on_step=False, on_epoch=True, prog_bar=True)
        return {"val_loss":result['loss'], "x": result['x'], "y": result['y'], "yhat":result['yhat'], "dataSource":result['dataSource'], 'fname':result['fname']}
    
    def test_step(self, batch, batch_idx):
        result = self.pipeline(batch, batch_idx)
        self.log('test_loss', result['loss'], on_step=False, on_epoch=True)
        return {"test_loss":result['loss'], "x": result['x'], "y": result['y'], "yhat":result['yhat'], "dataSource":result['dataSource'], 'fname':result['fname']}
    
    def validation_epoch_end(self, outputs):
        
        self.fnames = []
        data = valid_data
        for d in data:
            self.fnames.append(f"{d['pid']}_{d['time']}")

        self.evaluations(outputs,'val', False)

    def test_epoch_end(self, outputs):        
        
        self.fnames = []
        if outputs[0]['dataSource'][0]==3:
            self.dataSource = 'testMIT'
            data = test_data
        elif outputs[0]['dataSource'][0]==11:
            self.dataSource = 'testAMC'
            data = AMC_data
        elif outputs[0]['dataSource'][0]==12:
            self.dataSource = 'testCPSC2020'
            data = CPSC2020_data
        elif outputs[0]['dataSource'][0]==13:
            self.dataSource = 'testINCART'
            data = INCART_data
        
        for d in data:
            self.fnames.append(f"{d['pid']}_{d['time']}")

        self.evaluations(outputs, self.dataSource, True)
    
    def apply_threshold(self, pred, t):
        try:
            result = pred.clone()
        except:
            result = pred.copy()
        result[result>=t]= 1
        result[result<t]= 0
        return result
    
    def evaluations(self, outputs, dataSource, plot=False):
        
        # fnames = []
        xs = []
        ys= []
        yhatsRaw = []                    
        yhatsRefined = []                    
        
        ysClass1Raw = []
        ysClass2Raw = []
        ysClass3Raw = []
        
        yhatsClass1Raw = []
        yhatsClass2Raw = []
        yhatsClass3Raw = []
        
        ysClass1Refined = []
        ysClass2Refined = []
        ysClass3Refined = []
        
        yhatsClass1Refined = []
        yhatsClass2Refined = []
        yhatsClass3Refined = []

        RRaw_TP = 0
        RRaw_FP = 0
        RRaw_FN = 0
        RRefined_TP = 0
        RRefined_FP = 0
        RRefined_FN = 0
        
        for output in outputs:
            # fnames.extend(output['fname'])
            xs.extend(output["x"].cpu().detach().numpy())
            ys.extend(output["y"].cpu().detach().numpy())

            for i in range(len(output["y"])):
                y = output["y"][i].cpu().detach().numpy()
                yhatRaw = output["yhat"][i].cpu().detach().numpy()
                yhatRefined = self.postProcessByRPeak(output["yhat"][i].cpu().detach().numpy())
                
                yhatsRaw.append(yhatRaw)
                yhatsRefined.append(yhatRefined)
                
                yhatRaw_eval = self.eval_Peak(yhatRaw, y)                
                RRaw_TP += yhatRaw_eval['R_TPs']
                RRaw_FP += yhatRaw_eval['R_FPs']
                RRaw_FN += yhatRaw_eval['R_FNs']
                ysClass1Raw.extend(yhatRaw_eval['ys_class1'])
                yhatsClass1Raw.extend(yhatRaw_eval['yhats_class1'])                
                ysClass2Raw.extend(yhatRaw_eval['ys_class2'])
                yhatsClass2Raw.extend(yhatRaw_eval['yhats_class2'])                
                ysClass3Raw.extend(yhatRaw_eval['ys_class3'])
                yhatsClass3Raw.extend(yhatRaw_eval['yhats_class3'])                

                yhatRefined_eval = self.eval_Peak(yhatRefined, y)                
                RRefined_TP += yhatRefined_eval['R_TPs']
                RRefined_FP += yhatRefined_eval['R_FPs']
                RRefined_FN += yhatRefined_eval['R_FNs']
                ysClass1Refined.extend(yhatRefined_eval['ys_class1'])
                yhatsClass1Refined.extend(yhatRefined_eval['yhats_class1'])
                ysClass2Refined.extend(yhatRefined_eval['ys_class2'])
                yhatsClass2Refined.extend(yhatRefined_eval['yhats_class2'])
                ysClass3Refined.extend(yhatRefined_eval['ys_class3'])
                yhatsClass3Refined.extend(yhatRefined_eval['yhats_class3'])
                
        del output
        
        fnames = np.array(self.fnames)
        xs = np.array(xs)
        ys = np.array(ys)
        yhatsRaw = np.array(yhatsRaw)
        yhatsRefined = np.array(yhatsRefined)
        
        ysClass1Raw = np.array(ysClass1Raw)
        ysClass2Raw = np.array(ysClass2Raw)
        ysClass3Raw = np.array(ysClass3Raw)
        ysClass1Refined = np.array(ysClass1Refined)
        ysClass2Refined = np.array(ysClass2Refined)
        ysClass3Refined = np.array(ysClass3Refined)
        yhatsClass1Raw = np.array(yhatsClass1Raw)
        yhatsClass2Raw = np.array(yhatsClass2Raw)
        yhatsClass3Raw = np.array(yhatsClass3Raw)
        yhatsClass1Refined = np.array(yhatsClass1Refined)
        yhatsClass2Refined = np.array(yhatsClass2Refined)
        yhatsClass3Refined = np.array(yhatsClass3Refined)
        
#         print('Raw',ys.shape,yhatsRaw.shape,yhatsRefined.shape)
#         print('picked',ysClass1Raw.shape,yhatsClass1Raw.shape,ysClass1Refined.shape,yhatsClass1Refined.shape,
#                       ysClass2Raw.shape,yhatsClass2Raw.shape,ysClass2Refined.shape,yhatsClass2Refined.shape,
#                       ysClass3Raw.shape,yhatsClass3Raw.shape,ysClass3Refined.shape,yhatsClass3Refined.shape)
        
        def eval_cm(ys, yhats, name):
            """
            input shape shoud be only [ B ]
            """
            
            TP = 0
            FN = 0
            FP = 0
            TN = 0

            auc = sklearn.metrics.roc_auc_score(ys, yhats)
            ap  = sklearn.metrics.average_precision_score(ys, yhats)
            self.youden_index = Youden_index(ys, yhats)
            
            negativeIdx = np.where(ys == 0)
            positiveIdx = np.where(ys != 0)
            
            for i in range(len(negativeIdx[0])):
                z = negativeIdx[0][i]
                FP = FP+1 if yhats[z]>=self.youden_index else FP
                TN = TN+1 if yhats[z]<self.youden_index else TN
            for i in range(len(positiveIdx[0])):
                n = positiveIdx[0][i]
                TP = TP+1 if yhats[n]>=self.youden_index else TP
                FN = FN+1 if yhats[n]<self.youden_index else FN
            
            # print(TP,TN,FP,FN)
            sen = TP/(TP+FN)
            spe = TN/(TN+FP)
            acc = (TP+TN)/(TP+FN+FP+TN)
            bacc = (sen+spe)/2
            
            # f1 = sklearn.metrics.f1_score(gt1, apply_threshold(yhat1,yi))
            # pr  = sklearn.metrics.precision_score(gt1,apply_threshold(yhat1,yi))
            f1 = 2*TP/(2*TP+FP+FN)
            ppv = TP/(TP+FP)
            npv = TN/(TN+FN)

            # plot performance
            plt.figure(figsize=(20,4))
            plt.subplot(131)           
            plt.title(f'Histogram of likelihood (Youden index : {self.youden_index:.3f})')
            plt.xlabel('Likelihood')
            plt.ylabel('Normalized samples')
            plt.hist(yhats[negativeIdx],bins=50,density=True,label='Likelihood of Negative Cases',alpha=0.5)
            plt.hist(yhats[positiveIdx],bins=50,density=True,label='Likelihood of Positive Cases',alpha=0.5)
            plt.vlines(self.youden_index,0,10,label='Youden index',color='r',alpha=0.5)
            plt.xlim([0,1])
            plt.legend()
            
            plt.subplot(132)
            plt.title(f'ROC (AUROC : {auc:.3f})')
            plt.xlabel('1-Specificity')
            plt.ylabel('Sensitivity')
            fpr, tpr, _ = roc_curve(ys, yhats)
            plt.plot(fpr, tpr)
            
            plt.subplot(133)        
            plt.title(f'Precision-Recall (AP : {ap:.3f})')
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            prec, recall, _ = precision_recall_curve(ys, yhats)
            plt.plot(recall, prec)
            # plt.show()
            os.makedirs(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/metric/", mode=0o777, exist_ok=True)
            plt.savefig(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/metric/performance_{dataSource}_{name}.png")
            plt.close()
            
            return {"TP":TP,"TN":TN,"FN":FN,"FP":FP,
                    "auc":auc,"ap":ap,"sen":sen,"spe":spe,"acc":acc,"bacc":bacc,"f1":f1,"ppv":ppv,"npv":npv,"youdenIndex":self.youden_index}
        
        try:
            cmClass1Raw = eval_cm(ysClass1Raw, yhatsClass1Raw,'Class1Raw')
            cmClass1Refined = eval_cm(ysClass1Refined, yhatsClass1Refined,'Class1Refined')
            cmClass2Raw = eval_cm(ysClass2Raw, yhatsClass2Raw,'Class2Raw')
            cmClass2Refined = eval_cm(ysClass2Refined, yhatsClass2Refined,'Class2Refined')
            cmClass3Raw = eval_cm(ysClass3Raw, yhatsClass3Raw,'Class3Raw')
            cmClass3Refined = eval_cm(ysClass3Refined, yhatsClass3Refined,'Class3Refined')
        except:
            pass
        
        RRaw_sen = RRaw_TP/(RRaw_TP+RRaw_FN)
        RRaw_pp  = RRaw_TP/(RRaw_TP+RRaw_FP)
        RRaw_err = (RRaw_FP+RRaw_FN)/(RRaw_TP+RRaw_FP+RRaw_FN)

        RRefined_sen = RRefined_TP/(RRefined_TP+RRefined_FN)
        RRefined_pp  = RRefined_TP/(RRefined_TP+RRefined_FP)
        RRefined_err = (RRefined_FP+RRefined_FN)/(RRefined_TP+RRefined_FP+RRefined_FN)
            
        def logcmResult(cmPVC, refined):
            
            self.log(f'{dataSource}_TP_{refined}',cmPVC['TP'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_FN_{refined}',cmPVC['FN'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_FP_{refined}',cmPVC['FP'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_TN_{refined}',cmPVC['TN'],on_step=False,on_epoch=True)

            self.log(f'{dataSource}_AUPRC_{refined}',cmPVC['ap'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_AUROC_{refined}',cmPVC['auc'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_ACC_{refined}',cmPVC['acc'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_SEN_{refined}',cmPVC['sen'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_SPE_{refined}',cmPVC['spe'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_BACC_{refined}',cmPVC['bacc'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_F1_{refined}',cmPVC['f1'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_PPV_{refined}',cmPVC['ppv'],on_step=False,on_epoch=True)
            self.log(f'{dataSource}_NPV_{refined}',cmPVC['npv'],on_step=False,on_epoch=True)
            
            self.log(f'{dataSource}_YoudenIndex_{refined}',cmPVC['youdenIndex'],on_step=False,on_epoch=True)
            
        def logRResult(R_sen,R_ppv,R_der, refined):
            self.log(f'{dataSource}_R-DER_{refined}',R_der,on_step=False,on_epoch=True)
            self.log(f'{dataSource}_R-PPV_{refined}',R_ppv,on_step=False,on_epoch=True)
            self.log(f'{dataSource}_R-SEN_{refined}',R_sen,on_step=False,on_epoch=True)
        
        try:
            logcmResult(cmClass1Raw,'Class1Raw')
            logcmResult(cmClass1Refined,'Class1Refined')
        except:
            pass
        try:
            logcmResult(cmClass2Raw,'Class2Raw')
            logcmResult(cmClass2Refined,'Class2Refined')
        except:
            pass
        try:
            logcmResult(cmClass3Raw,'Class3Raw')
            logcmResult(cmClass3Refined,'Class3Refined')
        except:
            pass
        
        logRResult(RRaw_sen,RRaw_pp,RRaw_err, 'Raw')
        logRResult(RRefined_sen,RRefined_pp,RRefined_err,'Refined')
        
        os.makedirs(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/metric/", mode=0o777, exist_ok=True)

        if len(ysClass3Raw)!=0 and len(ysClass2Raw)!=0:
            df = pd.DataFrame([ysClass1Raw, yhatsClass1Raw, ysClass2Raw, yhatsClass2Raw, ysClass3Raw, yhatsClass3Raw],
                              ['ysClass1Raw','yhatsClass1Raw','ysClass2Raw','yhatsClass2Raw','ysClass3Raw','yhatsClass3Raw'])
            df = df.T
            df.to_csv(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/metric/likelihood_{dataSource}_Raw.csv",index=False)

            df = pd.DataFrame([ysClass1Refined, yhatsClass1Refined,ysClass2Refined, yhatsClass2Refined,ysClass3Refined, yhatsClass3Refined],
                              ['ysClass1Refined','yhatsClass1Refined','ysClass2Refined','yhatsClass2Refined','ysClass3Refined','yhatsClass3Refined'])
            df = df.T
            df.to_csv(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/metric/likelihood_{dataSource}_Refined.csv",index=False)
        elif len(ysClass2Raw)!=0:
            df = pd.DataFrame([ysClass1Raw, yhatsClass1Raw, ysClass2Raw, yhatsClass2Raw],['ysClass1Raw','yhatsClass1Raw','ysClass2Raw','yhatsClass2Raw'])
            df = df.T
            df.to_csv(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/metric/likelihood_{dataSource}_Raw.csv",index=False)

            df = pd.DataFrame([ysClass1Refined, yhatsClass1Refined,ysClass2Refined, yhatsClass2Refined],
                              ['ysClass1Refined','yhatsClass1Refined','ysClass2Refined','yhatsClass2Refined'])
            df = df.T
            df.to_csv(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/metric/likelihood_{dataSource}_Refined.csv",index=False)
        else:
            df = pd.DataFrame([ysClass1Raw, yhatsClass1Raw],['ysClass1Raw','yhatsClass1Raw'])
            df = df.T
            df.to_csv(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/metric/likelihood_{dataSource}_Raw.csv",index=False)

            df = pd.DataFrame([ysClass1Refined, yhatsClass1Refined],['ysClass1Refined','yhatsClass1Refined'])
            df = df.T
            df.to_csv(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/metric/likelihood_{dataSource}_Refined.csv",index=False)
                
        for i in range(len(yhatsRefined)):
            yhatsRefined[i] = self.postProcessByYoudenIndex(yhatsRefined[i], cmClass1Refined['youdenIndex'])
        # try:
        #     yhatsRefined[:,1] = self.postProcessByYoudenIndex(yhatsRefined[:,1], cmClass1Refined['youdenIndex'])
        # except:
        #     pass
        self.plotCaseResult(xs, ys, yhatsRaw, self.apply_threshold(yhatsRefined,cmClass1Raw['youdenIndex']), fnames) if self.testPlot else 0

    def eval_Peak(self, yhat, y):
        
        """
        input y: CxSignal
        input yhat : CXSignal
        """
        classes = yhat.shape[0]-1

        try:
            yhat_ = yhat.clone()
        except:
            yhat_ = yhat.copy()
            
        yhat_[0] = self.apply_threshold(yhat_[0],self.thresholdRPeak)
            
        ys_class1 = []
        ys_class2 = []
        ys_class3 = []
        yhats_class1 = []
        yhats_class2 = []
        yhats_class3 = []

        # evalutation of R-peak
        R_TP = []
        R_FP = []
        R_FN = []
        
        result_y, count_y = label(y[0])

        for j in range(1, count_y+1):
            index = np.where(result_y == j)[0]
            start = index[0]
            end = index[-1]
            
            try:
                yhat0_mean = torch.nanmean(yhat_[0,start:end+1])
                yhat1_mean = torch.nanmean(yhat_[1,start:end+1])
                yhat2_mean = torch.nanmean(yhat_[2,start:end+1])
                yhat3_mean = torch.nanmean(yhat_[3,start:end+1])
            except:
                try:
                    yhat0_mean = np.nanmean(yhat_[0,start:end+1])
                    yhat1_mean = np.nanmean(yhat_[1,start:end+1])
                    yhat2_mean = np.nanmean(yhat_[2,start:end+1])
                    yhat3_mean = np.nanmean(yhat_[3,start:end+1])
                except:
                    pass
                
            # evalutation of R-peak : TP, FN
            if 1 in y[0,start:end+1] and yhat0_mean>=self.thresholdRPeak:
                R_TP.append(1)
            elif 1 in y[0,start:end+1] and yhat0_mean<self.thresholdRPeak:
                R_FN.append(1)            
                
            # evalutation of PVC : just return likelihood
            if 0 in y[1,start:end]:
                ys_class1.append(0)
                yhats_class1.append(yhat1_mean)
            elif 1 in y[1,start:end]:
                ys_class1.append(1)
                yhats_class1.append(yhat1_mean)
                
            try:
                if 0 in y[2,start:end]:
                    ys_class2.append(0)
                    yhats_class2.append(yhat2_mean)
                elif 1 in y[2,start:end]:
                    ys_class2.append(1)
                    yhats_class2.append(yhat2_mean)

                if 0 in y[3,start:end]:
                    ys_class3.append(0)
                    yhats_class3.append(yhat3_mean)
                elif 1 in y[3,start:end]:
                    ys_class3.append(1)
                    yhats_class3.append(yhat3_mean)
            except:
                pass
            
        # print('B',count_y,np.array(ys).shape, np.array(yhats).shape, len(R_TP),len(R_FP),len(R_FN))

        result_yhat, count_yhat = label(yhat_[0])

        for j in range(1,count_yhat+1):
            index = np.where(result_yhat == j)[0]
            start = index[0]
            end = index[-1]
            
            try:
                yhat0_mean = torch.nanmean(yhat_[0,start:end+1])
                yhat1_mean = torch.nanmean(yhat_[1,start:end+1])
                yhat2_mean = torch.nanmean(yhat_[2,start:end+1])
                yhat3_mean = torch.nanmean(yhat_[3,start:end+1])
            except:                
                try:
                    yhat0_mean = np.nanmean(yhat_[0,start:end+1])
                    yhat1_mean = np.nanmean(yhat_[1,start:end+1])
                    yhat2_mean = np.nanmean(yhat_[2,start:end+1])
                    yhat3_mean = np.nanmean(yhat_[3,start:end+1])
                except:
                    pass
                

            # evalutation of R-peak : FP
            if 1 not in y[0,start:end+1]:
                R_FP.append(1)

            # evalutation of PVC : FP
            if 1 not in y[1,start:end+1] and 1 in yhat_[1,start:end+1]:
                ys_class1.append(0)
                yhats_class1.append(yhat1_mean)
            
            try:
                if 1 not in y[2,start:end+1] and 1 in yhat_[2,start:end+1]:
                    ys_class2.append(0)
                    yhats_class2.append(yhat2_mean)
                    
                if 1 not in y[3,start:end+1] and 1 in yhat_[3,start:end+1]:
                    ys_class3.append(0)
                    yhats_class3.append(yhat3_mean)
            except:
                pass
        
        return {
                'R_TPs':np.sum(R_TP), 'R_FNs':np.sum(R_FN), 'R_FPs':np.sum(R_FP),
                'ys_class1':np.array(ys_class1), 'yhats_class1':np.array(yhats_class1),
                'ys_class2':np.array(ys_class2), 'yhats_class2':np.array(yhats_class2),
                'ys_class3':np.array(ys_class3), 'yhats_class3':np.array(yhats_class3)
               }
    
    def postProcessByRPeak(self, yhat):
        """
        input : yhat [C x S]
        output : yhat [C x S]
        
        Rule 0. R-peak는 self.thresholdRPeak로 binarize한다.
        Rule 1. R-peak의 간격이 특정 간격보다 작으면 무시한다. 
        Rule 2. R-peak가 아니면 PVC, AFIB도 아니다. # R-peak에 살짝 마진을 주고 PVC, AFIB과 곱해준다
        threshold = int(srTarget*.05)
        """

        yhat_ = yhat.copy()
        # Rule 0.
        yhat_[0] = self.apply_threshold(yhat_[0], self.thresholdRPeak)
        
        # Rule 1. fill in and remove R- peak
        threshold_hole = int(self.srTarget*.2*0.2) # 20% of R-peak 
        yhat_[0] = morphology.remove_small_holes(yhat_[0].astype(bool), threshold_hole).astype(float) # fill in
        
        threshold_object = int(self.srTarget*.2*0.7) # 70% of R-peak seg
        yhat_[0] = morphology.remove_small_objects(yhat_[0].astype(bool), threshold_object).astype(float) # remove small R-peak
        yhat_0_dilated = ndimage.binary_dilation(yhat_[0],iterations=int(self.srTarget*.1*0.1))         #dilated
        
        # Rule 2.
        yhat_[1] = yhat_0_dilated*yhat_[1] 
        try:
            yhat_[2] = yhat_0_dilated*yhat_[2] 
            yhat_[3] = yhat_0_dilated*yhat_[3] 
        except:
            pass
        return yhat_
    
    def postProcessByYoudenIndex(self, yhat, threshold):
        """
        input : yhat [C x S]
        output : yhat [C x S]        
        """
        result, count_yhat = label(yhat[0])
        for j in range(1, count_yhat+1):
            index = np.where(result == j)[0]
            start = index[0]
            end = index[-1]
            margin = int(self.srTarget*.2*.1)
            
            yhat1_mean = np.nanmean(yhat[1,start:end+1])
            
            # evalutation of PVC : FP
            if yhat1_mean >= threshold:
                yhat[1,start-margin:end+1+margin] = 1
            else:
                yhat[1,start-margin:end+1+margin] = 0
                
        yhat = yhat.round()
        return yhat
    
    def plotCaseResult(self, x, y, yhat1, yhat2, fname):
        t = np.linspace(0,x.shape[-1]/self.srTarget, x.shape[-1]) # for x-axis ticks
        
        for idx in range(len(y)):
            plt.figure(figsize=(20,12))
            plt.subplot(221)
            plt.title(f'Prediction result of {self.dataSource}')
            plt.xlabel('Time (s)')
            plt.ylabel('Normalized ECG')
            plt.plot(t,x[idx,0],alpha=0.9,color='black',label='ECG signal')
            plt.plot(t,yhat1[idx,0],alpha=0.7,color='b',label='R Peak prediction (Likelihood)')
            plt.plot(t,yhat1[idx,1],alpha=0.7,color='r',label='PVC prediction (Likelihood)')
            # other annotation            
            try:
                plt.plot(t,yhat1[idx,2],alpha=0.7,color='g',label='AFIB prediction (Likelihood)')
                plt.plot(t,yhat1[idx,3],alpha=0.7,color='orange',label='Others prediction (Likelihood)')
            except:
                pass
            plt.xticks(np.arange(0, len(t)/self.srTarget, step=1))
            plt.ylim([0,1.5])
            plt.legend(loc=1)
            
            plt.subplot(222)
            plt.title(f'Raw signal of {self.dataSource}')
            plt.xlabel('Time (s)')
            plt.ylabel('Normalized ECG')
            plt.plot(t,x[idx,0],alpha=0.9,color='black',label='ECG signal')
            plt.xticks(np.arange(0, len(t)/self.srTarget, step=1))
            plt.ylim([0,1.5])
            plt.legend(loc=1)

            
            plt.subplot(223)
            plt.title(f'Refined Prediction result of {self.dataSource}')
            plt.xlabel('Time (s)')
            plt.ylabel('Normalized ECG')
            plt.plot(t,x[idx,0],alpha=0.9,color='black',label='ECG signal')
            plt.plot(t,yhat2[idx,0],alpha=0.7,color='b',label='R Peak prediction (Binarized)')
            plt.plot(t,yhat2[idx,1],alpha=0.7,color='r',label='PVC prediction (Binarized)')
            # other annotation
            try:
                plt.plot(t,ndimage.binary_dilation(yhat2[idx,2],iterations=int(srTarget*.1*0.2)),alpha=0.7,color='g',label='AFIB prediction (Likelihood)')
                plt.plot(t,ndimage.binary_dilation(yhat2[idx,3],iterations=int(srTarget*.1*0.3)),alpha=0.7,color='orange',label='Others prediction (Likelihood)')
            except:
                pass
            plt.xticks(np.arange(0, len(t)/self.srTarget, step=1))
            plt.ylim([0,1.5])
            plt.legend(loc=1)

            plt.subplot(224)
            plt.title(f'Ground truth of {self.dataSource}')
            plt.xlabel('Time (s)')
            plt.ylabel('Normalized ECG')
            plt.plot(t,x[idx,0],alpha=0.9,color='black',label='ECG signal')
            plt.plot(t,y[idx,0],alpha=0.7,color='b',label='R Peak GT') if y is not None else y
            plt.plot(t,y[idx,1],alpha=0.7,color='r',label='PVC GT') if y is not None else y
            # other annotation            
            try:
                plt.plot(t,ndimage.binary_dilation(y[idx,2],iterations=int(srTarget*.1*0.2)),alpha=0.7,color='g',label='AFIB prediction (Likelihood)')
                plt.plot(t,ndimage.binary_dilation(y[idx,3],iterations=int(srTarget*.1*0.3)),alpha=0.7,color='orange',label='Others prediction (Likelihood)')
            except:
                pass
            plt.xticks(np.arange(0, len(t)/self.srTarget, step=1))
            plt.ylim([0,1.5])
            plt.legend(loc=1)
            plt.tight_layout()
            # plt.show()

            os.makedirs(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/result/", mode=0o777, exist_ok=True)
            os.makedirs(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/result/{self.dataSource}/", mode=0o777, exist_ok=True)
            plt.savefig(f"{self.hyperparameters['path_logRoot']}/{self.experiment_name}/result/{self.dataSource}/{str(fname[idx])}.png")
            plt.close()

In [7]:
train_data = np.load('dataset/mit-bih-arrhythmia-database-1.0.0_trainSeg_seed4.npy',allow_pickle=True) # B x (C) x Signal
valid_data = np.load('dataset/mit-bih-arrhythmia-database-1.0.0_validSeg_seed4.npy',allow_pickle=True) # B x (C) x Signal
test_data = np.load('dataset/mit-bih-arrhythmia-database-1.0.0_testSeg.npy',allow_pickle=True) # B x (C) x Signal

Fantasia_data = np.load('dataset/fantasia-database-1.0.0_testSeg.npy', allow_pickle=True) # B x (C) x Signal
AMC_data  = np.load('dataset/AMC_PeakLabel_3rd_125Hz.npy',allow_pickle=True) # 497 samples
INCART_data  = np.load('dataset/INCART_testSeg.npy',allow_pickle=True)
CPSC2020_data  = np.load('dataset/CPSC2020_testSeg.npy',allow_pickle=True)
AMCREAL_data = np.load('dataset/AMCREAL_testSeg.npy',allow_pickle=True)
print(len(train_data), len(valid_data), len(test_data), len(Fantasia_data), len(AMC_data), len(CPSC2020_data), len(AMCREAL_data))


def add_datainfo(data, info_string):
    new_data = []
    for d in data:
        d['dataSource'] = info_string
        new_data.append(d)
    return np.array(new_data)

# add_datainfo(train_data,'train')
# add_datainfo(valid_data,'val')
# add_datainfo(test_data,'test')

# add_datainfo(AMC_data,'AMC')
# add_datainfo(CPSC2020_data,'CPSC2020')
# add_datainfo(INCART_data,'INCART')
# add_datainfo(Fantasia_data,'Fantasia')
# add_datainfo(AMCREAL_data,'AMCREAL')

add_datainfo(train_data,1)
add_datainfo(valid_data,2)
add_datainfo(test_data,3)

add_datainfo(AMC_data,11)
add_datainfo(CPSC2020_data,12)
add_datainfo(INCART_data,13)
add_datainfo(Fantasia_data,14)
add_datainfo(AMCREAL_data,15)
print()


from warnings import warn
import numpy as np

from neurokit2.misc import NeuroKitWarning, listify
from neurokit2.signal.signal_resample import signal_resample
from neurokit2.signal.signal_simulate import signal_simulate

def signal_distort(
    signal,
    sampling_rate=1000,
    noise_shape="laplace",
    noise_amplitude=0,
    noise_frequency=100,
    powerline_amplitude=0,
    powerline_frequency=50,
    artifacts_amplitude=0,
    artifacts_frequency=100,
    artifacts_number=5,
    linear_drift=False,
    random_state=None,
    silent=False,
):
    """**Signal distortion**

    Add noise of a given frequency, amplitude and shape to a signal.

    Parameters
    ----------
    signal : Union[list, np.array, pd.Series]
        The signal (i.e., a time series) in the form of a vector of values.
    sampling_rate : int
        The sampling frequency of the signal (in Hz, i.e., samples/second).
    noise_shape : str
        The shape of the noise. Can be one of ``"laplace"`` (default) or
        ``"gaussian"``.
    noise_amplitude : float
        The amplitude of the noise (the scale of the random function, relative
        to the standard deviation of the signal).
    noise_frequency : float
        The frequency of the noise (in Hz, i.e., samples/second).
    powerline_amplitude : float
        The amplitude of the powerline noise (relative to the standard
        deviation of the signal).
    powerline_frequency : float
        The frequency of the powerline noise (in Hz, i.e., samples/second).
    artifacts_amplitude : float
        The amplitude of the artifacts (relative to the standard deviation of
        the signal).
    artifacts_frequency : int
        The frequency of the artifacts (in Hz, i.e., samples/second).
    artifacts_number : int
        The number of artifact bursts. The bursts have a random duration
        between 1 and 10% of the signal duration.
    linear_drift : bool
        Whether or not to add linear drift to the signal.
    random_state : int
        Seed for the random number generator. Keep it fixed for reproducible
        results.
    silent : bool
        Whether or not to display warning messages.

    Returns
    -------
    array
        Vector containing the distorted signal.

    Examples
    --------
    .. ipython:: python

      import numpy as np
      import pandas as pd
      import neurokit2 as nk

      signal = nk.signal_simulate(duration=10, frequency=0.5)

      # Noise
      @savefig p_signal_distort1.png scale=100%
      noise = pd.DataFrame({"Freq100": nk.signal_distort(signal, noise_frequency=200),
                           "Freq50": nk.signal_distort(signal, noise_frequency=50),
                           "Freq10": nk.signal_distort(signal, noise_frequency=10),
                           "Freq5": nk.signal_distort(signal, noise_frequency=5),
                           "Raw": signal}).plot()
      @suppress
      plt.close()

    .. ipython:: python

      # Artifacts
      @savefig p_signal_distort2.png scale=100%
      artifacts = pd.DataFrame({"1Hz": nk.signal_distort(signal, noise_amplitude=0,
                                                        artifacts_frequency=1,
                                                        artifacts_amplitude=0.5),
                               "5Hz": nk.signal_distort(signal, noise_amplitude=0,
                                                        artifacts_frequency=5,
                                                        artifacts_amplitude=0.2),
                               "Raw": signal}).plot()
      @suppress
      plt.close()

    """
    # Seed the random generator for reproducible results.
    # np.random.seed(random_state)

    # Make sure that noise_amplitude is a list.
    if isinstance(noise_amplitude, (int, float)):
        noise_amplitude = [noise_amplitude]

    signal_sd = np.std(signal, ddof=1)
    if signal_sd == 0:
        signal_sd = None

    noise = 0

    # Basic noise.
    if min(noise_amplitude) > 0:
        noise += _signal_distort_noise_multifrequency(
            signal,
            signal_sd=signal_sd,
            sampling_rate=sampling_rate,
            noise_amplitude=noise_amplitude,
            noise_frequency=noise_frequency,
            noise_shape=noise_shape,
            silent=silent,
        )

    # Powerline noise.
    if powerline_amplitude > 0:
        noise += _signal_distort_powerline(
            signal,
            signal_sd=signal_sd,
            sampling_rate=sampling_rate,
            powerline_frequency=powerline_frequency,
            powerline_amplitude=powerline_amplitude,
            silent=silent,
        )

    # Artifacts.
    if artifacts_amplitude > 0:
        noise += _signal_distort_artifacts(
            signal,
            signal_sd=signal_sd,
            sampling_rate=sampling_rate,
            artifacts_frequency=artifacts_frequency,
            artifacts_amplitude=artifacts_amplitude,
            artifacts_number=artifacts_number,
            silent=silent,
        )

    if linear_drift:
        noise += _signal_linear_drift(signal)

    distorted = signal + noise

    # Reset random seed (so it doesn't affect global)
    # np.random.seed(None)

    return distorted

def _signal_distort_artifacts(
    signal,
    signal_sd=None,
    sampling_rate=1000,
    artifacts_frequency=0,
    artifacts_amplitude=0.1,
    artifacts_number=5,
    artifacts_shape="laplace",
    silent=False,
):

    # Generate artifact burst with random onset and random duration.
    artifacts = _signal_distort_noise(
        len(signal),
        sampling_rate=sampling_rate,
        noise_frequency=artifacts_frequency,
        noise_amplitude=artifacts_amplitude,
        noise_shape=artifacts_shape,
        silent=silent,
    )
    if artifacts.sum() == 0:
        return artifacts

    min_duration = int(np.rint(len(artifacts) * 0.001))
    max_duration = int(np.rint(len(artifacts) * 0.01))
    artifact_durations = np.random.randint(min_duration, max_duration, artifacts_number)

    artifact_onsets = np.random.randint(0, len(artifacts) - max_duration, artifacts_number)
    artifact_offsets = artifact_onsets + artifact_durations

    artifact_idcs = np.array([False] * len(artifacts))
    for i in range(artifacts_number):
        artifact_idcs[artifact_onsets[i] : artifact_offsets[i]] = True

    artifacts[~artifact_idcs] = 0

    # Scale amplitude by the signal's standard deviation.
    if signal_sd is not None:
        artifacts_amplitude *= signal_sd
    artifacts *= artifacts_amplitude

    return artifacts

def _signal_distort_noise_multifrequency(
    signal,
    signal_sd=None,
    sampling_rate=1000,
    noise_amplitude=0.1,
    noise_frequency=100,
    noise_shape="laplace",
    silent=False,
):
    base_noise = np.zeros(len(signal))
    params = listify(
        noise_amplitude=noise_amplitude, noise_frequency=noise_frequency, noise_shape=noise_shape
    )

    for i in range(len(params["noise_amplitude"])):

        freq = params["noise_frequency"][i]
        amp = params["noise_amplitude"][i]
        shape = params["noise_shape"][i]

        if signal_sd is not None:
            amp *= signal_sd

        # Make some noise!
        _base_noise = _signal_distort_noise(
            len(signal),
            sampling_rate=sampling_rate,
            noise_frequency=freq,
            noise_amplitude=amp,
            noise_shape=shape,
            silent=silent,
        )
        base_noise += _base_noise

    return base_noise


def _signal_distort_noise(
    n_samples,
    sampling_rate=1000,
    noise_frequency=100,
    noise_amplitude=0.1,
    noise_shape="laplace",
    silent=False,
):

    _noise = np.zeros(n_samples)
    # Apply a very conservative Nyquist criterion in order to ensure
    # sufficiently sampled signals.
    nyquist = sampling_rate * 0.1
    if noise_frequency > nyquist:
        if not silent:
            warn(
                f"Skipping requested noise frequency "
                f" of {noise_frequency} Hz since it cannot be resolved at "
                f" the sampling rate of {sampling_rate} Hz. Please increase "
                f" sampling rate to {noise_frequency * 10} Hz or choose "
                f" frequencies smaller than or equal to {nyquist} Hz.",
                category=NeuroKitWarning,
            )
        return _noise
    # Also make sure that at least one period of the frequency can be
    # captured over the duration of the signal.
    duration = n_samples / sampling_rate
    if (1 / noise_frequency) > duration:
        if not silent:
            warn(
                f"Skipping requested noise frequency "
                f" of {noise_frequency} Hz since its period of {1 / noise_frequency} "
                f" seconds exceeds the signal duration of {duration} seconds. "
                f" Please choose noise frequencies larger than "
                f" {1 / duration} Hz or increase the duration of the "
                f" signal above {1 / noise_frequency} seconds.",
                category=NeuroKitWarning,
            )
        return _noise

    noise_duration = int(duration * noise_frequency)

    if noise_shape in ["normal", "gaussian"]:
        _noise = np.random.normal(0, noise_amplitude, noise_duration)
    elif noise_shape == "laplace":
        _noise = np.random.laplace(0, noise_amplitude, noise_duration)
    else:
        raise ValueError(
            "NeuroKit error: signal_distort(): 'noise_shape' should be one of 'gaussian' or 'laplace'."
        )

    if len(_noise) != n_samples:
        _noise = signal_resample(_noise, desired_length=n_samples, method="interpolation")
    return _noise


def _signal_distort_powerline(
    signal,
    signal_sd=None,
    sampling_rate=1000,
    powerline_frequency=50,
    powerline_amplitude=0.1,
    silent=False,
):

    duration = len(signal) / sampling_rate
    powerline_noise = signal_simulate(
        duration=duration,
        sampling_rate=sampling_rate,
        frequency=powerline_frequency,
        amplitude=1,
        silent=silent,
    )

    if signal_sd is not None:
        powerline_amplitude *= signal_sd
    powerline_noise *= powerline_amplitude

    return powerline_noise
import scipy
import scipy.io as sio
from scipy.signal import butter, filtfilt, lfilter
from scipy.signal import kaiserord, firwin, filtfilt, butter

class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices: a list of indices
        num_samples: number of samples to draw
        callback_get_label: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(self, dataset, indices = None, num_samples = None, callback_get_label = None):
        self.indices = list(range(len(dataset))) if indices is None else indices        # if indices is not provided, all elements in the dataset will be considered
        self.callback_get_label = callback_get_label                                    # define custom callback
        self.num_samples = len(self.indices) if num_samples is None else num_samples    # if num_samples is not provided, draw `len(indices)` samples in each iteration

        df = pd.DataFrame()                                                             # distribution of classes in the dataset
        
        label = []
        for idx in trange(len(dataset), desc="Sampling"):
            ########## customize here ###############
            l = dataset[idx]['y_PVC_seg']
            if 1 in l:
                label.append(1)
            else:
                label.append(0)                
            ########## customize here ###############
        label = torch.tensor(label)
        
        df["label"] = label
        df.index = self.indices
        df = df.sort_index()

        label_to_count = df["label"].value_counts()

        weights = 1.0 / label_to_count[df["label"]] # almost equally
        # weights = 1.0 / (label_to_count[df["label"]])**2 # slightly weighted to 1
        self.weights = torch.DoubleTensor(weights.to_list())

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples

def IIRRemoveBL(ecgy,Fs, Fc):    
    #    ecgy:        the contamined signal (must be a list)
    #    Fc:          cut-off frequency
    #    Fs:          sample frequiency
    #    ECG_Clean :  processed signal without BLW
    
    # getting the length of the signal
    signal_len = len(ecgy)
    
    # fixed order
    N = 4
    
    # Normalized Cutt of frequency
    Wn = Fc/(Fs/2)    
    
    # IIR butterworth coefficients
    b, a = butter(N, Wn, 'highpass', analog=False)
    
    # Check filtfilt condition
    if N*3 > signal_len:
        diff = N*3 - signal_len
        ecgy = list(reversed(ecgy)) + list(ecgy) + list(ecgy[-1] * np.ones(diff))
        
        # Filtering with filtfilt
        ECG_Clean = filtfilt(b, a, ecgy)
        ECG_Clean = ECG_Clean[signal_len: signal_len + signal_len]
        
    else:
        ECG_Clean = filtfilt(b, a, ecgy)
                   
    return ECG_Clean

def IIRRemoveHF(ecgy, Fs, Fc):
    #    ecgy:        the contamined signal (must be a list)
    #    Fc:          cut-off frequency
    #    Fs:          sample frequiency
    #    ECG_Clean :  processed signal without BLW

    # getting the length of the signal
    signal_len = len(ecgy)

    # fixed order
    N = 4

    # Normalized Cutt of frequency
    Wn = Fc / (Fs / 2)

    # IIR butterworth coefficients
    b, a = butter(N, Wn, 'lowpass', analog=False)

    # Check filtfilt condition
    if N * 3 > signal_len:
        diff = N * 3 - signal_len
        ecgy = list(reversed(ecgy)) + list(ecgy) + list(ecgy[-1] * np.ones(diff))

        # Filtering with filtfilt
        ECG_Clean = filtfilt(b, a, ecgy)
        ECG_Clean = ECG_Clean[signal_len: signal_len + signal_len]

    else:
        ECG_Clean = filtfilt(b, a, ecgy)

    return ECG_Clean

def remove_baseline_wander(signal, fs):    
    Fc_l = 0.5
    Fc_h = 40.0

    signal_IIR = IIRRemoveBL(signal,fs,Fc_l)
    signal_IIR = IIRRemoveHF(signal_IIR,fs,Fc_h)
    return signal_IIR

# old preprocessing
def remove_baseline_wander(signal, fs):    
    order = 4
    nyq = 0.5 * fs
    lowcut = 0.67 #0.5
    highcut = 40
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    
    res = filtfilt(b, a, signal)
    # res = lfilter(b, a, signal)
    return res

from audiomentations import *
p=.2
augment_audiomentation = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=p),
    AddGaussianSNR(min_snr_in_db=5, max_snr_in_db=40.0, p=p),
    Gain(min_gain_in_db=-12, max_gain_in_db=12, p=p),
    FrequencyMask(min_frequency_band=0.0, max_frequency_band=.5, p=p),
    TanhDistortion(min_distortion= 0.01, max_distortion = 0.4, p=p),
    ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=30, p=p),
])

import neurokit2 as nk
def augment_neurokit(ecg_signal, sr):
    noise_shape = ['gaussian', 'laplace']
    n_noise_shape = np.random.randint(0,2)

    powerline_frequency = np.random.randint(50,60)
    noise_frequency = np.random.randint(2,20)
    artifacts_frequency= np.random.randint(2,20)
    # artifacts_number = np.random.randint(2,20)
    artifacts_number = 1

    powerline_amplitude = np.random.rand(1)*.3 #/ powerline_frequency
    noise_amplitude = np.random.rand(1)*.2 #/ noise_frequency
    artifacts_amplitude = np.random.rand(1)*1 #/ artifacts_frequency
    
    ecg_signal = signal_distort(ecg_signal,
                                sampling_rate=sr,
                                noise_shape=noise_shape[n_noise_shape],
                                noise_amplitude=noise_amplitude,
                                noise_frequency=noise_frequency,
                                powerline_amplitude=powerline_amplitude,
                                powerline_frequency=powerline_frequency,
                                artifacts_amplitude=artifacts_amplitude,
                                artifacts_frequency=artifacts_frequency,
                                artifacts_number=artifacts_number,
                                linear_drift=False,
                                random_state=None,#42,
                                silent=True)
    return ecg_signal

def minmax(arr):
    """
    numpy
    """
    return (arr-np.min(arr))/(np.max(arr)-np.min(arr))

def augment_neurokit2(sig,sr):
    beta = (np.random.rand(1)-.5)*4
    amp = np.random.rand(1)
    
    noise = nk.signal.signal_noise(duration=len(sig)/sr, sampling_rate=sr, beta=beta)*amp
    noise = minmax(noise) * (np.random.rand(1)) / 4

    result = augment_neurokit(noise, sr=sr)
    result = sig + result
    result = minmax(result)
    return result

class MIT_DATASET():
    def __init__(self, data, featureLength, srTarget, classes=4, augmentation="NONE", random_crop=False):
        self.data = data
        self.classes = classes
        self.augmentation = augmentation
        if augmentation == "NONE":
            self.augmentation = False
        elif augmentation =='NEUROKIT':
            self.augmentation=augment_neurokit
        elif augmentation =='NEUROKIT2':
            self.augmentation=augment_neurokit2
        elif augmentation =='AUDIOMENTATION':
            self.augmentation=augment_audiomentation

        self.random_crop = random_crop
        self.srTarget = srTarget
        self.featureLength = featureLength
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        pid = self.data[idx]['pid']
        signal = self.data[idx]['signal']
        srOriginal = self.data[idx]['sr']
        time = self.data[idx]['time']
        idx_Normal = self.data[idx]['idx_Normal']
        idx_PVC = self.data[idx]['idx_PVC']
        idx_AFIB = self.data[idx]['idx_Afib']
        idx_Others = self.data[idx]['idx_Others']
        idx_Artifact = self.data[idx]['idx_Artifact']
        dataSource = self.data[idx]['dataSource']
        
        y_Normal_seg = np.zeros_like(signal)
        y_PVC_seg = np.zeros_like(signal)
        y_AFIB_seg = np.zeros_like(signal)
        y_Others_seg = np.zeros_like(signal)
        
        interval = int(srOriginal * 0.1) # this is to set peak interval
        
        # grab annotations
        for idx_ in idx_Normal:  
            y_Normal_seg[idx_-interval:idx_+interval] = 1
        for idx_ in idx_PVC:
            y_PVC_seg[idx_-interval:idx_+interval] = 1
        for idx_ in idx_AFIB:
            y_AFIB_seg[idx_-interval:idx_+interval] = 1
        for idx_ in idx_Others:
            y_Others_seg[idx_-interval:idx_+interval] = 1
    
        # resampling
        if self.augmentation:
            srTarget = np.random.randint(int(self.srTarget*0.97),int(self.srTarget*1.03)) # time stretching, you need to carefully check here
        else: 
            srTarget = self.srTarget
            
        signal = lb.resample(signal, orig_sr=srOriginal, target_sr=srTarget) if srTarget != srOriginal else signal # resample
        y_Normal_seg = scipy.ndimage.zoom(y_Normal_seg, srTarget/srOriginal, order=0, mode='nearest',) if srTarget != srOriginal else y_Normal_seg # resample
        y_PVC_seg = scipy.ndimage.zoom(y_PVC_seg, srTarget/srOriginal, order=0, mode='nearest',) if srTarget != srOriginal else y_PVC_seg # resample
        y_AFIB_seg = scipy.ndimage.zoom(y_AFIB_seg, srTarget/srOriginal, order=0, mode='nearest',) if srTarget != srOriginal else y_AFIB_seg # resample
        y_Others_seg = scipy.ndimage.zoom(y_Others_seg, srTarget/srOriginal, order=0, mode='nearest',) if srTarget != srOriginal else y_Others_seg # resample
        
        if self.random_crop:
            if int(len(signal)) > self.featureLength:  # randomly crop 
                randnum = np.random.randint(0,len(signal)-self.featureLength)
                start = randnum if self.random_crop else 0
                end = start+self.featureLength
            elif int(len(signal)) == self.featureLength:
                start = 0
                end = 0 + self.featureLength
            else:
                print('too short data:: need check sampling rate or featureLength', int(len(signal)),self.featureLength)
                
            signal = signal[start:end]
            y_Normal_seg = y_Normal_seg[start:end]
            y_PVC_seg = y_PVC_seg[start:end]
            y_AFIB_seg = y_AFIB_seg[start:end]
            y_Others_seg = y_Others_seg[start:end]
        # print('after crop',signal.shape)

        y_peak_seg = y_Normal_seg + y_PVC_seg + y_Others_seg + y_AFIB_seg # R-peak
        y_peak_seg[y_peak_seg!=0] =1

        if self.classes == 1:
            y_seg = np.expand_dims(y_PVC_seg,0) # 1 class
        elif self.classes == 2:
            y_seg = np.stack((y_peak_seg, y_PVC_seg), axis=0).astype(float) # 2 multi class
        elif self.classes == 3:
            y_Others_seg = y_Others_seg + y_AFIB_seg # non PVC
            y_Others_seg[y_Others_seg!=0] =1
            y_seg = np.stack((y_peak_seg, y_PVC_seg, y_Others_seg), axis=0).astype(float) # 3 multi class    
            # y_seg = np.stack((y_peak_seg, y_PVC_seg, y_AFIB_seg), axis=0).astype(float) # 3 multi class    
        elif self.classes == 4:
            y_seg = np.stack((y_peak_seg, y_PVC_seg, y_Others_seg, y_AFIB_seg), axis=0).astype(float) # 4 multi class
        
        y_Normal = np.array([0]) if 1 in y_Normal_seg else np.array([1]) # classification task
        y_Others = np.array([0]) if 1 in y_Others_seg else np.array([1]) # classification task
        y_PVC    = np.array([0]) if 1 in y_PVC_seg else np.array([1]) # classification task
        y_AFIB   = np.array([0]) if 1 in y_AFIB_seg else np.array([1]) # classification task
        
        signal_original = signal.copy()
        signal_original = np.expand_dims(signal_original,0)
        
        # augmentation
        signal = signal if not self.augmentation else self.augmentation(signal, srTarget)
        
        signal = remove_baseline_wander(signal,srTarget)
        signal = np.expand_dims(signal,0)
                      
        signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal)) # normalize  
        signal = torch.tensor(signal).float() # shape should be Channel X Signal
        
        return {'dataSource':dataSource,
                'pid':pid,
                'srOriginal': srOriginal,
                'srTarget':srTarget,
                'time':time,
                'fname':f'{pid}_time{time}',
                'signal':signal,
                'signal_original':signal_original,
                'y_AFIB':y_AFIB, 
                'y_PVC':y_PVC,
                'y_AFIB_seg':y_AFIB_seg,
                'y_PVC_seg':y_PVC_seg, 
                'y_Normal_seg':y_Normal_seg,
                'y_Others_seg':y_Others_seg, 
                'y_seg':y_seg,}

5610 1431 905 14262 497 41753 10811



In [8]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=.1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha 
        self.gamma = gamma 

    def forward(self, inputs, targets):
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #first compute binary cross-entropy 
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Pt = torch.exp(-BCE)
        alpha_tensor = (1 - self.alpha) + targets * (2 * self.alpha - 1)  # alpha if target = 1 and 1 - alpha if target = 0
        focal_loss = self.alpha * (1-Pt)**self.gamma * BCE
                       
        return focal_loss

# class WeightedFocalLoss(nn.Module):
#     "Non weighted version of Focal Loss"
#     def __init__(self, alpha=.1, gamma=2):
#         super(WeightedFocalLoss, self).__init__()
#         self.alpha = torch.tensor([alpha, 1-alpha])
#         self.gamma = gamma

#     def forward(self, inputs, targets):
#         BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
#         targets = targets.type(torch.long)
#         at = self.alpha.gather(0, targets.data.view(-1))
#         pt = torch.exp(-BCE_loss)
#         F_loss = at*(1-pt)**self.gamma * BCE_loss
#         return F_loss.mean()
    
class FocalLoss(nn.Module):
    def __init__(self, alpha=.25, gamma=2, weight=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha 
        self.gamma = gamma 

    def forward(self, inputs, targets):
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #first compute binary cross-entropy 
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = self.alpha * (1-BCE_EXP)**self.gamma * BCE
                       
        return focal_loss
    
def get_dice_loss(y_pred, y_true, log=False, per_image=False, smooth=1e-7):
    tp = torch.sum(y_true * y_pred, axis=AXIS)
    fp = torch.sum(y_pred, axis=AXIS) - tp
    fn = torch.sum(y_true, axis=AXIS) - tp
    dice_score_per_image = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
    if log:
        dice_score_per_image = -1 * torch.log(dice_score_per_image)
    else:
        dice_score_per_image = 1 - dice_score_per_image
    if per_image:
        return dice_score_per_image
    else:
        return torch.mean(dice_score_per_image)
    
def get_tversky_loss(y_pred, y_true, beta=0.7, log=False, per_image=False, smooth=1e-7):
    alpha = 1 - beta
    tp = torch.sum(y_true * y_pred, axis=AXIS)
    fp = torch.sum(y_pred, axis=AXIS) - tp
    fn = torch.sum(y_true, axis=AXIS) - tp
    dice_score_per_image = (tp + smooth) / (tp + alpha * fp + beta * fn + smooth)
    if log:
        dice_score_per_image = -1 * torch.log(dice_score_per_image)
    else:
        dice_score_per_image = 1 - dice_score_per_image
    if per_image:
        return dice_score_per_image
    else:
        return torch.mean(dice_score_per_image)

class PropotionalLoss(nn.Module):
    def __init__(self, log=False, per_image=False, smooth=1e-7, beta=0.7, bce=False):
        super(PropotionalLoss, self).__init__()
        self.beta = beta
        self.smooth = smooth 
        self.log = log
        self.per_image = per_image
        self.bce = bce
        
    def forward(self, inputs, targets):
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        AXIS = [-1]
        self.alpha = 1 - self.beta
        y_true = targets
        y_pred = inputs
        
        prevalence = torch.mean(y_true, axis=AXIS)
        tp = torch.sum(y_true * y_pred, axis=AXIS)
        tn = torch.sum((1 - y_true) * (1 - y_pred), axis=AXIS)
        fp = torch.sum(y_pred, axis=AXIS) - tp
        fn = torch.sum(y_true, axis=AXIS) - tp
        negative_score = (tn + self.smooth) / (tn + self.beta * fn + self.alpha * fp + self.smooth) * (self.smooth + 1 - prevalence)
        positive_score = (tp + self.smooth) / (tp + self.alpha * fn + self.beta * fp + self.smooth) * (self.smooth + prevalence)
        score_per_image = negative_score + positive_score
        
        if self.log:
            score_per_image = -1 * torch.log(score_per_image)
        else:
            score_per_image = 1 - score_per_image
            
        if self.per_image == False:
            score_per_image = torch.mean(score_per_image)
        
        if self.bce:
            return score_per_image + F.binary_cross_entropy(y_pred, y_true)       
        
        
# def get_propotional_loss(y_pred, y_true, log=False, per_image=False, smooth=SMOOTH, beta=0.7):
#     alpha = 1 - beta
#     prevalence = torch.mean(y_true, axis=AXIS)
#     tp = torch.sum(y_true * y_pred, axis=AXIS)
#     tn = torch.sum((1 - y_true) * (1 - y_pred), axis=AXIS)
#     fp = torch.sum(y_pred, axis=AXIS) - tp
#     fn = torch.sum(y_true, axis=AXIS) - tp
#     negative_score = (tn + smooth) \
#         / (tn + beta * fn + alpha * fp + smooth) * (smooth + 1 - prevalence)
#     positive_score = (tp + smooth) \
#         / (tp + alpha * fn + beta * fp + smooth) * (smooth + prevalence)
#     score_per_image = negative_score + positive_score
#     if log:
#         score_per_image = -1 * torch.log(score_per_image)
#     else:
#         score_per_image = 1 - score_per_image
#     if per_image:
#         return score_per_image
#     else:
#         return torch.mean(score_per_image)
    
# get_bce_loss = torch.nn.BCELoss()
# get_loss = lambda y_pred, y_true: get_propotional_loss(y_pred, y_true) + get_bce_loss(y_pred, y_true)
    
# def BCELoss_class_weighted(weights):

#     def loss(input, target):
#         # input = torch.clamp(input, min=1e-7, max=1-1e-7)
#         bce = - weights[1] * target * torch.log(input) - weights[0] * (1 - target) * torch.log(1 - input)
#         return torch.mean(bce)

#     return loss

# from libauc.losses import APLoss
# from libauc.optimizers import SOAP
# from libauc.models import resnet18 as ResNet18
# from libauc.datasets import CIFAR10
# from libauc.utils import ImbalancedDataGenerator
# from libauc.sampler import DualSampler
# from libauc.metrics import auc_prc_score

# import torchvision.transforms as transforms
# from torch.utils.data import Dataset
# import numpy as np
# import torch
# from PIL import Image

# model = ResNet18(pretrained=False, last_activation=None) 
# model = model.cuda()

# lr= 1e-3
# pos_len=6018
# margin = 1.0
# gamma = 0.1

# weight_decay = 0
# total_epoch = 60
# decay_epoch = [30]
# SEED = 2022

# lossFn = APLoss(pos_len=pos_len, margin=margin, gamma=gamma)
# optimizer = SOAP(model.parameters(), lr=lr, mode='adam', weight_decay=weight_decay)

# lossFn = FocalLoss(alpha=1, gamma=0)
# y = torch.zeros(1,4,64)
# y[:,:1,:]=1

# print(np.unique(y.numpy(),return_counts=True))
# yhat = torch.rand(1,4,64)
# print(lossFn(yhat,y))

# lossFn = nn.BCELoss()
# print(lossFn(yhat,y))

# lossFn = BCELoss_class_weighted([.2, 1])
# print(lossFn(yhat,y))

In [9]:
import wandb
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, StochasticWeightAveraging, LambdaCallback, EarlyStopping

def train():
    set_seed()
    wandb.init(config=config_defaults)
    
    hyperparameters = dict(wandb.config)
    model = PVC_NET(hyperparameters)
    # model.prepare_data()
    # model.train_dataloader()

    classes = model.hyperparameters['outChannels']
    srTarget = model.hyperparameters['srTarget']
    featureLength = model.hyperparameters['featureLength']   
#     classes = wandb.config.outChannels
#     srTarget = wandb.config.srTarget
#     featureLength = wandb.config.featureLength

    files = glob('dataset/MIT-BIH_NPY/*.npy')
    
    def seed_MITBIH(files, seed):
        train_files, valid_files = sklearn.model_selection.train_test_split(files, test_size=.2, random_state= seed)

        train_seg = []
        for f in train_files:
            data = np.load(f,allow_pickle=True)
            train_seg.extend(data)

        valid_seg = []
        for f in valid_files:
            data = np.load(f,allow_pickle=True)
            valid_seg.extend(data)
            
        print('seed:',seed)
        # print('train_files:',train_files)
        # print('valid_files:',valid_files)
        add_datainfo(train_seg,1)
        add_datainfo(valid_seg,2)
        return train_seg, valid_seg
    
    train_data, valid_data = seed_MITBIH(files, wandb.config.dataSeed)
        
    train_dataset = MIT_DATASET(train_data,featureLength,srTarget,classes,wandb.config.trainaug, True)
    valid_dataset = MIT_DATASET(valid_data,featureLength,srTarget,classes,False)
    test_dataset = MIT_DATASET(test_data,featureLength,srTarget,classes,False)
    AMC_dataset = MIT_DATASET(AMC_data,featureLength,srTarget,classes,False)
    INCART_dataset = MIT_DATASET(INCART_data,featureLength, srTarget, classes, False)
    Fantasia_dataset = MIT_DATASET(Fantasia_data,featureLength, srTarget, classes, False)
    CPSC2020_dataset = MIT_DATASET(CPSC2020_data,featureLength, srTarget, classes, False)

    if wandb.config.sampler:
        train_loader = DataLoader(train_dataset, batch_size = wandb.config.batch_size, shuffle = False, num_workers=4, pin_memory=True, sampler=ImbalancedDatasetSampler(train_dataset))
    else:
        train_loader = DataLoader(train_dataset, batch_size = wandb.config.batch_size, shuffle = True, num_workers=4, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size = 64, shuffle = False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size = 64, num_workers=2, shuffle = False)
    AMC_loader = DataLoader(AMC_dataset,batch_size = 64, num_workers=2, shuffle = False)
    INCART_loader = DataLoader(INCART_dataset, batch_size = 64, num_workers=2, shuffle = False)
    Fantasia_loader = DataLoader(Fantasia_dataset,batch_size = 64, num_workers=2, shuffle = False)
    CPSC2020_loader = DataLoader(CPSC2020_dataset,batch_size = 64, num_workers=2, shuffle = False)

    wandb_logger = pl_loggers.WandbLogger(save_dir=f"{wandb.config.path_logRoot}/{model.experiment_name}", name=model.experiment_name, project=wandb.config.project, offline=False)

    lr_monitor_callback = LearningRateMonitor(logging_interval='epoch',)
    early_stop_callback = EarlyStopping(monitor='val_loss', mode="min", patience=10, verbose=False)
    loss_checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', dirpath=f"{wandb.config.path_logRoot}/{model.experiment_name}/weight/", filename="best_val_loss", save_top_k=1, verbose=False)
    # metric_checkpoint_callback = ModelCheckpoint(monitor='val_AUPRC_Class1Raw', mode='max', dirpath=f"{wandb.config.path_logRoot}/{model.experiment_name}/weight/", filename="best_val_metric", save_top_k=1, verbose=False)

    trainer = pl.Trainer(accumulate_grad_batches=8,
                        gradient_clip_val=0.1,
                        accelerator='gpu',
                        devices=-1,
                        strategy ='dp',
                        max_epochs=100, # 80
                        sync_batchnorm=True,
                        benchmark=False,
                        deterministic=True,
                        check_val_every_n_epoch=1,
                        # callbacks=[loss_checkpoint_callback, metric_checkpoint_callback, lr_monitor_callback, early_stop_callback],# StochasticWeightAveraging(swa_lrs=0.05)], #
                        callbacks=[loss_checkpoint_callback, lr_monitor_callback, early_stop_callback],# , StochasticWeightAveraging(swa_lrs=0.0001)], #
                        logger = wandb_logger,
                        precision= 32 # 'bf16', 16, 32
    )
    
    trainer.fit(model, train_loader, valid_loader)
    
    # weights = glob(f"{hyperparameters['path_logRoot']}/{net.experiment_name}/weight/*val_loss*")[-1]
    # result_test = trainer.test(net, test_loader, ckpt_path=weights)
    # result_AMC = trainer.test(net, AMC_loader, ckpt_path=weights)
    # result_INCART = trainer.test(net, INCART_loader,ckpt_path=weights)
    # result_CPSC2020 = trainer.test(net, CPSC2020_loader,ckpt_path=weights)

    result_test = trainer.test(model, test_loader, ckpt_path='best')
    result_AMC = trainer.test(model, AMC_loader, ckpt_path='best')
    result_CPSC2020 = trainer.test(model, CPSC2020_loader,ckpt_path='best')
    result_INCART = trainer.test(model, INCART_loader,ckpt_path='best')
    

def test(path):
    # set_seed()
    model = PVC_NET.load_from_checkpoint(path)

    classes = model.hyperparameters['outChannels']
    srTarget = model.hyperparameters['srTarget']
    featureLength = model.hyperparameters['featureLength']
    
    test_dataset = MIT_DATASET(test_data,featureLength,srTarget,classes,False)
    AMC_dataset = MIT_DATASET(AMC_data,featureLength,srTarget,classes,False)
    INCART_dataset = MIT_DATASET(INCART_data,featureLength, srTarget, classes, False)
    Fantasia_dataset = MIT_DATASET(Fantasia_data,featureLength, srTarget, classes, False)
    CPSC2020_dataset = MIT_DATASET(CPSC2020_data,featureLength, srTarget, classes, False)
    
    test_loader = DataLoader(test_dataset, batch_size = 64, num_workers=2, shuffle = False)
    AMC_loader = DataLoader(AMC_dataset,batch_size = 64, num_workers=2, shuffle = False)
    INCART_loader = DataLoader(INCART_dataset, batch_size = 64, num_workers=2, shuffle = False)
    Fantasia_loader = DataLoader(Fantasia_dataset,batch_size = 64, num_workers=2, shuffle = False)
    CPSC2020_loader = DataLoader(CPSC2020_dataset,batch_size = 64, num_workers=2, shuffle = False)

    trainer = pl.Trainer(accumulate_grad_batches=8,
                        gradient_clip_val=0.1,
                        accelerator='gpu',
                        devices=-1,
                        strategy ='dp',
                        max_epochs=100, # 80
                        sync_batchnorm=True,
                        benchmark=False,
                        deterministic=True,
                        check_val_every_n_epoch=1,
                        # callbacks=[loss_checkpoint_callback, lr_monitor_callback, early_stop_callback],# , StochasticWeightAveraging(swa_lrs=0.0001)], #
                        # logger = wandb_logger,
                        precision= 32 # 'bf16', 16, 32
    )
    model.testPlot=True
    
# #     result_test = trainer.test(model, test_loader, ckpt_path='best')
# #     result_AMC = trainer.test(model, AMC_loader, ckpt_path='best')
# #     result_CPSC2020 = trainer.test(model, CPSC2020_loader,ckpt_path='best')
# #     result_INCART = trainer.test(model, INCART_loader,ckpt_path='best')

    result_test = trainer.test(model, test_loader)
    result_AMC = trainer.test(model, AMC_loader)
    result_CPSC2020 = trainer.test(model, CPSC2020_loader)
    result_INCART = trainer.test(model, INCART_loader)


# def test(path_ckpt, test_dataloader='test', path_root=None):
    
#     weight = torch.load(path_ckpt)
#     # print(weight['hyper_parameters']['hyperparameters'])
        
#     classes = weight['hyper_parameters']['hyperparameters']['out_channels']
#     test_dataset = MIT_DATASET(test_data, sr_target, classes, False)
#     AMC_dataset = MIT_DATASET(AMC_data, sr_target, classes, False)
#     AMCREAL_dataset = MIT_DATASET(AMCREAL_data, sr_target, classes, False)

#     INCART_dataset = MIT_DATASET(INCART_data, sr_target, classes, False)
#     Fantasia_dataset = MIT_DATASET(Fantasia_data, sr_target, classes, False)
#     CPSC2020_dataset = MIT_DATASET(CPSC2020_data, sr_target, classes, False)

#     test_loader = DataLoader(test_dataset, batch_size = 16, shuffle = False)
#     AMC_loader = DataLoader(AMC_dataset, batch_size = 16, shuffle = False)
#     AMCREAL_loader = DataLoader(AMCREAL_dataset, batch_size = 16, shuffle = False)
#     INCART_loader = DataLoader(INCART_dataset, batch_size = 16, shuffle = False)
#     CPSC2020_loader = DataLoader(CPSC2020_dataset,batch_size = 16, shuffle = False)
    
#     if test_dataloader =='testMIT':
#         dataloader = test_loader
#     elif test_dataloader =='testAMC':
#         dataloader = AMC_loader
#     elif test_dataloader =='testAMCREAL':
#         dataloader = AMCREAL_loader
#     elif test_dataloader =='testINCART':
#         dataloader = INCART_loader
#     elif test_dataloader =='testFantasia':
#         dataloader = Fantasia_loader
#     elif test_dataloader =='testCPSC2020':
#         dataloader = CPSC2020_loader
        
#     net = PVCDetection(weight['hyper_parameters']['hyperparameters'])
#     net.load_state_dict(weight['state_dict'], strict=True)
#     net.testPlot=True
#     trainer = pl.Trainer(
#         accelerator='gpu',
#         devices=1,
#         strategy ='dp',
#         max_epochs=1,
#         check_val_every_n_epoch=1,
#         precision=32
#     )
    
#     return trainer.test(net, dataloader)

In [10]:
# !sudo rm -r 20221110_final/

In [11]:
# train()

In [None]:
import wandb
sweep_id = wandb.sweep(sweep_config, project=config_defaults['project'])

wandb.agent(sweep_id,function=train)

Create sweep with ID: z1it9q0m
Sweep URL: https://wandb.ai/keewonshin/PVC_NET/sweeps/z1it9q0m
2022-11-10 08:20:36,830 - Starting sweep agent: entity=None, project=None, count=None
2022-11-10 08:20:37,860 - Global seed set to 42
saving path : 20221110_final/dataSeed2_dropout0.1_featureLength512_lossFnbceloss_modelNameefficientnet-b0_norminstance_outChannels2_samplerTrue_sese_skipASPPNONE_skipModuleNONE_srTarget125_supervisionTYPE2_trainaugNEUROKIT_upsamplepixelshuffle_inChannels1/
[16, 24, 40, 112, 320]
skipModule:NONE
skipASPP:ASPP NONE
seed: 2


Sampling:   0%|          | 0/3077 [00:00<?, ?it/s]

2022-11-10 08:21:10,870 - GPU available: True (cuda), used: True
2022-11-10 08:21:10,872 - TPU available: False, using: 0 TPU cores
2022-11-10 08:21:10,873 - IPU available: False, using: 0 IPUs
2022-11-10 08:21:10,874 - HPU available: False, using: 0 HPUs
2022-11-10 08:21:13,543 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
2022-11-10 08:21:13,554 - 
  | Name   | Type    | Params
-----------------------------------
0 | net    | UNet    | 5.6 M 
1 | lossFn | BCELoss | 0     
-----------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.470    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
# path = '20221110_final/dataSeed2_srTarget360_featureLength2048_samplerFalse_outChannels4_modelNameefficientnet-b0_norminstance_upsampledeconv_supervisionNONE_skipModuleNONE_trainaugNONE/weight/best_val_loss.ckpt'
# model = PVC_NET.load_from_checkpoint(path)

In [None]:
# test('20221110_final/dataSeed2_srTarget360_featureLength2048_samplerFalse_outChannels4_modelNameefficientnet-b0_norminstance_upsampledeconv_supervisionNONE_skipModuleNONE_trainaugNONE/weight/best_val_loss.ckpt')

In [None]:
# path = '20221110_final/dataSeed2_srTarget360_featureLength2048_samplerFalse_outChannels4_modelNameefficientnet-b0_norminstance_upsampledeconv_supervisionNONE_skipModuleNONE_trainaugNONE//weight/best_val_loss.ckpt'

# # weight = torch.load(path)
# # hyperparameters = weight['hyper_parameters']['hyperparameters']
# model = PVC_NET.load_from_checkpoint(path)

In [None]:
# hyperparameters = config_defaults
# PVC_NET(hyperparameters)