In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path
import skimage
import skimage.segmentation
import sklearn.preprocessing
import sklearn.model_selection
import math
import shutil
import pathlib
import glob
import shutil
import uuid
import random
import platform
import torch
import torchvision
import numpy as np
import scipy as sp
import scipy.io
import scipy.signal
import pandas as pd
import networkx
import wfdb
import json
import tqdm
import dill
import pickle
import matplotlib.pyplot as plt

import scipy.stats

import src.data
import sak
import sak.wavelet
import sak.data
import sak.data.augmentation
import sak.visualization
import sak.visualization.plot
import sak.torch
import sak.torch.nn
import sak.torch.nn as nn
import sak.torch.train
import sak.torch.data
import sak.data.preprocessing
import sak.torch.models
import sak.torch.models.lego
import sak.torch.models.variational
import sak.torch.models.classification

from sak.signal import StandardHeader

def smooth(x: np.ndarray, window_size: int, conv_mode: str = 'same'):
    x = np.pad(np.copy(x),(window_size,window_size),'edge')
    window = np.hamming(window_size)/(window_size//2)
    x = np.convolve(x, window, mode=conv_mode)
    x = x[window_size:-window_size]
    return x

# Training the network

# Load execution configuration

In [3]:
with open('./configurations/MultiScaleUNet5Levels.json', 'r') as f:
    execution = json.load(f)

### 1. Load data
#### 1.1. Load individual segments

In [4]:
P = sak.pickleload(os.path.join('.','pickle','Psignal_new.pkl'))
PQ = sak.pickleload(os.path.join('.','pickle','PQsignal_new.pkl'))
QRS = sak.pickleload(os.path.join('.','pickle','QRSsignal_new.pkl'))
ST = sak.pickleload(os.path.join('.','pickle','STsignal_new.pkl'))
T = sak.pickleload(os.path.join('.','pickle','Tsignal_new.pkl'))
TP = sak.pickleload(os.path.join('.','pickle','TPsignal_new.pkl'))

Pamplitudes = sak.pickleload(os.path.join('.','pickle','Pamplitudes_new.pkl'))
PQamplitudes = sak.pickleload(os.path.join('.','pickle','PQamplitudes_new.pkl'))
QRSamplitudes = sak.pickleload(os.path.join('.','pickle','QRSamplitudes_new.pkl'))
STamplitudes = sak.pickleload(os.path.join('.','pickle','STamplitudes_new.pkl'))
Tamplitudes = sak.pickleload(os.path.join('.','pickle','Tamplitudes_new.pkl'))
TPamplitudes = sak.pickleload(os.path.join('.','pickle','TPamplitudes_new.pkl'))

#### 1.2. Get amplitude distribution

In [5]:
Pdistribution   = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(Pamplitudes.values()))))
PQdistribution  = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(PQamplitudes.values()))))
QRSdistribution = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.hstack((np.array(list(QRSamplitudes.values())), 2-np.array(list(QRSamplitudes.values()))))))
# QRSdistribution = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(QRSamplitudes.values()))))
STdistribution  = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(STamplitudes.values()))))
Tdistribution   = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(Tamplitudes.values()))))
TPdistribution  = scipy.stats.lognorm(*scipy.stats.lognorm.fit(np.array(list(TPamplitudes.values()))))

#### 1.3. Smooth all waves

In [6]:
# Smooth all
window = 5
P   = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth(  P[k],window)),metric=sak.signal.abs_max) for k in   P}
PQ  = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth( PQ[k],window)),metric=sak.signal.abs_max) for k in  PQ}
QRS = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth(QRS[k],window)),metric=sak.signal.abs_max) for k in QRS}
ST  = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth( ST[k],window)),metric=sak.signal.abs_max) for k in  ST}
T   = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth(  T[k],window)),metric=sak.signal.abs_max) for k in   T}
TP  = {k: sak.data.ball_scaling(sak.signal.on_off_correction(smooth( TP[k],window)),metric=sak.signal.abs_max) for k in  TP}

#### 1.4. Split into train and test

In [7]:
all_keys = {}
for k in list(P) + list(PQ) + list(QRS) + list(ST) + list(T) + list(TP):
    uid = k.split('###')[0].split('_')[0].split('-')[0]
    if uid not in all_keys:
        all_keys[uid] = [k]
    else:
        all_keys[uid].append(k)
        
# Get database and file
filenames = []
database = []
for k in all_keys:
    filenames.append(k)
    if k.startswith('SOO'):
        database.append(0)
    elif k.startswith('sel'):
        database.append(1)
    else:
        database.append(2)
filenames = np.array(filenames)
database = np.array(database)

In [24]:
# config_file = './configurations/UNet6Levels.json'
config_file = './configurations/MultiScaleUNet5Levels.json'
model_name = 'relativeampl_MultiUNet5_2_Uwave_NoShortP_SinusArrest'

with open(config_file, 'r') as f:
    execution = json.load(f)

random.seed(execution['seed'])
np.random.seed(execution['seed'])
torch.random.manual_seed(execution['seed'])
splitter = sklearn.model_selection.StratifiedKFold(5).split(filenames,database)

In [9]:
original_path = execution['save_directory']
all_folds_test = {}

for i,(ix_train,ix_valid) in enumerate(splitter):
    train_keys, valid_keys = ([],[])
    for k in np.array(filenames)[ix_train]: train_keys += all_keys[k]
    for k in np.array(filenames)[ix_valid]: valid_keys += all_keys[k]

    # Save fold"s validation files for later usage
    all_folds_test["fold_{}".format(i+1)] = np.array(filenames)[ix_valid]
    
    # Divide train/valid segments
    Ptrain   = {k: P[k] for k in P if k in train_keys}
    PQtrain  = {k: PQ[k] for k in PQ if k in train_keys}
    QRStrain = {k: QRS[k] for k in QRS if k in train_keys}
    STtrain  = {k: ST[k] for k in ST if k in train_keys}
    Ttrain   = {k: T[k] for k in T if k in train_keys}
    TPtrain  = {k: TP[k] for k in TP if k in train_keys}

    Pvalid   = {k: P[k] for k in P if k in valid_keys}
    PQvalid  = {k: PQ[k] for k in PQ if k in valid_keys}
    QRSvalid = {k: QRS[k] for k in QRS if k in valid_keys}
    STvalid  = {k: ST[k] for k in ST if k in valid_keys}
    Tvalid   = {k: T[k] for k in T if k in valid_keys}
    TPvalid  = {k: TP[k] for k in TP if k in valid_keys}

    # Prepare folders
    execution['save_directory'] = os.path.join(original_path,model_name,'fold_{}'.format(i+1))
    if not os.path.isdir(execution['save_directory']):
        pathlib.Path(execution['save_directory']).mkdir(parents=True, exist_ok=True)

    # Define datasets
    # dataset_train = src.data.Dataset(Ptrain, QRStrain, Ttrain, PQtrain, STtrain, TPtrain, 
    #                                  Pamplitudes, QRSamplitudes, Tamplitudes, PQamplitudes, 
    #                                  STamplitudes, TPamplitudes, 300*execution['loader']['batch_size'],
    #                                  interp_std = 0.25,labels_as_masks=True)
    # dataset_valid = src.data.Dataset(Pvalid, QRSvalid, Tvalid, PQvalid, STvalid, TPvalid, 
    #                                  Pamplitudes, QRSamplitudes, Tamplitudes, PQamplitudes, 
    #                                  STamplitudes, TPamplitudes, 100*execution['loader']['batch_size'],
    #                                  interp_std = 0.25,labels_as_masks=True)

    # # Create dataloaders
    # loader_train = torch.utils.data.DataLoader(dataset_train,**execution['loader'])
    # loader_valid = torch.utils.data.DataLoader(dataset_valid,**execution['loader'])

    # # Define model
    # model = nn.ModelGraph(execution['model']).float()
    
    # 
    break
    
sak.save_data(all_folds_test,os.path.join(original_path,model_name,"validation_files.csv"))


In [5]:
# Define model
model = nn.ModelGraph(execution['model']).float()

AttributeError: module 'torch.nn' has no attribute 'Concatenate'

In [11]:
dataset_train = src.data.Dataset(Ptrain, QRStrain, Ttrain, PQtrain, STtrain, TPtrain, 
                                 Pdistribution, QRSdistribution, Tdistribution, PQdistribution, 
                                 STdistribution, TPdistribution, **execution['dataset'])
execution['dataset'][]
dataset_valid = src.data.Dataset(Pvalid, QRSvalid, Tvalid, PQvalid, STvalid, TPvalid, 
                                 Pdistribution, QRSdistribution, Tdistribution, PQdistribution, 
                                 STdistribution, TPdistribution, **execution['dataset'])

# Create dataloaders
loader_train = torch.utils.data.DataLoader(dataset_train,**execution['loader'])
loader_valid = torch.utils.data.DataLoader(dataset_valid,**execution['loader'])

In [2]:
import ast

In [5]:
%%timeit
eval("lambda X,y,y_pred: sak.torch.nn.DiceLoss()(y_pred, y)")

11.7 µs ± 95.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [12]:
### Loss
# criterion = lambda X,y,y_pred: sak.torch.nn.CrossEntropyLoss()(y_pred, y.long())
# metric = lambda X,y,y_pred: sak.torch.nn.CrossEntropyLoss()(y_pred, y.long())
criterion = lambda X,y,y_pred: sak.torch.nn.DiceLoss()(y_pred, y)
metric = lambda X,y,y_pred: sak.torch.nn.DiceLoss()(y_pred, y)

state = {
    'epoch'         : 0,
    'device'        : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    'optimizer'     : sak.class_selector('torch.optim',execution['optimizer']['class'])(model.parameters(), **execution['optimizer']['arguments']),
    'root_dir'      : './'
}
if 'scheduler' in execution:
    state['scheduler'] = sak.class_selector('torch.optim.lr_scheduler',execution['scheduler']['class'])(state['optimizer'], **execution['scheduler']['arguments'])

# Save model-generating files
shutil.copyfile('./src/data.py',os.path.join(execution['save_directory'],'data.py'))
shutil.copyfile('./src/metrics.py',os.path.join(execution['save_directory'],'metrics.py'))
shutil.copyfile(config_file,os.path.join(execution['save_directory'],os.path.split(config_file)[1]))

# Train model (auto-saves to same location as above)
state = sak.torch.train.train_model(model,
                                      state,
                                      execution,
                                      loader_train,
                                      loader_valid,
                                      criterion,
                                      metric,
                                      smaller=True)

(Train) Epoch   1/100, Loss      0.301: 100%|██████████| 512/512 [04:33<00:00,  1.87it/s]
(Valid) Epoch   1/100, Loss      0.165: 100%|██████████| 512/512 [03:57<00:00,  2.15it/s]
(Train) Epoch   2/100, Loss      0.216: 100%|██████████| 512/512 [04:17<00:00,  1.99it/s]
(Valid) Epoch   2/100, Loss      0.146: 100%|██████████| 512/512 [03:58<00:00,  2.15it/s]
(Train) Epoch   3/100, Loss      0.223: 100%|██████████| 512/512 [04:18<00:00,  1.98it/s]
(Valid) Epoch   3/100, Loss      0.149: 100%|██████████| 512/512 [03:58<00:00,  2.14it/s]
(Train) Epoch   4/100, Loss      0.167: 100%|██████████| 512/512 [04:16<00:00,  1.99it/s]
(Valid) Epoch   4/100, Loss      0.142: 100%|██████████| 512/512 [03:58<00:00,  2.14it/s]
(Train) Epoch   5/100, Loss      0.170: 100%|██████████| 512/512 [04:18<00:00,  1.98it/s]
(Valid) Epoch   5/100, Loss      0.126: 100%|██████████| 512/512 [04:03<00:00,  2.10it/s]
(Train) Epoch   6/100, Loss      0.169: 100%|██████████| 512/512 [04:19<00:00,  1.97it/s]
(Valid) Ep

(Train) Epoch  92/100, Loss      0.132: 100%|██████████| 512/512 [04:19<00:00,  1.97it/s]
(Valid) Epoch  92/100, Loss      0.100: 100%|██████████| 512/512 [04:00<00:00,  2.13it/s]
(Train) Epoch  93/100, Loss      0.125: 100%|██████████| 512/512 [04:31<00:00,  1.89it/s]
(Valid) Epoch  93/100, Loss      0.126: 100%|██████████| 512/512 [04:25<00:00,  1.93it/s]
(Train) Epoch  94/100, Loss      0.132: 100%|██████████| 512/512 [04:34<00:00,  1.86it/s]
(Valid) Epoch  94/100, Loss      0.089: 100%|██████████| 512/512 [04:14<00:00,  2.01it/s]
(Train) Epoch  95/100, Loss      0.135: 100%|██████████| 512/512 [04:50<00:00,  1.76it/s]
(Valid) Epoch  95/100, Loss      0.107: 100%|██████████| 512/512 [04:35<00:00,  1.86it/s]
(Train) Epoch  96/100, Loss      0.199: 100%|██████████| 512/512 [04:48<00:00,  1.77it/s]
(Valid) Epoch  96/100, Loss      0.112: 100%|██████████| 512/512 [04:33<00:00,  1.87it/s]
(Train) Epoch  97/100, Loss      0.142: 100%|██████████| 512/512 [04:45<00:00,  1.79it/s]
(Valid) Ep

In [13]:
list(set([k.split('###')[0] for k in valid_keys]))

['SOO60-1-1_AVR',
 '12_V4',
 'SOO61-1-1_V5',
 'SOO51-1-1_V1',
 'SOO41-1-1_AVF',
 '40_AVF',
 'SOO58-1-1_V2',
 '37_AVR',
 'SOO50-1-1_III',
 '22_I',
 'SOO27-1-1_AVL',
 '42_V1',
 'SOO58-1-1_V4',
 'SOO69-1-1_AVF',
 'SOO31-1-1_V6',
 '4_V4',
 'SOO59-1-1_V6',
 'SOO46-1-1_V1',
 '7_V3',
 '13_AVR',
 '16_AVL',
 'SOO39-1-1_V2',
 '28_V3',
 'SOO18-1-1_III',
 '4_AVR',
 'SOO10-1-1_AVF',
 'SOO13-1-1_AVR',
 'SOO20-1-1_V4',
 'SOO43-1-1_V4',
 '41_I',
 'SOO21-1-1_II',
 '36_I',
 '24_II',
 '26_V2',
 'SOO2-1-1_V3',
 'SOO23-1-1_AVL',
 'SOO32-1-1_V1',
 'SOO41-1-1_V3',
 'SOO55-1-1_V2',
 'SOO44-1-1_V3',
 'SOO18-1-1_AVF',
 '2_III',
 '6_V6',
 '18_AVR',
 '41_V6',
 'SOO52-1-1_AVL',
 'SOO15-1-1_V6',
 'SOO63-1-1_AVR',
 '7_V6',
 '41_V5',
 'SOO44-1-1_V2',
 '46_I',
 'SOO24-1-1_III',
 '32_AVL',
 '28_II',
 'SOO32-1-1_V3',
 'SOO4-1-1_AVF',
 'SOO21-1-1_I',
 'SOO43-1-1_AVL',
 '25_V2',
 '32_V4',
 'SOO21-1-1_AVF',
 '36_V5',
 'SOO8-1-1_III',
 'SOO19-1-1_AVR',
 'SOO55-1-1_III',
 'SOO55-1-1_V1',
 'SOO63-1-1_III',
 '33_V5',
 '22_III'

In [14]:
%%time
for x,y in loader_valid:
    break

CPU times: user 19 ms, sys: 64 ms, total: 82.9 ms
Wall time: 3.26 s


In [15]:
agkljjsgkldjs

NameError: name 'agkljjsgkldjs' is not defined

In [None]:
(out,) = model(x.cuda())
y2 = torch.clone(y)
y = torch.zeros((64,3,2048),dtype=bool)
y[:,0,:] = (y2 == 1)
y[:,1,:] = (y2 == 2)
y[:,2,:] = (y2 == 3)

In [None]:
i = 2
w = 3
plt.figure(figsize=(15,5))
plt.plot(x[i,0,:])
plt.twinx()
plt.plot(y[i,w-1,:],alpha=0.5)
plt.plot(out[i,w,:].cpu().detach().numpy(),alpha=0.5)


# Test against SoO db

In [None]:
with open('./configurations/UNet5Levels.json', 'r') as f:
    execution = json.load(f)

# Define model
# model = nn.ModelGraph(execution['model']).float().cuda()
model = torch.load('/home/guille/GitHub/DelineatorSwitchAndCompose/Notebooks/modelo5nivsdice.state',pickle_module=dill).eval().float()
# model = torch.load(os.path.join(execution['save_directory'],'modelo5nivsdiceversion6_3','model_best.model'),pickle_module=dill).eval().float()
# model = torch.load(os.path.join(execution['save_directory'],'modelo5nivsdiceversion6_3','checkpoint.model'),pickle_module=dill).eval().float()

In [None]:
#### LOAD DATASET ####
basedir = '/media/guille/DADES/DADES/Delineator/'
Files = os.listdir(os.path.join(basedir,'SoO','RETAG'))
Files = [os.path.splitext(f)[0] for f in Files if os.path.splitext(f)[1] == '.txt']
Segmentations = pd.read_csv(os.path.join(basedir,'SoO','SEGMENTATIONS.csv'),index_col=0,header=None).T
Keys = Segmentations.keys().tolist()
Keys = [k for k in Keys if '-'.join(k.split('-')[:2]) in Files]
database = pd.read_csv(os.path.join(basedir,'SoO','DATABASE_MANUAL.csv'))

# Data storage
QRSsignalSoO = dict()
QRSgroupSoO = dict()

for k in tqdm.tqdm(Keys):
    # Retrieve general information
    fname = '-'.join(k.split('-')[:2]) + '.txt'
    ID = int(k.split('-')[0])
    
    # Read signal and segmentation
    Signal = pd.read_csv(os.path.join(basedir,'SoO','RETAG',fname),index_col=0).values
    (son,soff) = Segmentations[k]
    fs = database['Sampling_Freq'][database['ID'] == int(ID)].values[0]
    
    # Check correct segmentation
    if son > soff:
        print("(!!!) Check file   {:>10s} has onset ({:d}) > offset ({:d})".format(k, son, soff))
        continue

    # Up/downsample to 1000 Hz
    factor = int(fs/250)
    Signal = np.round(sp.signal.decimate(Signal.T, factor)).T
    fs = fs/factor

    # Filter baseline wander and high freq. noise
    Signal = sp.signal.filtfilt(*sp.signal.butter(4,   0.5/fs, 'high'),Signal.T).T
    Signal = sp.signal.filtfilt(*sp.signal.butter(4, 125.0/fs,  'low'),Signal.T).T
    
    for i in range(len(StandardHeader)):
        # Store data
        QRSsignalSoO[k+'###'+str(StandardHeader[i])] = Signal[:,i]


In [None]:
tmp = QRSsignalSoO['49-1-1###III']
ampl = np.median(sak.signal.moving_lambda(tmp,200,lambda x: np.max(x)-np.min(x)))
# tmp = scipy.interpolate.interp1d(np.linspace(0,1,tmp.size),tmp)(np.linspace(0,1,tmp.size))
aaa = (skimage.util.view_as_windows(tmp/ampl,2048,1024)-0)[:,None,:]
bbb = torch.zeros((aaa.shape[0],3,2048),dtype=float)
for i in range(0,aaa.shape[0],64):
    bbb[i:i+64] = model(torch.tensor(aaa[i:i+64]).cuda().float())[0]
bbb = bbb.cpu().detach().numpy()

i = 0
w = 0
f,ax = plt.subplots(nrows=3,figsize=(15,10))
ax[0].plot(aaa[i,0,:])
ax1 = ax[0].twinx()
ax1.plot(bbb[i,0,:]>0.5,alpha=0.5,color='red')
ax[1].plot(aaa[i,0,:])
ax2 = ax[1].twinx()
ax2.plot(bbb[i,1,:]>0.5,alpha=0.5,color='green')
ax[2].plot(aaa[i,0,:])
ax3 = ax[2].twinx()
ax3.plot(bbb[i,2,:]>0.5,alpha=0.5,color='magenta')


# Test against LUDB

In [None]:
LUDBsignal = {}

for i in tqdm.tqdm(range(200)):
    (signal, header) = wfdb.rdsamp(os.path.join(basedir,'ludb','{}'.format(i+1)))
    sortOrder = np.where(np.array([x.upper() for x in header['sig_name']])[:,None] == StandardHeader)[1]
    signal = signal[:,sortOrder]
    signal = sp.signal.decimate(signal,header['fs']//250,axis=0)
    
    for j in range(len(StandardHeader)):
        lead = StandardHeader[j]
        name = str(i+1)+"###"+lead
        LUDBsignal[name] = signal[:,j]

LUDBsignal = pd.DataFrame(LUDBsignal)

In [None]:
list(set([k.split('###')[0] for k in valid_keys]))

In [None]:
tmp = LUDBsignal['39###AVL']
ampl = np.median(sak.signal.moving_lambda(tmp,200,lambda x: np.max(x)-np.min(x)))
tmp = scipy.interpolate.interp1d(np.linspace(0,1,tmp.size),tmp)(np.linspace(0,1,1.0*tmp.size))
aaa = (skimage.util.view_as_windows(tmp/ampl,2048,1024)-0)[:,None,:]
bbb = torch.zeros((aaa.shape[0],3,2048),dtype=float)
for i in range(0,aaa.shape[0],64):
    bbb[i:i+64] = model(torch.tensor(aaa[i:i+64]).cuda().float())[0]
bbb = bbb.cpu().detach().numpy()

i = 0
w = 0
f,ax = plt.subplots(nrows=3,figsize=(15,10))
ax[0].plot(aaa[i,0,:])
ax1 = ax[0].twinx()
ax1.plot(bbb[i,0,:]>0.5,alpha=0.5,color='red')
ax[1].plot(aaa[i,0,:])
ax2 = ax[1].twinx()
ax2.plot(bbb[i,1,:]>0.5,alpha=0.5,color='green')
ax[2].plot(aaa[i,0,:])
ax3 = ax[2].twinx()
ax3.plot(bbb[i,2,:]>0.5,alpha=0.5,color='magenta')


# Test against QTDB

In [None]:
QTDBsignal = pd.read_csv(os.path.join(basedir,'QTDB','Dataset.csv'), index_col=0)
QTDBsignal = dataset.sort_index(axis=1)

In [None]:
list(set([k.split('###')[0] for k in valid_keys]))

In [None]:
tmp = QTDBsignal['sel16272_0']
ampl = np.median(sak.signal.moving_lambda(tmp,200,lambda x: np.max(x)-np.min(x)))
tmp = scipy.interpolate.interp1d(np.linspace(0,1,tmp.size),tmp)(np.linspace(0,1,1.0*tmp.size))
aaa = (skimage.util.view_as_windows(tmp/ampl,2048,1024)-0)[:,None,:]
bbb = torch.zeros((aaa.shape[0],3,2048),dtype=float)
for i in range(0,aaa.shape[0],64):
    bbb[i:i+64] = model(torch.tensor(aaa[i:i+64]).cuda().float())[0]
bbb = bbb.cpu().detach().numpy()

i = 74
w = 0
f,ax = plt.subplots(nrows=3,figsize=(15,10))
ax[0].plot(aaa[i,0,:])
ax1 = ax[0].twinx()
ax1.plot(bbb[i,0,:]>0.5,alpha=0.5,color='red')
ax[1].plot(aaa[i,0,:])
ax2 = ax[1].twinx()
ax2.plot(bbb[i,1,:]>0.5,alpha=0.5,color='green')
ax[2].plot(aaa[i,0,:])
ax3 = ax[2].twinx()
ax3.plot(bbb[i,2,:]>0.5,alpha=0.5,color='magenta')


### 2. Load model definition

In [None]:
### IMPORT EXECUTION CONFIGURATION PARAMETERS (JSON) ###
with open("./parameters.json", 'r') as f:
    execution = json.load(f)

execution["root_directory"] = input_directory
execution["save_directory"] = output_directory

### SET RANDOM SEED ###
torch.manual_seed(execution['seed'])
random.seed(execution['seed'])
np.random.seed(execution['seed'])

### LOAD DATASET ###
# 0) Get classes
# print(list(glob.glob(os.path.join(input_directory,"*.mat"))))
classes = get_classes(input_directory,[os.path.split(f)[-1] for f in glob.glob(os.path.join(input_directory,"*.mat"))])

# 1) Load labels and compute detections
print("########## COMPUTING DETECTIONS ##########")
files = []
labels = []
detections = []
for f in tqdm.tqdm(glob.glob(os.path.join(input_directory,"*.mat"))):
    # Load data
    (signal,header) = wfdb.rdsamp(os.path.join(input_directory,os.path.splitext(f)[0]))
    signal = signal.astype('float32')

    # Use provided function for retrieving the true label
    fname, label_header, label = get_true_labels(f.replace('.mat','.hea'),classes)

    # Detect signal
    detector = ecgdetectors.Detectors(header['fs'])
    index_I = np.where(np.array(list(map(str.upper,header['sig_name']))) == 'I')[0][0]
    qrs = detector.pan_tompkins_detector(signal[:,index_I])

    # Store file name and label
    files.append(fname)
    detections.append(qrs)
    labels.append(label)

labels = np.array(labels)
files = np.array(files)

# 2) Train-test split
labels_train,labels_valid,files_train,files_valid,detections_train,detections_valid = sklearn.model_selection.train_test_split(
    labels,
    files,
    detections,
    stratify=labels.argmax(-1),
    random_state=execution['seed'],
)

# Save into folder
src.utils.pickledump(labels_train, './training/labels_train.pkl')
src.utils.pickledump(labels_valid, './training/labels_valid.pkl')
src.utils.pickledump(files_train, './training/files_train.pkl')
src.utils.pickledump(files_valid, './training/files_valid.pkl')
src.utils.pickledump(detections_train, './training/detections_train.pkl')
src.utils.pickledump(detections_valid, './training/detections_valid.pkl')

print("########## GENERATING TRAIN SET ##########")
# Generate train/test sets
X_train = []
y_train = []
X_valid = []
y_valid = []

for i in tqdm.tqdm(range(len(files_train))):
    # Retrieve the file information
    (signal,_) = wfdb.rdsamp(os.path.join(execution['root_directory'],files_train[i]))
    signal = signal.astype('float32').T

    if not execution['whole_record']:
        for j in range(1,len(detections_train[i])-1):
            onset = detections_train[i][j-1]
            offset = detections_train[i][j+1]
            interp = signal[:,onset:offset]
            interp = sp.interpolate.interp1d(np.linspace(0,1,interp.shape[1]),interp,axis=-1)(np.linspace(0,1,736)).astype('float32')
            X_train.append(interp)
            y_train.append(labels_train[i,:])
    else:
        X_train.append(signal)
        y_train.append(labels_train[i,:])

print("########## GENERATING TRAIN SET ##########")
for i in tqdm.tqdm(range(len(files_valid))):
    # Retrieve the file information
    (signal,_) = wfdb.rdsamp(os.path.join(execution['root_directory'],files_valid[i]))
    signal = signal.astype('float32').T

    if not execution['whole_record']:
        for j in range(1,len(detections_valid[i])-1):
            onset = detections_valid[i][j-1]
            offset = detections_valid[i][j+1]
            interp = signal[:,onset:offset]
            interp = sp.interpolate.interp1d(np.linspace(0,1,interp.shape[1]),interp,axis=-1)(np.linspace(0,1,736)).astype('float32')
            X_valid.append(interp)
            y_valid.append(labels_valid[i,:])
    else:
        X_valid.append(signal)
        y_valid.append(labels_valid[i,:])

y_valid = np.array(y_valid, dtype='float32')
y_train = np.array(y_train, dtype='float32')
try:
    X_train = np.array(X_train, dtype='float32')
    X_valid = np.array(X_valid, dtype='float32')
except:
    pass

### TRAIN MODEL ###
model = src.model.GAPModel(
    torch.nn.Sequential(
        src.model.CNN([12,16,16], regularization=execution['regularization_CNN']),
        torch.nn.MaxPool1d(3),
        src.model.CNN([16,16], regularization=execution['regularization_CNN']),
        torch.nn.MaxPool1d(3),
        src.model.CNN([16,32], regularization=execution['regularization_CNN']),
        torch.nn.MaxPool1d(3),
        src.model.CNN([32,32], regularization=execution['regularization_CNN']),
        torch.nn.MaxPool1d(3),
        src.model.CNN([32,64], regularization=execution['regularization_CNN']),
        torch.nn.MaxPool1d(3),
        src.model.CNN([64,128], regularization=execution['regularization_CNN']),
        torch.nn.MaxPool1d(3),
        src.model.CNN([128,256], regularization=execution['regularization_CNN'], regularize_extrema=False),
    ),
    src.model.DNN([256,128,64,32,9], regularization=execution['regularization_DNN'], regularize_extrema=False),
)

if execution['whole_record']:
    dataset_train = src.data.PaddedDataset(X_train, y_train, padding_length=execution['padding_length'],swapaxes=False, mode='edge')
    dataset_valid = src.data.PaddedDataset(X_valid, y_valid, padding_length=execution['padding_length'],swapaxes=False, mode='edge')
else:
    dataset_train = src.data.Dataset(X_train, y_train)
    dataset_valid = src.data.Dataset(X_valid, y_valid)

sampler_train = src.data.StratifiedSampler(y_train, *execution['sampler'])
sampler_valid = src.data.StratifiedSampler(y_valid, *execution['sampler'])

loader_train  = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, batch_size=execution['batch_size'], **execution['loader'])
loader_valid  = torch.utils.data.DataLoader(dataset_valid, sampler=sampler_valid, batch_size=execution['batch_size'], **execution['loader'])

# Loss
criterion = lambda X,y,y_pred: torch.nn.MultiLabelSoftMarginLoss(reduction='mean')(y_pred, y.long())
metric = lambda X,y,y_pred: src.evaluate.compute_beta_score(y.long().cpu().detach().numpy(),(torch.nn.functional.softmax(y_pred,-1) > 0.5).cpu().detach().numpy())[-1]

state = {
    'epoch'         : 0,
    'device'        : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    'optimizer'     : src.utils.class_selector('torch.optim',execution['optimizer']['name'])(model.parameters(), **execution['optimizer']['arguments']),
    'root_dir'      : './'
}
if 'scheduler' in execution:
    state['scheduler'] = src.utils.class_selector('torch.optim.lr_scheduler',execution['scheduler']['name'])(state['optimizer'], **execution['scheduler']['arguments'])

print("########## TRAINING THE MODEL ##########")
state = src.train.train_model(model,state,execution,loader_train, loader_valid, criterion, metric, smaller=False)


# Boundary Loss

In [None]:
from typing import List

class SurfaceLoss():
    """Adapted from https://github.com/LIVIAETS/surface-loss/blob/master/losses.py#L74"""
    def __init__(self, **kwargs):
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
        self.idc: List[int] = kwargs["idc"]

    def __call__(self, probs: torch.Tensor, dist_maps: Tensor) -> torch.Tensor:
        assert simplex(probs)
        assert not one_hot(dist_maps)

        pc = probs[:, self.idc, ...].type(torch.float32)
        dc = dist_maps[:, self.idc, ...].type(torch.float32)

        multipled = torch.einsum("bcs,bcs->bcs", pc, dc)

        loss = multipled.mean()

        return loss

In [None]:
P = skimage.segmentation.find_boundaries(y[:,0,None,:])

In [None]:
def generate_binary_structure(rank, connectivity):
    if connectivity < 1:
        connectivity = 1
    if rank < 1:
        if connectivity < 1:
            return np.array(0, dtype=bool)
        else:
            return np.array(1, dtype=bool)
    output = np.fabs(np.indices([3] * rank) - 1)
    output = np.add.reduce(output, 0)
    
    return np.asarray(output <= connectivity, dtype=bool)

In [None]:
struct = generate_binary_structure(ndim-1, 1)
struct[0,:]=False
struct[-1,:]=False
struct = struct[None,None,...]
struct

In [None]:
selem = torch.tensor(selem).type(torch.uint8)

In [None]:
import torch.nn.functional as F
bnd = F.conv2d(label_img, selem, padding=(selem.shape[2] // 2, selem.shape[2] // 2))

In [None]:
def erosion1d(signal, selem):
    inverted = torch.logical_not(signal).type(signal.dtype)
    out = F.conv1d(inverted, selem, padding=(selem.shape[-1] // 2,)) > 0
    return torch.logical_not(out)

def dilation1d(signal, selem):
    return F.conv1d(signal, selem, padding=(selem.shape[-1] // 2,)) > 0

def erosion2d(image, selem):
    inverted = torch.logical_not(image).type(image.dtype)
    out = F.conv2d(inverted, selem, padding=(selem.shape[2] // 2, selem.shape[2] // 2)) > 0
    return torch.logical_not(out)

def dilation2d(image, selem):
    return F.conv2d(image, selem, padding=(selem.shape[2] // 2, selem.shape[2] // 2)) > 0

In [None]:
selem = np.zeros((3,)*(ndim-1),dtype=bool)
selem[1,1,:] = True
selem1 = np.zeros((3,)*(ndim-1),dtype=bool)
selem1[0,0,:] = True
selem1[1,1,:] = True
selem1[2,2,:] = True

In [None]:
selem1

In [None]:
label_img.shape

In [None]:
Pbound = skimage.morphology.dilation(label_img[:,0,...].numpy(),selem.astype('bool')).squeeze()
QRSbound = skimage.morphology.dilation(label_img[:,1,...].numpy(),selem.astype('bool')).squeeze()
Tbound = skimage.morphology.dilation(label_img[:,2,...].numpy(),selem.astype('bool')).squeeze()
out2 = np.concatenate((Pbound[:,None,:],QRSbound[:,None,:],Tbound[:,None,:]),axis=1)
Per = skimage.morphology.erosion(label_img[:,0,...].numpy(),selem.astype('bool')).squeeze()
QRSer = skimage.morphology.erosion(label_img[:,1,...].numpy(),selem.astype('bool')).squeeze()
Ter = skimage.morphology.erosion(label_img[:,2,...].numpy(),selem.astype('bool')).squeeze()
er2 = np.concatenate((Per[:,None,:],QRSer[:,None,:],Ter[:,None,:]),axis=1)

In [None]:
out2.sum()

In [None]:
er2.sum()

In [None]:
out = dilation1d(label_img.type(torch.float32).squeeze(),torch.tensor(selem1).type(torch.float32))

In [None]:
er = erosion1d(label_img.type(torch.float32).squeeze(),torch.tensor(selem1).type(torch.float32))

In [None]:
out.sum()

In [None]:
er.sum()

In [None]:
i = 3
l = 1
plt.plot(er[i,l,:])
plt.plot(er2[i,l,:])

In [None]:
np.allclose(out2.astype('bool'),out.numpy())

In [None]:
out

In [None]:
scipy.ndimage.morphology.binary_dilation(label_img,selem)

In [None]:
label_img = y[:,:,None,None,:]

In [None]:
bnds = skimage.segmentation.find_boundaries(label_img).squeeze()

In [None]:
bnds.shape

In [None]:
import cv2

In [None]:
cv2.dilate(label_img.numpy(),selem)

In [None]:
plt.plot(y[0,1,:])
plt.plot(boundaries.squeeze()[0,1,:])

In [None]:
label_img = y[:,:,None,:]
connectivity=1
mode='thick'
background=0

In [None]:
if label_img.dtype == torch.bool:
    label_img = label_img.type(torch.uint8)
ndim = label_img.ndim
# selem = torch.tensor(generate_binary_structure(ndim, connectivity))
selem = np.zeros((3,)*(ndim-1),dtype=bool)
selem[1,1,1,:] = True
if mode != 'subpixel':
    boundaries = skimage.morphology.dilation(label_img, selem) != skimage.morphology.erosion(label_img, selem)
    if mode == 'inner':
        foreground_image = (label_img != background)
        boundaries &= foreground_image
    elif mode == 'outer':
        max_label = torch.iinfo(label_img.dtype).max
        background_image = (label_img == background)
        selem = generate_binary_structure(ndim, ndim)
        inverted_background = torch.tensor(label_img, copy=True)
        inverted_background[background_image] = max_label
        adjacent_objects = ((skimage.morphology.dilation(label_img, selem) !=
                             skimage.morphology.erosion(inverted_background, selem)) &
                            ~background_image)
        boundaries &= (background_image | adjacent_objects)
else:
    boundaries = _find_boundaries_subpixel(label_img)


In [None]:
boundaries

In [None]:
skimage.morphology.dilation

In [None]:
math.fabs()