In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import platform
import random
import uuid
import os
import os.path
import skimage
import utils
import utils.wavelet
import utils.data
import utils.data.augmentation
import numpy as np
import scipy as sp
import scipy.signal
import pandas as pd
import networkx
import networkx.algorithms.approximation
import wfdb
import json
import tqdm
import matplotlib.pyplot as plt
from scipy.stats import norm
from utils.signal import StandardHeader

# Data loader to un-clutter code    
def load_data(filepath):
    dic = dict()
    with open(filepath) as f:
        text = list(f)
    for line in text:
        line = line.replace(' ','').replace('\n','').replace(',,','')
        if line[-1] == ',': line = line[:-1]
        head = line.split(',')[0]
        tail = line.split(',')[1:]
        if tail == ['']:
            tail = np.asarray([])
        else:
            tail = np.asarray(tail).astype(int)

        dic[head] = tail
    return dic


def trailonset(sig,on):
    on = on-sig[0]
    off = on-sig[0]+sig[-1]
    sig = sig+np.linspace(on,off,sig.size)
    
    return sig

def getcorr(segments):
    if len(segments) > 0:
        length = 2*max([segments[i][2].size for i in range(len(segments))])
    else:
        return np.zeros((0,0))

    corr = np.zeros((len(segments),len(segments)))

    for i in range(len(segments)):
        for j in range(len(segments)):
            if i != j:
                if segments[i][2].size != segments[j][2].size:
                    if segments[i][2].size != 1:
                        x1 = sp.interpolate.interp1d(np.linspace(0,1,len(segments[i][2])),segments[i][2])(np.linspace(0,1,length))
                    else:
                        x1 = np.full((length,),segments[i][2][0])
                    if segments[j][2].size != 1:
                        x2 = sp.interpolate.interp1d(np.linspace(0,1,len(segments[j][2])),segments[j][2])(np.linspace(0,1,length))
                    else:
                        x2 = np.full((length,),segments[j][2][0])
                else:
                    x1 = segments[i][2]
                    x2 = segments[j][2]
                if (x1.size == 1) and (x2.size == 1):
                    corr[i,j] = 1
                else:
                    c,_ = utils.signal.xcorr(x1,x2)
                    corr[i,j] = np.max(np.abs(c))
            else:
                corr[i,j] = 1
                
    return corr

def getdelete(segments, threshold):
    corr = getcorr(segments)
    
    index_delete = []
    
    for i in range(corr.shape[0]):
        if i in index_delete:
            continue
        for j in range(corr.shape[1]):
            if j == i:
                continue
            if corr[i,j] > threshold:
                if j not in index_delete:
                    index_delete.append(j)
                
    return index_delete
    

# Define directories

In [3]:
if platform.system() in ['Linux', 'Linux2']:
    basedir = '/media/guille/DADES/DADES/Delineator'
else:
    basedir = r'C:\Users\Emilio\Documents\DADES\DADES\Delineator'

# Load LUDB

In [4]:
dataset = {}
Pon = {}
Ppeak = {}
Poff = {}
QRSon = {}
QRSpeak = {}
QRSoff = {}
Ton = {}
Tpeak = {}
Toff = {}
group = {}

for i in tqdm.tqdm(range(200)):
    (signal, header) = wfdb.rdsamp(os.path.join(basedir,'ludb','{}'.format(i+1)))
    sortOrder = np.where(np.array([x.upper() for x in header['sig_name']])[:,None] == StandardHeader)[1]
    signal = signal[:,sortOrder]
    if header['fs'] != 500:
        print(header['fs'])
    signal = sp.signal.decimate(signal,2,axis=0)
    
    # 1st step: reduce noise
    signal = sp.signal.filtfilt(*sp.signal.butter(4,   0.5/250., 'high'),signal.T).T
    signal = sp.signal.filtfilt(*sp.signal.butter(4, 125.0/250.,  'low'),signal.T).T

    # 2nd step: retrieve onsets and offsets
    for j in range(len(StandardHeader)):
        lead = StandardHeader[j]
        name = str(i+1)+"_"+lead
        ann = wfdb.rdann(os.path.join(basedir,'ludb','{}'.format(i+1)),'atr_{}'.format(lead.lower()))
        dataset[name] = signal[:,j]
        
        locP = np.where(np.array(ann.symbol) == 'p')[0]
        if len(locP) != 0:
            if locP[0]-1 < 0:
                locP = locP[1:]
            if locP[-1]+1 == len(ann.sample):
                locP = locP[:-1]
        Pon[name] = ann.sample[locP-1]//2
        Ppeak[name] = ann.sample[locP]//2
        Poff[name] = ann.sample[locP+1]//2

        locQRS = np.where(np.array(ann.symbol) == 'N')[0]
        if len(locQRS) != 0:
            if locQRS[0]-1 < 0:
                locQRS = locQRS[1:]
            if locQRS[-1]+1 == len(ann.sample):
                locQRS = locQRS[:-1]
        QRSon[name] = ann.sample[locQRS-1]//2
        QRSpeak[name] = ann.sample[locQRS]//2
        QRSoff[name] = ann.sample[locQRS+1]//2

        locT = np.where(np.array(ann.symbol) == 't')[0]
        if len(locT) != 0:
            if locT[0]-1 < 0:
                locT = locT[1:]
            if locT[-1]+1 == len(ann.sample):
                locT = locT[:-1]
        Ton[name] = ann.sample[locT-1]//2
        Tpeak[name] = ann.sample[locT]//2
        Toff[name] = ann.sample[locT+1]//2
        
        # Store group
        group[name] = str(i+1)

dataset = pd.DataFrame(dataset)

100%|██████████| 200/200 [00:17<00:00, 11.22it/s]


In [None]:
threshold = 0.99

PsignalLUDB = {}
PQsignalLUDB = {}
QRSsignalLUDB = {}
STsignalLUDB = {}
TsignalLUDB = {}
TPsignalLUDB = {}

PgroupLUDB = {}
PQgroupLUDB = {}
QRSgroupLUDB = {}
STgroupLUDB = {}
TgroupLUDB = {}
TPgroupLUDB = {}

for k in tqdm.tqdm(dataset.keys()):
    # Buggy files
    if k in (['116_{}'.format(h) for h in StandardHeader] + 
             ['104_{}'.format(h) for h in StandardHeader] + 
             ['103_III',]):
        continue
    pon = Pon.get(k,np.array([]))
    pof = Poff.get(k,np.array([]))
    qon = QRSon.get(k,np.array([]))
    qof = QRSoff.get(k,np.array([]))
    ton = Ton.get(k,np.array([]))
    tof = Toff.get(k,np.array([]))
    
    unordered_samples = np.concatenate([pon,pof,qon,qof,ton,tof,]).astype(float)
    unordered_symbols = np.concatenate([['Pon']*pon.size,['Poff']*pof.size,
                                        ['QRSon']*qon.size,['QRSoff']*qof.size,
                                        ['Ton']*ton.size,['Toff']*tof.size,])
    # Sort fiducials taking logical orders if same sample of occurrence
    # There is (I'm definitely sure) a better way to do it
    samples = []
    symbols = []
    for i in range(unordered_samples.size):
        minimum = np.where(unordered_samples == min(unordered_samples))[0]
        if minimum.size == 1:
            minimum = minimum[0]
            samples.append(int(unordered_samples[minimum]))
            symbols.append(unordered_symbols[minimum])
            unordered_samples[minimum] = np.inf
        elif minimum.size == 2:
            if symbols[-1] == 'Pon':
                if unordered_symbols[minimum[0]] == 'Poff':
                    samples.append(int(unordered_samples[minimum[0]]))
                    symbols.append(unordered_symbols[minimum[0]])
                    unordered_samples[minimum[0]] = np.inf
                elif unordered_symbols[minimum[1]] == 'Poff':
                    samples.append(int(unordered_samples[minimum[1]]))
                    symbols.append(unordered_symbols[minimum[1]])
                    unordered_samples[minimum[1]] = np.inf
            elif symbols[-1] == 'QRSon':
                if unordered_symbols[minimum[0]] == 'QRSoff':
                    samples.append(int(unordered_samples[minimum[0]]))
                    symbols.append(unordered_symbols[minimum[0]])
                    unordered_samples[minimum[0]] = np.inf
                elif unordered_symbols[minimum[1]] == 'QRSoff':
                    samples.append(int(unordered_samples[minimum[1]]))
                    symbols.append(unordered_symbols[minimum[1]])
                    unordered_samples[minimum[1]] = np.inf
            elif symbols[-1] == 'Ton':
                if unordered_symbols[minimum[0]] == 'Toff':
                    samples.append(int(unordered_samples[minimum[0]]))
                    symbols.append(unordered_symbols[minimum[0]])
                    unordered_samples[minimum[0]] = np.inf
                elif unordered_symbols[minimum[1]] == 'Toff':
                    samples.append(int(unordered_samples[minimum[1]]))
                    symbols.append(unordered_symbols[minimum[1]])
                    unordered_samples[minimum[1]] = np.inf
            else:
                raise ValueError("Should not happen at all")
        else:
            raise ValueError("Definitely should not happen. Check file {}".format(k))
    samples = np.array(samples)
    symbols = np.array(symbols)
    
    # Extract segments
    P = []
    QRS = []
    T = []
    TP = []
    PQ = []
    ST = []

    # Extract segments
    for i in range(samples.size-1):
        if samples[i] == samples[i+1]:
            continue
        if symbols[i] == 'Pon':
            if symbols[i+1] == 'Poff':
                P.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            else:
                print("Check file {}. P onset not followed by offset".format(k))
        elif symbols[i] == 'QRSon':
            if symbols[i+1] == 'QRSoff':
                QRS.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            else:
                print("Check file {}. QRS onset not followed by offset".format(k))
        elif symbols[i] == 'Ton':
            if symbols[i+1] == 'Toff':
                T.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            else:
                print("Check file {}. T onset not followed by offset".format(k))
        elif symbols[i] == 'Poff':
            if symbols[i+1] == 'Pon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'QRSon':
                PQ.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'Ton':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] in ['Poff','QRSoff','Toff']:
                print("Check file {}. P offset not followed by onset".format(k))
        elif symbols[i] == 'QRSoff':
            if symbols[i+1] == 'Pon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'QRSon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'Ton':
                ST.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] in ['Poff','QRSoff','Toff']:
                print("Check file {}. P offset not followed by onset".format(k))
        elif symbols[i] == 'Toff':
            if symbols[i+1] == 'Pon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'QRSon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'Ton':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] in ['Poff','QRSoff','Toff']:
                print("Check file {}. P offset not followed by onset".format(k))
        else:
            raise ValueError("This should definitely not happen")

    # Filter out too similar segments
    corr = getcorr(P)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    P = [P[i] for i in range(len(P)) if i in nodesclique]

    corr = getcorr(QRS)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    QRS = [QRS[i] for i in range(len(QRS)) if i in nodesclique]

    corr = getcorr(T)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    T = [T[i] for i in range(len(T)) if i in nodesclique]

    corr = getcorr(TP)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    TP = [TP[i] for i in range(len(TP)) if i in nodesclique]

    corr = getcorr(PQ)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    PQ = [PQ[i] for i in range(len(PQ)) if i in nodesclique]

    corr = getcorr(ST)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    ST = [ST[i] for i in range(len(ST)) if i in nodesclique]
    
    # Store segments
    for i in range(len(P)):
        PsignalLUDB[k + '_' + str(i)] = P[i][2]
        PgroupLUDB[k + '_' + str(i)] = (P[i][0],P[i][1])
    for i in range(len(QRS)):
        QRSsignalLUDB[k + '_' + str(i)] = QRS[i][2]
        QRSgroupLUDB[k + '_' + str(i)] = (QRS[i][0],QRS[i][1])
    for i in range(len(T)):
        TsignalLUDB[k + '_' + str(i)] = T[i][2]
        TgroupLUDB[k + '_' + str(i)] = (T[i][0],T[i][1])
    for i in range(len(TP)):
        TPsignalLUDB[k + '_' + str(i)] = TP[i][2]
        TPgroupLUDB[k + '_' + str(i)] = (TP[i][0],TP[i][1])
    for i in range(len(PQ)):
        PQsignalLUDB[k + '_' + str(i)] = PQ[i][2]
        PQgroupLUDB[k + '_' + str(i)] = (PQ[i][0],PQ[i][1])
    for i in range(len(ST)):
        STsignalLUDB[k + '_' + str(i)] = ST[i][2]
        STgroupLUDB[k + '_' + str(i)] = (ST[i][0],ST[i][1])


 17%|█▋        | 414/2400 [00:36<09:12,  3.60it/s]

In [None]:
print(len(PgroupLUDB))
print(len(PQgroupLUDB))
print(len(QRSgroupLUDB))
print(len(STgroupLUDB))
print(len(TgroupLUDB))
print(len(TPgroupLUDB))

# Load QT db

In [None]:
#### LOAD DATASETS ####
dataset             = pd.read_csv(os.path.join(basedir,'QTDB','Dataset.csv'), index_col=0)
dataset             = dataset.sort_index(axis=1)
labels              = np.asarray(list(dataset)) # In case no data augmentation is applied
description         = dataset.describe()
group               = {k: '_'.join(k.split('_')[:-1]) for k in dataset}

# Zero-center data
for key in description:
    dataset[key]    = (dataset[key] - description[key]['mean'])/description[key]['std']
    
# Filter the data
for col in dataset:
    dataset[col] = sp.signal.filtfilt(*sp.signal.butter(4,   0.5/250., 'high'),dataset[col].T).T
    dataset[col] = sp.signal.filtfilt(*sp.signal.butter(4, 125.0/250.,  'low'),dataset[col].T).T
    
# Load fiducials
Pon = load_data(os.path.join(basedir,'QTDB','PonNew.csv'))
Poff = load_data(os.path.join(basedir,'QTDB','PoffNew.csv'))
QRSon = load_data(os.path.join(basedir,'QTDB','QRSonNew.csv'))
QRSoff = load_data(os.path.join(basedir,'QTDB','QRSoffNew.csv'))
Ton = load_data(os.path.join(basedir,'QTDB','TonNew.csv'))
Toff = load_data(os.path.join(basedir,'QTDB','ToffNew.csv'))

In [None]:
threshold = 0.99

PsignalQTDB = {}
PQsignalQTDB = {}
QRSsignalQTDB = {}
STsignalQTDB = {}
TsignalQTDB = {}
TPsignalQTDB = {}

PgroupQTDB = {}
PQgroupQTDB = {}
QRSgroupQTDB = {}
STgroupQTDB = {}
TgroupQTDB = {}
TPgroupQTDB = {}

for k in tqdm.tqdm(dataset.keys()):
    # Buggy files
    if k in ['sel232_0', 'sel232_1']:
        continue
    pon = Pon.get(k,np.array([]))
    pof = Poff.get(k,np.array([]))
    qon = QRSon.get(k,np.array([]))
    qof = QRSoff.get(k,np.array([]))
    ton = Ton.get(k,np.array([]))
    tof = Toff.get(k,np.array([]))
    
    unordered_samples = np.concatenate([pon,pof,qon,qof,ton,tof,]).astype(float)
    unordered_symbols = np.concatenate([['Pon']*pon.size,['Poff']*pof.size,
                                        ['QRSon']*qon.size,['QRSoff']*qof.size,
                                        ['Ton']*ton.size,['Toff']*tof.size,])
    # Sort fiducials taking logical orders if same sample of occurrence
    # There is (I'm definitely sure) a better way to do it
    samples = []
    symbols = []
    for i in range(unordered_samples.size):
        minimum = np.where(unordered_samples == min(unordered_samples))[0]
        if minimum.size == 1:
            minimum = minimum[0]
            samples.append(int(unordered_samples[minimum]))
            symbols.append(unordered_symbols[minimum])
            unordered_samples[minimum] = np.inf
        elif minimum.size == 2:
            if symbols[-1] == 'Pon':
                if unordered_symbols[minimum[0]] == 'Poff':
                    samples.append(int(unordered_samples[minimum[0]]))
                    symbols.append(unordered_symbols[minimum[0]])
                    unordered_samples[minimum[0]] = np.inf
                elif unordered_symbols[minimum[1]] == 'Poff':
                    samples.append(int(unordered_samples[minimum[1]]))
                    symbols.append(unordered_symbols[minimum[1]])
                    unordered_samples[minimum[1]] = np.inf
            elif symbols[-1] == 'QRSon':
                if unordered_symbols[minimum[0]] == 'QRSoff':
                    samples.append(int(unordered_samples[minimum[0]]))
                    symbols.append(unordered_symbols[minimum[0]])
                    unordered_samples[minimum[0]] = np.inf
                elif unordered_symbols[minimum[1]] == 'QRSoff':
                    samples.append(int(unordered_samples[minimum[1]]))
                    symbols.append(unordered_symbols[minimum[1]])
                    unordered_samples[minimum[1]] = np.inf
            elif symbols[-1] == 'Ton':
                if unordered_symbols[minimum[0]] == 'Toff':
                    samples.append(int(unordered_samples[minimum[0]]))
                    symbols.append(unordered_symbols[minimum[0]])
                    unordered_samples[minimum[0]] = np.inf
                elif unordered_symbols[minimum[1]] == 'Toff':
                    samples.append(int(unordered_samples[minimum[1]]))
                    symbols.append(unordered_symbols[minimum[1]])
                    unordered_samples[minimum[1]] = np.inf
            else:
                raise ValueError("Should not happen at all")
        else:
            raise ValueError("Definitely should not happen. Check file {}".format(k))
    samples = np.array(samples)
    symbols = np.array(symbols)
    
    # Extract segments
    P = []
    QRS = []
    T = []
    TP = []
    PQ = []
    ST = []

    for i in range(samples.size-1):
        if samples[i] == samples[i+1]:
            continue
        if symbols[i] == 'Pon':
            if symbols[i+1] == 'Poff':
                P.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            else:
                print("Check file {}. P onset not followed by offset".format(k))
        elif symbols[i] == 'QRSon':
            if symbols[i+1] == 'QRSoff':
                QRS.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            else:
                print("Check file {}. QRS onset not followed by offset".format(k))
        elif symbols[i] == 'Ton':
            if symbols[i+1] == 'Toff':
                T.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            else:
                print("Check file {}. T onset not followed by offset".format(k))
        elif symbols[i] == 'Poff':
            if symbols[i+1] == 'Pon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'QRSon':
                PQ.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'Ton':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] in ['Poff','QRSoff','Toff']:
                print("Check file {}. P offset not followed by onset".format(k))
        elif symbols[i] == 'QRSoff':
            if symbols[i+1] == 'Pon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'QRSon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'Ton':
                ST.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] in ['Poff','QRSoff','Toff']:
                print("Check file {}. P offset not followed by onset".format(k))
        elif symbols[i] == 'Toff':
            if symbols[i+1] == 'Pon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'QRSon':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] == 'Ton':
                TP.append((k,group[k],dataset[k][samples[i]:samples[i+1]].values))
            elif symbols[i+1] in ['Poff','QRSoff','Toff']:
                print("Check file {}. P offset not followed by onset".format(k))
        else:
            raise ValueError("This should definitely not happen")
            
    # Filter out too long TP segments (causing this to break)
    TP = [TP[i] for i in range(len(TP)) if TP[i][2].size < 250]

    # Filter out too similar segments
    corr = getcorr(P)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    P = [P[i] for i in range(len(P)) if i in nodesclique]

    corr = getcorr(QRS)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    QRS = [QRS[i] for i in range(len(QRS)) if i in nodesclique]

    corr = getcorr(T)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    T = [T[i] for i in range(len(T)) if i in nodesclique]

    corr = getcorr(TP)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    TP = [TP[i] for i in range(len(TP)) if i in nodesclique]

    corr = getcorr(PQ)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    PQ = [PQ[i] for i in range(len(PQ)) if i in nodesclique]

    corr = getcorr(ST)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    ST = [ST[i] for i in range(len(ST)) if i in nodesclique]
    
    # Store segments
    for i in range(len(P)):
        PsignalQTDB[k + '_' + str(i)] = P[i][2]
        PgroupQTDB[k + '_' + str(i)] = (P[i][0],P[i][1])
    for i in range(len(QRS)):
        QRSsignalQTDB[k + '_' + str(i)] = QRS[i][2]
        QRSgroupQTDB[k + '_' + str(i)] = (QRS[i][0],QRS[i][1])
    for i in range(len(T)):
        TsignalQTDB[k + '_' + str(i)] = T[i][2]
        TgroupQTDB[k + '_' + str(i)] = (T[i][0],T[i][1])
    for i in range(len(TP)):
        TPsignalQTDB[k + '_' + str(i)] = TP[i][2]
        TPgroupQTDB[k + '_' + str(i)] = (TP[i][0],TP[i][1])
    for i in range(len(PQ)):
        PQsignalQTDB[k + '_' + str(i)] = PQ[i][2]
        PQgroupQTDB[k + '_' + str(i)] = (PQ[i][0],PQ[i][1])
    for i in range(len(ST)):
        STsignalQTDB[k + '_' + str(i)] = ST[i][2]
        STgroupQTDB[k + '_' + str(i)] = (ST[i][0],ST[i][1])


In [None]:
print(len(PgroupQTDB))
print(len(PQgroupQTDB))
print(len(QRSgroupQTDB))
print(len(STgroupQTDB))
print(len(TgroupQTDB))
print(len(TPgroupQTDB))

# Merge databases

In [None]:
Psignal = {}
Pgroup = {}
PQsignal = {}
PQgroup = {}
QRSsignal = {}
QRSgroup = {}
STsignal = {}
STgroup = {}
Tsignal = {}
Tgroup = {}
TPsignal = {}
TPgroup = {}

Psignal.update(PsignalQTDB)
Pgroup.update(PgroupQTDB)
PQsignal.update(PQsignalQTDB)
PQgroup.update(PQgroupQTDB)
QRSsignal.update(QRSsignalQTDB)
QRSgroup.update(QRSgroupQTDB)
STsignal.update(STsignalQTDB)
STgroup.update(STgroupQTDB)
Tsignal.update(TsignalQTDB)
Tgroup.update(TgroupQTDB)
TPsignal.update(TPsignalQTDB)
TPgroup.update(TPgroupQTDB)

Psignal.update(PsignalLUDB)
Pgroup.update(PgroupLUDB)
PQsignal.update(PQsignalLUDB)
PQgroup.update(PQgroupLUDB)
QRSsignal.update(QRSsignalLUDB)
QRSgroup.update(QRSgroupLUDB)
STsignal.update(STsignalLUDB)
STgroup.update(STgroupLUDB)
Tsignal.update(TsignalLUDB)
Tgroup.update(TgroupLUDB)
TPsignal.update(TPsignalLUDB)
TPgroup.update(TPgroupLUDB)

In [None]:
print(len(Pgroup))
print(len(PQgroup))
print(len(QRSgroup))
print(len(STgroup))
print(len(Tgroup))
print(len(TPgroup))

# Delete too short or too long signals

In [None]:
# Signal lengths
Plength = {k: len(Psignal[k]) for k in Psignal.keys() if not isinstance(Psignal[k],float)}
PQlength = {k: len(PQsignal[k]) for k in PQsignal.keys() if not isinstance(PQsignal[k],float)}
QRSlength = {k: len(QRSsignal[k]) for k in QRSsignal.keys() if not isinstance(QRSsignal[k],float)}
STlength = {k: len(STsignal[k]) for k in STsignal.keys() if not isinstance(STsignal[k],float)}
Tlength = {k: len(Tsignal[k]) for k in Tsignal.keys() if not isinstance(Tsignal[k],float)}
TPlength = {k: len(TPsignal[k]) for k in TPsignal.keys() if not isinstance(TPsignal[k],float)}

In [None]:
# Filter signals by length
for k in list(Psignal.keys()):
    if isinstance(Psignal[k],float):
        Psignal.pop(k)
        Pgroup.pop(k)
    elif not ((len(Psignal[k]) > 1) and (len(Psignal[k]) < 45)):
        Psignal.pop(k)
        Pgroup.pop(k)
for k in list(PQsignal.keys()):
    if isinstance(PQsignal[k],float):
        PQsignal.pop(k)
        PQgroup.pop(k)
    elif not ((len(PQsignal[k]) > 1) and (len(PQsignal[k]) < 35)):
        PQsignal.pop(k)
        PQgroup.pop(k)
for k in list(QRSsignal.keys()):
    if isinstance(QRSsignal[k],float):
        QRSsignal.pop(k)
        QRSgroup.pop(k)
    elif not ((len(QRSsignal[k]) > 1) and (len(QRSsignal[k]) < 50)):
        QRSsignal.pop(k)
        QRSgroup.pop(k)
for k in list(STsignal.keys()):
    if isinstance(STsignal[k],float):
        STsignal.pop(k)
        STgroup.pop(k)
    elif not ((len(STsignal[k]) > 1) and (len(STsignal[k]) < 65)):
        STsignal.pop(k)
        STgroup.pop(k)
for k in list(Tsignal.keys()):
    if isinstance(Tsignal[k],float):
        Tsignal.pop(k)
        Tgroup.pop(k)
    elif not ((len(Tsignal[k]) > 1) and (len(Tsignal[k]) < 100)):
        Tsignal.pop(k)
        Tgroup.pop(k)
for k in list(TPsignal.keys()):
    if isinstance(TPsignal[k],float):
        TPsignal.pop(k)
        TPgroup.pop(k)
    elif not ((len(TPsignal[k]) > 1) and (len(TPsignal[k]) < 250)):
        TPsignal.pop(k)
        TPgroup.pop(k)

In [None]:
print(len(Pgroup))
print(len(PQgroup))
print(len(QRSgroup))
print(len(STgroup))
print(len(Tgroup))
print(len(TPgroup))

# Save files

In [None]:
utils.pickledump(Psignal,os.path.join('.','pickle','Psignal.pkl'))
utils.pickledump(Pgroup,os.path.join('.','pickle','Pgroup.pkl'))
utils.pickledump(PQsignal,os.path.join('.','pickle','PQsignal.pkl'))
utils.pickledump(PQgroup,os.path.join('.','pickle','PQgroup.pkl'))
utils.pickledump(QRSsignal,os.path.join('.','pickle','QRSsignal.pkl'))
utils.pickledump(QRSgroup,os.path.join('.','pickle','QRSgroup.pkl'))
utils.pickledump(STsignal,os.path.join('.','pickle','STsignal.pkl'))
utils.pickledump(STgroup,os.path.join('.','pickle','STgroup.pkl'))
utils.pickledump(Tsignal,os.path.join('.','pickle','Tsignal.pkl'))
utils.pickledump(Tgroup,os.path.join('.','pickle','Tgroup.pkl'))
utils.pickledump(TPsignal,os.path.join('.','pickle','TPsignal.pkl'))
utils.pickledump(TPgroup,os.path.join('.','pickle','TPgroup.pkl'))

# Normalize amplitudes

In [None]:
normalization_strategy = 0

amplitudes = {stratification[normalization_strategy]: [] for stratification in list(set(QRSgroup.values()))}
for k in QRSsignal:
    stratification = QRSgroup[k]
    g = stratification[normalization_strategy]
    amplitudes[g].append(np.max(np.abs(utils.signal.on_off_correction(QRSsignal[k]))))


In [None]:
metric = np.median

for k in Psignal:
    stratification = Pgroup[k]
    g = stratification[normalization_strategy]
    Psignal[k] = utils.signal.on_off_correction(Psignal[k])/metric(amplitudes[g])

for k in PQsignal:
    stratification = PQgroup[k]
    g = stratification[normalization_strategy]
    PQsignal[k] = utils.signal.on_off_correction(PQsignal[k])/metric(amplitudes[g])

for k in QRSsignal:
    stratification = QRSgroup[k]
    g = stratification[normalization_strategy]
    QRSsignal[k] = utils.signal.on_off_correction(QRSsignal[k])/metric(amplitudes[g])

for k in STsignal:
    stratification = STgroup[k]
    g = stratification[normalization_strategy]
    STsignal[k] = utils.signal.on_off_correction(STsignal[k])/metric(amplitudes[g])

for k in Tsignal:
    stratification = Tgroup[k]
    g = stratification[normalization_strategy]
    Tsignal[k] = utils.signal.on_off_correction(Tsignal[k])/metric(amplitudes[g])

for k in TPsignal:
    stratification = TPgroup[k]
    g = stratification[normalization_strategy]
    TPsignal[k] = utils.signal.on_off_correction(TPsignal[k])/metric(amplitudes[g])

# Adding inverses - Data augmentation

NOPE! At runtime. Otherwise, memory issues that fuck my mixup up.

In [None]:
# Psignal.update({'-'+k: -Psignal[k] for k in Psignal})
# Pgroup.update({'-'+k: Pgroup[k] for k in Pgroup})
# PQsignal.update({'-'+k: -PQsignal[k] for k in PQsignal})
# PQgroup.update({'-'+k: PQgroup[k] for k in PQgroup})
# QRSsignal.update({'-'+k: -QRSsignal[k] for k in QRSsignal})
# QRSgroup.update({'-'+k: QRSgroup[k] for k in QRSgroup})
# STsignal.update({'-'+k: -STsignal[k] for k in STsignal})
# STgroup.update({'-'+k: STgroup[k] for k in STgroup})
# Tsignal.update({'-'+k: -Tsignal[k] for k in Tsignal})
# Tgroup.update({'-'+k: Tgroup[k] for k in Tgroup})
# TPsignal.update({'-'+k: -TPsignal[k] for k in TPsignal})
# TPgroup.update({'-'+k: TPgroup[k] for k in TPgroup})

In [None]:
# print(len(Pgroup))
# print(len(PQgroup))
# print(len(QRSgroup))
# print(len(STgroup))
# print(len(Tgroup))
# print(len(TPgroup))

# Mixup - Data augmentation

In [None]:
# number = 1

# too_much_it_squares_amount_of_data

# permuted = np.random.permutation(list(Psignal))
# for k1 in tqdm.tqdm(list(Psignal.keys())):
#     visited = {}
#     (k_all_1,g_1) = Pgroup[k1]
#     counter = 0
#     for k2 in permuted:
#         (k_all_2,g_2) = Pgroup[k2]
#         if (k1 != k2) & (g_1 != g_2):
#             visited[g_2] = visited.get(g_2,0)+1
#             if visited[g_2] > number:
#                 continue
#             if Psignal[k1].size != Psignal[k2].size:
#                 intlen = np.random.randint(min([Psignal[k1].size,Psignal[k2].size]),max([Psignal[k1].size,Psignal[k2].size]))
#                 x1 = sp.interpolate.interp1d(np.linspace(0,1,Psignal[k1].size),Psignal[k1])(np.linspace(0,1,intlen))
#                 x2 = sp.interpolate.interp1d(np.linspace(0,1,Psignal[k2].size),Psignal[k2])(np.linspace(0,1,intlen))
#             else:
#                 x1 = Psignal[k1]
#                 x2 = Psignal[k2]
#             (xhat,lmbda) = utils.data.augmentation.mixup(x1,x2,5.,1.5)
#             Psignal[k1+'m'+str(counter)] = xhat.squeeze()
#             Pgroup[k1+'m'+str(counter)] = Pgroup[k1]
#             counter += 1

# permuted = np.random.permutation(list(QRSsignal))
# for k1 in tqdm.tqdm(list(QRSsignal.keys())):
#     visited = {}
#     (k_all_1,g_1) = QRSgroup[k1]
#     counter = 0
#     for k2 in permuted:
#         (k_all_2,g_2) = QRSgroup[k2]
#         if (k1 != k2) & (g_1 != g_2):
#             visited[g_2] = visited.get(g_2,0)+1
#             if visited[g_2] > number:
#                 continue
#             if QRSsignal[k1].size != QRSsignal[k2].size:
#                 intlen = np.random.randint(min([QRSsignal[k1].size,QRSsignal[k2].size]),max([QRSsignal[k1].size,QRSsignal[k2].size]))
#                 x1 = sp.interpolate.interp1d(np.linspace(0,1,QRSsignal[k1].size),QRSsignal[k1])(np.linspace(0,1,intlen))
#                 x2 = sp.interpolate.interp1d(np.linspace(0,1,QRSsignal[k2].size),QRSsignal[k2])(np.linspace(0,1,intlen))
#             else:
#                 x1 = QRSsignal[k1]
#                 x2 = QRSsignal[k2]
#             (xhat,lmbda) = utils.data.augmentation.mixup(x1,x2,5.,1.5)
#             QRSsignal[k1+'m'+str(counter)] = xhat.squeeze()
#             QRSgroup[k1+'m'+str(counter)] = QRSgroup[k1]
#             counter += 1

# permuted = np.random.permutation(list(Tsignal))
# for k1 in tqdm.tqdm(list(Tsignal.keys())):
#     visited = {}
#     (k_all_1,g_1) = Tgroup[k1]
#     counter = 0
#     for k2 in permuted:
#         (k_all_2,g_2) = Tgroup[k2]
#         if (k1 != k2) & (g_1 != g_2):
#             visited[g_2] = visited.get(g_2,0)+1
#             if visited[g_2] > number:
#                 continue
#             if Tsignal[k1].size != Tsignal[k2].size:
#                 intlen = np.random.randint(min([Tsignal[k1].size,Tsignal[k2].size]),max([Tsignal[k1].size,Tsignal[k2].size]))
#                 x1 = sp.interpolate.interp1d(np.linspace(0,1,Tsignal[k1].size),Tsignal[k1])(np.linspace(0,1,intlen))
#                 x2 = sp.interpolate.interp1d(np.linspace(0,1,Tsignal[k2].size),Tsignal[k2])(np.linspace(0,1,intlen))
#             else:
#                 x1 = Tsignal[k1]
#                 x2 = Tsignal[k2]
#             (xhat,lmbda) = utils.data.augmentation.mixup(x1,x2,5.,1.5)
#             Tsignal[k1+'m'+str(counter)] = xhat.squeeze()
#             Tgroup[k1+'m'+str(counter)] = Tgroup[k1]
#             counter += 1

In [None]:
# print(len(Pgroup))
# print(len(PQgroup))
# print(len(QRSgroup))
# print(len(STgroup))
# print(len(Tgroup))
# print(len(TPgroup))

# Compute criteria

In [None]:
# Generate wavelets
Pwavelet = {k: utils.wavelet.transform(Psignal[k],250.).squeeze() for k in tqdm.tqdm(Psignal.keys())}
PQwavelet = {k: utils.wavelet.transform(PQsignal[k],250.).squeeze() for k in tqdm.tqdm(PQsignal.keys())}
QRSwavelet = {k: utils.wavelet.transform(QRSsignal[k],250.).squeeze() for k in tqdm.tqdm(QRSsignal.keys())}
STwavelet = {k: utils.wavelet.transform(STsignal[k],250.).squeeze() for k in tqdm.tqdm(STsignal.keys())}
Twavelet = {k: utils.wavelet.transform(Tsignal[k],250.).squeeze() for k in tqdm.tqdm(Tsignal.keys())}
TPwavelet = {k: utils.wavelet.transform(TPsignal[k],250.).squeeze() for k in tqdm.tqdm(TPsignal.keys())}

In [None]:
# Generate criteria
s = 3 # wavelet scale
eps = np.finfo('float').eps
Pcriteria = {k: ((np.sign(Pwavelet[k][0,s]),(Pwavelet[k][0,s]-Pwavelet[k][1,s]+eps)/(np.max(Pwavelet[k][:,s])-np.min(Pwavelet[k][:,s]+eps))),
                 (np.sign(Pwavelet[k][-1,s]),(Pwavelet[k][-1,s]-Pwavelet[k][-2,s]+eps)/(np.max(Pwavelet[k][:,s])-np.min(Pwavelet[k][:,s]+eps)))) for k in Psignal.keys()}
PQcriteria = {k: ((np.sign(PQwavelet[k][0,s]),(PQwavelet[k][0,s]-PQwavelet[k][1,s]+eps)/(np.max(PQwavelet[k][:,s])-np.min(PQwavelet[k][:,s]+eps))),
                 (np.sign(PQwavelet[k][-1,s]),(PQwavelet[k][-1,s]-PQwavelet[k][-2,s]+eps)/(np.max(PQwavelet[k][:,s])-np.min(PQwavelet[k][:,s]+eps)))) for k in PQsignal.keys()}
QRScriteria = {k: ((np.sign(QRSwavelet[k][0,s]),(QRSwavelet[k][0,s]-QRSwavelet[k][1,s]+eps)/(np.max(QRSwavelet[k][:,s])-np.min(QRSwavelet[k][:,s]+eps))),
                 (np.sign(QRSwavelet[k][-1,s]),(QRSwavelet[k][-1,s]-QRSwavelet[k][-2,s]+eps)/(np.max(QRSwavelet[k][:,s])-np.min(QRSwavelet[k][:,s]+eps)))) for k in QRSsignal.keys()}
STcriteria = {k: ((np.sign(STwavelet[k][0,s]),(STwavelet[k][0,s]-STwavelet[k][1,s]+eps)/(np.max(STwavelet[k][:,s])-np.min(STwavelet[k][:,s]+eps))),
                 (np.sign(STwavelet[k][-1,s]),(STwavelet[k][-1,s]-STwavelet[k][-2,s]+eps)/(np.max(STwavelet[k][:,s])-np.min(STwavelet[k][:,s]+eps)))) for k in STsignal.keys()}
Tcriteria = {k: ((np.sign(Twavelet[k][0,s]),(Twavelet[k][0,s]-Twavelet[k][1,s]+eps)/(np.max(Twavelet[k][:,s])-np.min(Twavelet[k][:,s]+eps))),
                 (np.sign(Twavelet[k][-1,s]),(Twavelet[k][-1,s]-Twavelet[k][-2,s]+eps)/(np.max(Twavelet[k][:,s])-np.min(Twavelet[k][:,s]+eps)))) for k in Tsignal.keys()}
TPcriteria = {k: ((np.sign(TPwavelet[k][0,s]),(TPwavelet[k][0,s]-TPwavelet[k][1,s]+eps)/(np.max(TPwavelet[k][:,s])-np.min(TPwavelet[k][:,s]+eps))),
                 (np.sign(TPwavelet[k][-1,s]),(TPwavelet[k][-1,s]-TPwavelet[k][-2,s]+eps)/(np.max(TPwavelet[k][:,s])-np.min(TPwavelet[k][:,s]+eps)))) for k in TPsignal.keys()}

# Generate sample record with QT and LUDB

In [None]:
Pkeys = list(Psignal.keys())
PQkeys = list(PQsignal.keys())
QRSkeys = list(QRSsignal.keys())
STkeys = list(STsignal.keys())
Tkeys = list(Tsignal.keys())
TPkeys = list(TPsignal.keys())

In [None]:
# %%timeit
N = 2048

# Hyperparams
size = 20
onset = np.random.randint(0,50)
begining_wave = np.random.randint(0,6)
has_P = np.random.rand(1) > 0.1
has_PQ = np.random.rand(1) > 0.2
has_ST = np.random.rand(1) > 0.2
has_BBB = np.random.rand(1) > 0.9
proba_P = 0.15
proba_PQ = 0.15
proba_QRS = 0.01
proba_ST = 0.15

##### Data structure
ids = []

##### Identifiers
id_P = np.random.randint(0,len(Psignal),size=size)
id_PQ = np.random.randint(0,len(PQsignal),size=size)
id_QRS = np.random.randint(0,len(QRSsignal),size=size)
id_ST = np.random.randint(0,len(STsignal),size=size)
id_T = np.random.randint(0,len(Tsignal),size=size)
id_TP = np.random.randint(0,len(TPsignal),size=size)

# In case QRS is not expressed
filt_QRS = np.random.rand(size) < proba_QRS

# P wave
id_P[(np.random.rand(size) < proba_P) | np.logical_not(has_P)] = -1
id_PQ[filt_QRS | (np.random.rand(size) < proba_PQ) | np.logical_not(has_PQ)] = -1
id_QRS[filt_QRS] = -1
id_ST[filt_QRS | (np.random.rand(size) < proba_ST) | np.logical_not(has_ST)] = -1
id_T[filt_QRS] = -1

beats = []
masks = []
offset = 0
record_size = 0
mark_break = False
for i in range(size):
    for j in range(6):
        if (i == 0) and (j < begining_wave): 
            continue
        if (j == 0) and (id_P[i] != -1):
            beats.append(trailonset(Psignal[Pkeys[id_P[i]]],offset))
            masks.append(np.full((beats[-1].size,),1,dtype='int8'))
            offset = beats[-1][-1]
            record_size += beats[-1].size
        if (j == 1) and (id_PQ[i] != -1):
            beats.append(trailonset(PQsignal[PQkeys[id_PQ[i]]],offset))
            masks.append(np.zeros((beats[-1].size,),dtype='int8'))
            offset = beats[-1][-1]
            record_size += beats[-1].size
        if (j == 2) and (id_QRS[i] != -1):
            beats.append(trailonset(QRSsignal[QRSkeys[id_QRS[i]]],offset))
            masks.append(np.full((beats[-1].size,),2,dtype='int8'))
            offset = beats[-1][-1]
            record_size += beats[-1].size
        if (j == 3) and (id_ST[i] != -1):
            beats.append(trailonset(STsignal[STkeys[id_ST[i]]],offset))
            masks.append(np.zeros((beats[-1].size,),dtype='int8'))
            offset = beats[-1][-1]
            record_size += beats[-1].size
        if (j == 4) and (id_T[i] != -1):
            beats.append(trailonset(Tsignal[Tkeys[id_T[i]]],offset))
            masks.append(np.full((beats[-1].size,),3,dtype='int8'))
            offset = beats[-1][-1]
            record_size += beats[-1].size
        if (j == 5) and (id_TP[i] != -1):
            beats.append(trailonset(TPsignal[TPkeys[id_TP[i]]],offset))
            masks.append(np.zeros((beats[-1].size,),dtype='int8'))
            offset = beats[-1][-1]
            record_size += beats[-1].size
        if (record_size-onset) >= N:
            mark_break = True
            break
    if mark_break:
        break
        
# Obtain final stuff
signal = np.concatenate(beats)
masks = np.concatenate(masks)
masks_all = np.zeros((record_size,3),dtype=bool)
masks_all[:,0] = (masks == 1)
masks_all[:,1] = (masks == 2)
masks_all[:,2] = (masks == 3)

# Move onset
signal = signal[onset:onset+N]
masks_all = masks_all[onset:onset+N,:]

In [None]:
mskplt = ((np.max(signal)-np.min(signal))*masks_all)+np.min(signal)

plt.figure(figsize=(20,5))
plt.plot(signal)
plt.gca().fill_between(np.arange(N), mskplt[:,0], mskplt[:,0].min(), linewidth=0, alpha=0.15, color='red')
plt.gca().fill_between(np.arange(N), mskplt[:,1], mskplt[:,1].min(), linewidth=0, alpha=0.15, color='green')
plt.gca().fill_between(np.arange(N), mskplt[:,2], mskplt[:,2].min(), linewidth=0, alpha=0.15, color='magenta')
plt.show()

In [None]:
# N = 2048
# s = 3

# has_P = np.random.rand(1) > 0.1
# has_PQ = np.random.rand(1) > 0.2
# has_ST = np.random.rand(1) > 0.2
# has_BBB = np.random.rand(1) > 0.9
# counter_BBB = 0
# repetitions_BBB = np.random.randint(2,4)

# beats = []
# ids = []

# # Include first beat
# ids.append(('TPsignal',np.random.randint(0,len(TPsignal))))
# beats.append(utils.signal.on_off_correction(TPsignal[list(TPsignal)[ids[-1][1]]]))
# size = beats[0].size
# masks = np.zeros((size,),dtype='int8')
# onset = np.random.randint(0,size)
# while size-onset < N:
#     # P wave (sometimes)
#     if has_BBB:
#         if counter_BBB == 0:
#             id_BBB_P = ('Psignal',np.random.randint(0,len(Psignal)))
#         ids.append(id_BBB_P)
#         p = trailonset(Psignal[list(Psignal)[ids[-1][1]]],beats[-1][-1])[1:]
#         beats.append(p)
#         masks = np.concatenate((masks,1*np.ones((p.size,),dtype='int8')))

#         if has_PQ:
#             # PQ segment
#             ids.append(('PQsignal',np.random.randint(0,len(PQsignal))))
#             pq = trailonset(PQsignal[list(PQsignal)[ids[-1][1]]],beats[-1][-1])[1:]
#             beats.append(pq)
#             masks = np.concatenate((masks,np.zeros((pq.size,),dtype='int8')))
#     elif (np.random.rand(1) < 0.75) and (has_P):
#         ids.append(('Psignal',np.random.randint(0,len(Psignal))))
#         p = trailonset(Psignal[list(Psignal)[ids[-1][1]]],beats[-1][-1])[1:]
#         beats.append(p)
#         masks = np.concatenate((masks,1*np.ones((p.size,),dtype='int8')))

#         if has_PQ:
#             # PQ segment
#             ids.append(('PQsignal',np.random.randint(0,len(PQsignal))))
#             pq = trailonset(PQsignal[list(PQsignal)[ids[-1][1]]],beats[-1][-1])[1:]
#             beats.append(pq)
#             masks = np.concatenate((masks,np.zeros((pq.size,),dtype='int8')))

#     # QRS wave
#     has_QRS = np.random.rand(1)
#     if has_BBB:
#         if counter_BBB%repetitions_BBB == 0:
#             id_BBB_QRS = ('QRSsignal',np.random.randint(0,len(QRSsignal)))
#             ids.append(id_BBB_QRS)
#             qrs = trailonset(QRSsignal[list(QRSsignal)[ids[-1][1]]],beats[-1][-1])[1:]
#             beats.append(qrs)
#             masks = np.concatenate((masks,2*np.ones((qrs.size,),dtype='int8')))
#         else:
#             pass
#     elif (has_QRS < 0.99):
#         ids.append(('QRSsignal',np.random.randint(0,len(QRSsignal))))
#         qrs = trailonset(QRSsignal[list(QRSsignal)[ids[-1][1]]],beats[-1][-1])[1:]
#         beats.append(qrs)
#         masks = np.concatenate((masks,2*np.ones((qrs.size,),dtype='int8')))
    
#     # ST segment
#     if has_BBB and (counter_BBB%repetitions_BBB != 0):
#         pass
#     elif (np.random.rand(1) < 0.75) and (has_ST):
#         ids.append(('STsignal',np.random.randint(0,len(STsignal))))
#         st = trailonset(STsignal[list(STsignal)[ids[-1][1]]],beats[-1][-1])[1:]
#         beats.append(st)
#         masks = np.concatenate((masks,np.zeros((st.size,),dtype='int8')))

#     # T wave
#     if has_BBB:
#         if counter_BBB%repetitions_BBB == 0:
#             ids.append(('Tsignal',np.random.randint(0,len(Tsignal))))
#             t = trailonset(Tsignal[list(Tsignal)[ids[-1][1]]],beats[-1][-1])[1:]
#             beats.append(t)
#             masks = np.concatenate((masks,3*np.ones((t.size,),dtype='int8')))
#         else:
#             pass
#     elif (has_QRS < 0.99):
#         ids.append(('Tsignal',np.random.randint(0,len(Tsignal))))
#         t = trailonset(Tsignal[list(Tsignal)[ids[-1][1]]],beats[-1][-1])[1:]
#         beats.append(t)
#         masks = np.concatenate((masks,3*np.ones((t.size,),dtype='int8')))

#     # TP segment
#     if has_BBB:
#         if counter_BBB == 0:
#             id_BBB_TP = ('TPsignal',np.random.randint(0,len(TPsignal)))
#         ids.append(id_BBB_TP)
#         tp = trailonset(TPsignal[list(TPsignal)[ids[-1][1]]],beats[-1][-1])[1:]
#         beats.append(tp)
#         masks = np.concatenate((masks,np.zeros((tp.size,),dtype='int8')))
#     else:
#         ids.append(('TPsignal',np.random.randint(0,len(TPsignal))))
#         tp = trailonset(TPsignal[list(TPsignal)[ids[-1][1]]],beats[-1][-1])[1:]
#         beats.append(tp)
#         masks = np.concatenate((masks,np.zeros((tp.size,),dtype='int8')))
    
#     # Account for total signal size
#     size = sum([beats[i].size for i in range(len(beats))])
        
#     # Update BBB counter
#     if has_BBB:
#         counter_BBB += 1

# sig = np.concatenate(beats)[onset:onset+2048]
# # sig = sp.signal.filtfilt(*sp.signal.butter(4,   0.5/250, 'high'),sig)
# signal = sig# + np.convolve(np.cumsum(norm.rvs(scale=0.15**(2*0.5),size=N)),np.hamming(w)/(w/2),mode='same')

# masks = masks[onset:onset+2048]
# masks_all = np.zeros((N,3),dtype=bool)
# masks_all[:,0] = masks == 1
# masks_all[:,1] = masks == 2
# masks_all[:,2] = masks == 3
# mskplt = ((np.max(signal)-np.min(signal))*masks_all)+np.min(signal)

# # f,ax = plt.subplots(nrows=1,figsize=(20,4))
# # ax = np.array(ax)
# # if len(ax.shape) == 0: ax = ax[None]
# # [ax[i].set_xlim([0,N]) for i in range(ax.size)]
# # [ax[i].fill_between(np.arange(N), mskplt[:,0], mskplt[:,0].min(), linewidth=0, alpha=0.15, color='red') for i in range(ax.size)]
# # [ax[i].fill_between(np.arange(N), mskplt[:,1], mskplt[:,1].min(), linewidth=0, alpha=0.15, color='green') for i in range(ax.size)]
# # [ax[i].fill_between(np.arange(N), mskplt[:,2], mskplt[:,2].min(), linewidth=0, alpha=0.15, color='magenta') for i in range(ax.size)]
# # ax[0].plot(signal)
# # # ax[1].plot(wvlts,color='orange')
# # # ax[2].plot(wvlts_signal,color='orange')

In [None]:
plt.plot(P)
plt.plot(QRS)
plt.plot(T)

In [None]:
# plt.plot(np.concatenate((P,QRS,T)))
plt.plot(np.concatenate((Pw,QRSw,Tw)))

In [None]:
N = 2048

beats = []
ids = []

# Include first beat
ids.append(('TPsignal',np.random.randint(0,len(TPsignal))))
beats.append(utils.signal.on_off_correction(TPsignal[list(TPsignal)[ids[-1][1]]]))
# beats.append(TPsignal[list(TPsignal)[np.random.randint(0,len(TPsignal))]])
size = beats[0].size
masks = np.zeros((size,))
onset = np.random.randint(0,size)
while size-onset < N:
    # P wave (sometimes)
    if np.random.rand(1) < 0.75:
        # p = utils.signal.on_off_correction(Psignal[list(Psignal)[np.random.randint(0,len(Psignal))]])
        ids.append(('Psignal',np.random.randint(0,len(Psignal))))
        p = trailonset(Psignal[list(Psignal)[ids[-1][1]]],beats[-1][-1])
        # p = Psignal[list(Psignal)[np.random.randint(0,len(Psignal))]]
        beats.append(p)
        masks = np.concatenate((masks,1*np.ones((p.size,))))

    # PQ segment
    # pq = utils.signal.on_off_correction(PQsignal[list(PQsignal)[np.random.randint(0,len(PQsignal))]])
    ids.append(('PQsignal',np.random.randint(0,len(PQsignal))))
    pq = trailonset(PQsignal[list(PQsignal)[ids[-1][1]]],beats[-1][-1])
    # pq = PQsignal[list(PQsignal)[np.random.randint(0,len(PQsignal))]]
    beats.append(pq)
    masks = np.concatenate((masks,np.zeros((pq.size,))))

    # QRS wave
    # qrs = utils.signal.on_off_correction(QRSsignal[list(QRSsignal)[np.random.randint(0,len(QRSsignal))]])
    ids.append(('QRSsignal',np.random.randint(0,len(QRSsignal))))
    qrs = trailonset(QRSsignal[list(QRSsignal)[ids[-1][1]]],beats[-1][-1])
    # qrs = QRSsignal[list(QRSsignal)[np.random.randint(0,len(QRSsignal))]]
    beats.append(qrs)
    masks = np.concatenate((masks,2*np.ones((qrs.size,))))

    # ST segment
    # st = utils.signal.on_off_correction(STsignal[list(STsignal)[np.random.randint(0,len(STsignal))]])
    ids.append(('STsignal',np.random.randint(0,len(STsignal))))
    st = trailonset(STsignal[list(STsignal)[ids[-1][1]]],beats[-1][-1])
    # st = STsignal[list(STsignal)[np.random.randint(0,len(STsignal))]]
    beats.append(st)
    masks = np.concatenate((masks,np.zeros((st.size,))))

    # T wave
    # t = utils.signal.on_off_correction(Tsignal[list(Tsignal)[np.random.randint(0,len(Tsignal))]])
    ids.append(('Tsignal',np.random.randint(0,len(Tsignal))))
    t = trailonset(Tsignal[list(Tsignal)[ids[-1][1]]],beats[-1][-1])
    # t = Tsignal[list(Tsignal)[np.random.randint(0,len(Tsignal))]]
    beats.append(t)
    masks = np.concatenate((masks,3*np.ones((t.size,))))

    # TP segment
    # tp = utils.signal.on_off_correction(TPsignal[list(TPsignal)[np.random.randint(0,len(TPsignal))]])
    ids.append(('TPsignal',np.random.randint(0,len(TPsignal))))
    tp = trailonset(TPsignal[list(TPsignal)[ids[-1][1]]],beats[-1][-1])
    # tp = TPsignal[list(TPsignal)[np.random.randint(0,len(TPsignal))]]
    beats.append(tp)
    masks = np.concatenate((masks,np.zeros((tp.size,))))

    size = sum([beats[i].size for i in range(len(beats))])

w = 51
sig = np.concatenate(beats)[onset:onset+2048]
sig = sp.signal.filtfilt(*sp.signal.butter(4,   0.5/250, 'high'),sig)
signal = sig# + np.convolve(np.cumsum(norm.rvs(scale=0.15**(2*0.5),size=N)),np.hamming(w)/(w/2),mode='same')

masks = masks[onset:onset+2048]
masks_all = np.zeros((N,3),dtype=bool)
masks_all[:,0] = masks == 1
masks_all[:,1] = masks == 2
masks_all[:,2] = masks == 3
mskplt = ((np.max(signal)-np.min(signal))*masks_all)+np.min(signal)

plt.figure(figsize=(20,5))
plt.plot(signal)
plt.gca().fill_between(np.arange(N), mskplt[:,0], mskplt[:,0].min(), linewidth=0, alpha=0.15, color='red')
plt.gca().fill_between(np.arange(N), mskplt[:,1], mskplt[:,1].min(), linewidth=0, alpha=0.15, color='green')
plt.gca().fill_between(np.arange(N), mskplt[:,2], mskplt[:,2].min(), linewidth=0, alpha=0.15, color='magenta')
plt.show()

In [None]:
ids

In [None]:
i = 46000
f = TPsignal
print(list(f)[i])
plt.plot(f[list(f)[i]])

# fiducials to delete

* QRS - 111_AVF_24
* QRS - sel820_1_156
* QRS - 95_AVR_16
* ~TP - sel306_1_211~
* ~TP - sel114_0_179~
* PT - sel803_0_109