In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import platform
import random
import uuid
import os
import os.path
import skimage
import utils
import sak.wavelet
import sak.data
import sak.data.augmentation
import numpy as np
import scipy as sp
import scipy.signal
import pandas as pd
import networkx
import networkx.algorithms.approximation
import wfdb
import json
import tqdm
import matplotlib.pyplot as plt
from scipy.stats import norm
from sak.signal import StandardHeader

# Data loader to un-clutter code    
def load_data(filepath):
    dic = dict()
    with open(filepath) as f:
        text = list(f)
    for line in text:
        line = line.replace(' ','').replace('\n','').replace(',,','')
        if line[-1] == ',': line = line[:-1]
        head = line.split(',')[0]
        tail = line.split(',')[1:]
        if tail == ['']:
            tail = np.asarray([])
        else:
            tail = np.asarray(tail).astype(int)

        dic[head] = tail
    return dic


def trailonset(sig,on):
    on = on-sig[0]
    off = on-sig[0]+sig[-1]
    sig = sig+np.linspace(on,off,sig.size)
    
    return sig

def getcorr(segments):
    if len(segments) > 0:
        length = 2*max([seg.size for seg in segments])
    else:
        return np.zeros((0,0))

    corr = np.zeros((len(segments),len(segments)))

    for i,seg1 in enumerate(segments):
        for j,seg2 in enumerate(segments):
            if i != j:
                if seg1.size != seg2.size:
                    if seg1.size != 1:
                        x1 = sp.interpolate.interp1d(np.linspace(0,1,len(seg1)),seg1)(np.linspace(0,1,length))
                    else:
                        x1 = np.full((length,),seg1[0])
                    if seg2.size != 1:
                        x2 = sp.interpolate.interp1d(np.linspace(0,1,len(seg2)),seg2)(np.linspace(0,1,length))
                    else:
                        x2 = np.full((length,),seg2[0])
                else:
                    x1 = seg1
                    x2 = seg2
                if (x1.size == 1) and (x2.size == 1):
                    corr[i,j] = 1
                else:
                    c,_ = sak.signal.xcorr(x1,x2)
                    corr[i,j] = np.max(np.abs(c))
            else:
                corr[i,j] = 1
                
    return corr

def order_fiducials(qrs, other, mode='smaller'):
    # just in case
    other = np.copy(other)

    if mode=='smaller':
        filt = (other[None,:] < qrs[:,None])
    elif mode=='bigger':
        filt = (qrs[:,None] < other[None,:])
    else:
        raise ValueError("Mode must be {smaller, bigger}")
        
    if other.size == 0:
        return np.full_like(qrs,-1,dtype=int)
        
    for i in range(qrs.size):
        if i >= filt.shape[1]:
            newcol = np.zeros((qrs.size,),dtype=bool)
            if mode=='smaller':
                newcol[i:] = True
            elif mode=='bigger':
                newcol[:i+1] = True
            other = np.insert(other,i,-1)
            filt = np.insert(filt,i,newcol,axis=1)
        else:
            if mode=='smaller':
                j = np.nonzero(filt[:,i])[0][0]
            elif mode=='bigger':
                j = qrs.size-np.nonzero(np.flip(filt[:,i]))[0][0]-1

            if i != j:
                newcol = np.zeros((qrs.size,),dtype=bool)
                if mode=='smaller':
                    newcol[i:] = True
                elif mode=='bigger':
                    newcol[:i+1] = True
                other = np.insert(other,i,-1)
                filt = np.insert(filt,i,newcol,axis=1)

    return other

# Common operations

In [3]:
# Set base directory
if platform.system() in ['Linux', 'Linux2']:
    basedir = '/media/guille/DADES/DADES/Delineator'
else:
    basedir = r'C:\Users\Emilio\Documents\DADES\DADES\Delineator'
    
# Output data structures
Psignal = {}
PQsignal = {}
QRSsignal = {}
STsignal = {}
Tsignal = {}
TPsignal = {}

Pamplitudes = {}
PQamplitudes = {}
QRSamplitudes = {}
STamplitudes = {}
Tamplitudes = {}
TPamplitudes = {}

# Filter out samples
threshold = 0.99

# Load LUDB

In [4]:
# Load LUDB's manual delineations
P_LUDB   = sak.load_data(os.path.join(basedir,'ludb','P.csv'))
QRS_LUDB = sak.load_data(os.path.join(basedir,'ludb','QRS.csv'))
T_LUDB   = sak.load_data(os.path.join(basedir,'ludb','T.csv'))

# Halve to accomodate for 250hz sampling rate
P_LUDB   = {k: P_LUDB[k]//2   for k in P_LUDB}
QRS_LUDB = {k: QRS_LUDB[k]//2 for k in QRS_LUDB}
T_LUDB   = {k: T_LUDB[k]//2   for k in T_LUDB}

# Retrieve onsets
Pon_LUDB    = {k: P_LUDB[k][::2]    for k in P_LUDB}
QRSon_LUDB  = {k: QRS_LUDB[k][::2]  for k in QRS_LUDB}
Ton_LUDB    = {k: T_LUDB[k][::2]    for k in T_LUDB}

# Retrieve offsets
Poff_LUDB   = {k: P_LUDB[k][1::2]   for k in P_LUDB}
QRSoff_LUDB = {k: QRS_LUDB[k][1::2] for k in QRS_LUDB}
Toff_LUDB   = {k: T_LUDB[k][1::2]   for k in T_LUDB}

# Metrics for normalization
metric_intralead = np.max
metric_amplitude = sak.signal.abs_max

# Iterate over signals
for id_file in tqdm.tqdm(range(1,201)):
    if '{}###I'.format(id_file) not in  QRSon_LUDB: continue
    # Load singal
    (signal, header) = wfdb.rdsamp(os.path.join(basedir,'ludb','{}'.format(id_file)))
    
    # Obtain segmentation code
    seg_code = str(id_file)+'###I'
    
    # Sanity check for downsampling
    if header['fs'] != 500:
        print(header['fs'])

    # 0. Order leads as standard header
    sig_name = sak.map_upper(header['sig_name'])
    sort = sak.argsort_as(sig_name, StandardHeader)
    signal = signal[:,sort]
    
    # 1. Downsample and filter
    signal = sp.signal.decimate(signal,2,axis=0)
    signal = sp.signal.filtfilt(*sp.signal.butter(4,   0.5/250., 'high'),signal.T).T
    signal = sp.signal.filtfilt(*sp.signal.butter(4, 125.0/250.,  'low'),signal.T).T
    
    # 2. Get segmentation centers and locate w.r.t. QRS wave
    # Reference, qrs wave
    qrs_on  = QRSon_LUDB.get(seg_code,[]).astype(int)
    qrs_center = (QRS_LUDB.get(seg_code,[])[::2] + QRS_LUDB.get(seg_code,[])[1::2])//2 # Reference
    qrs_off = QRSoff_LUDB.get(seg_code,[]).astype(int)
    # p wave
    p_on    = order_fiducials(qrs_center,Pon_LUDB.get(seg_code,[]),mode='smaller').astype(int)
    p_off   = order_fiducials(qrs_center,Poff_LUDB.get(seg_code,[]),mode='smaller').astype(int)
    # t wave
    t_on    = order_fiducials(qrs_center,Ton_LUDB.get(seg_code,[]),mode='bigger').astype(int)
    t_off   = order_fiducials(qrs_center,Toff_LUDB.get(seg_code,[]),mode='bigger').astype(int)
    
    # 3. Inter-lead: Compute the QRS amplitude w.r.t. other QRS in signal
    normalizing_factor = -np.inf
    for i,lead in enumerate(StandardHeader):
        for j in range(qrs_center.size):
            qrs_segment = sak.signal.on_off_correction(signal[qrs_on[j]:qrs_off[j],i])
            qrs_amplitude = metric_amplitude(qrs_segment)
            if qrs_amplitude > normalizing_factor:
                normalizing_factor = qrs_amplitude
                
    # 4. Per-lead: 1) on/off correct, 2) crop and 3) find relative amplitude
    for i,lead in enumerate(StandardHeader):
        for j in range(qrs_center.size):
            # Unique code
            segment_code = '{}_{}###{}'.format(id_file,lead,j)
            
            # Segments
            p_segment   = signal[p_on[j]:p_off[j],i]
            qrs_segment = signal[qrs_on[j]:qrs_off[j],i]
            t_segment   = signal[t_on[j]:t_off[j],i]
            if p_segment.size != 0:   p_segment   = sak.signal.on_off_correction(p_segment)
            if qrs_segment.size != 0: qrs_segment = sak.signal.on_off_correction(qrs_segment)
            if t_segment.size != 0:   t_segment   = sak.signal.on_off_correction(t_segment)
            
            pq_segment  = signal[p_off[j]:qrs_on[j],i] if p_off[j] != -1 else np.array([])
            st_segment  = signal[qrs_off[j]:t_on[j],i] if (t_on[j] != -1) else np.array([])
            if j in range(qrs_center.size-1):
                if (p_on[j+1] != -1):
                    if (t_off[j] != -1): tp_segment = signal[t_off[j]:p_on[j+1],i]
                    else:                tp_segment = signal[qrs_off[j]:p_on[j+1],i]
                elif (t_off[j] != -1):   tp_segment = signal[t_off[j]:qrs_on[j+1],i]
                else:                    tp_segment = np.array([])
            
            # Amplitudes calculation
            if p_segment.size != 0:   Pamplitudes[segment_code]   = metric_amplitude(p_segment)/metric_amplitude(qrs_segment)
            if qrs_segment.size != 0: QRSamplitudes[segment_code] = metric_amplitude(qrs_segment)/normalizing_factor
            if t_segment.size != 0:   Tamplitudes[segment_code]   = metric_amplitude(t_segment)/metric_amplitude(qrs_segment)
            if pq_segment.size != 0:  PQamplitudes[segment_code]  = metric_amplitude(pq_segment)/metric_amplitude(qrs_segment)
            if st_segment.size != 0:  STamplitudes[segment_code]  = metric_amplitude(st_segment)/metric_amplitude(qrs_segment)
            if tp_segment.size != 0:  TPamplitudes[segment_code]  = metric_amplitude(tp_segment)/metric_amplitude(qrs_segment)
            
            # Normalize segments
            if p_segment.size != 0:   p_segment = sak.data.ball_scaling(p_segment, metric=sak.signal.abs_max)
            if qrs_segment.size != 0: qrs_segment = sak.data.ball_scaling(qrs_segment, metric=sak.signal.abs_max)
            if t_segment.size != 0:   t_segment = sak.data.ball_scaling(t_segment, metric=sak.signal.abs_max)
            if pq_segment.size != 0:  pq_segment = sak.data.ball_scaling(pq_segment, metric=sak.signal.abs_max)
            if st_segment.size != 0:  st_segment = sak.data.ball_scaling(st_segment, metric=sak.signal.abs_max)
            if tp_segment.size != 0:  tp_segment = sak.data.ball_scaling(tp_segment, metric=sak.signal.abs_max)
            
            # Store signals
            if p_segment.size != 0:   Psignal[segment_code] = p_segment
            if pq_segment.size != 0:  PQsignal[segment_code] = pq_segment
            if qrs_segment.size != 0: QRSsignal[segment_code] = qrs_segment
            if st_segment.size != 0:  STsignal[segment_code] = st_segment
            if t_segment.size != 0:   Tsignal[segment_code] = t_segment
            if tp_segment.size != 0:  TPsignal[segment_code] = tp_segment


100%|██████████| 200/200 [00:03<00:00, 63.31it/s]


In [None]:
# Filter signals by length
for k in list(TPsignal.keys()):
    if isinstance(TPsignal[k],float):
        TPamplitudes.pop(k)
        TPsignal.pop(k)
    elif len(TPsignal[k]) > 250:
        TPamplitudes.pop(k)
        TPsignal.pop(k)

In [5]:
print(len(Psignal))
print(len(PQsignal))
print(len(QRSsignal))
print(len(STsignal))
print(len(Tsignal))
print(len(TPsignal))

6336
6324
7332
7284
7284
7260


In [6]:
# Filter out redundant segments
for id_file in tqdm.tqdm(range(1,201)):
    # P signal
    p_ids = [a for a in list(Psignal) if a.startswith("{}_".format(id_file))]
    p_segments = [Psignal[id] for id in p_ids]
    corr = getcorr(p_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(p_ids) if i not in nodesclique]
    for seg in segments_exclude:
        Pamplitudes.pop(seg)
        Psignal.pop(seg)
        
    # QRS signal
    qrs_ids = [a for a in list(QRSsignal) if a.startswith("{}_".format(id_file))]
    qrs_segments = [QRSsignal[id] for id in qrs_ids]
    corr = getcorr(qrs_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(qrs_ids) if i not in nodesclique]
    for seg in segments_exclude:
        QRSamplitudes.pop(seg)
        QRSsignal.pop(seg)
        
    # T signal
    t_ids = [a for a in list(Tsignal) if a.startswith("{}_".format(id_file))]
    t_segments = [Tsignal[id] for id in t_ids]
    corr = getcorr(t_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(t_ids) if i not in nodesclique]
    for seg in segments_exclude:
        Tamplitudes.pop(seg)
        Tsignal.pop(seg)
        
    # PQ signal
    pq_ids = [a for a in list(PQsignal) if a.startswith("{}_".format(id_file))]
    pq_segments = [PQsignal[id] for id in pq_ids]
    corr = getcorr(pq_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(pq_ids) if i not in nodesclique]
    for seg in segments_exclude:
        PQamplitudes.pop(seg)
        PQsignal.pop(seg)
        
    # ST signal
    st_ids = [a for a in list(STsignal) if a.startswith("{}_".format(id_file))]
    st_segments = [STsignal[id] for id in st_ids]
    corr = getcorr(st_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(st_ids) if i not in nodesclique]
    for seg in segments_exclude:
        STamplitudes.pop(seg)
        STsignal.pop(seg)
        
    # TP signal
    tp_ids = [a for a in list(TPsignal) if a.startswith("{}_".format(id_file))]
    tp_segments = [TPsignal[id] for id in tp_ids]
    corr = getcorr(tp_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(tp_ids) if i not in nodesclique]
    for seg in segments_exclude:
        TPamplitudes.pop(seg)
        TPsignal.pop(seg)


100%|██████████| 200/200 [04:15<00:00,  1.28s/it]


In [7]:
print(len(Psignal))
print(len(PQsignal))
print(len(QRSsignal))
print(len(STsignal))
print(len(Tsignal))
print(len(TPsignal))

3923
2541
3276
4230
2494
3860


#### OLD AMOUNTS OF SEGMENTS

P:   13295
PQ:  8875
QRS: 10953
ST:  13167
T:   12267
TP:  15949

# Load QT db

In [8]:
#### LOAD DATASETS ####
dataset             = pd.read_csv(os.path.join(basedir,'QTDB','Dataset.csv'), index_col=0)
dataset             = dataset.sort_index(axis=1)
labels              = np.asarray(list(dataset)) # In case no data augmentation is applied
description         = dataset.describe()
group               = {k: '_'.join(k.split('_')[:-1]) for k in dataset}
unique_ids          = list(set([k.split('_')[0] for k in dataset]))

# Load fiducials
Pon_QTDB = load_data(os.path.join(basedir,'QTDB','PonNew.csv'))
Poff_QTDB = load_data(os.path.join(basedir,'QTDB','PoffNew.csv'))
QRSon_QTDB = load_data(os.path.join(basedir,'QTDB','QRSonNew.csv'))
QRSoff_QTDB = load_data(os.path.join(basedir,'QTDB','QRSoffNew.csv'))
Ton_QTDB = load_data(os.path.join(basedir,'QTDB','TonNew.csv'))
Toff_QTDB = load_data(os.path.join(basedir,'QTDB','ToffNew.csv'))

for i,id_file in enumerate(tqdm.tqdm(unique_ids)):
    # Buggy files
    if id_file in ['sel35','sel36','sel103','sel232','sel310']: continue

    # Load singal
    signal = np.vstack((dataset[id_file+'_0'],dataset[id_file+'_1'])).T
    
    # Obtain segmentation code
    seg_code = '{}_0'.format(id_file)
    
    # 1. Downsample and filter
    signal = sp.signal.filtfilt(*sp.signal.butter(4,   0.5/250., 'high'),signal.T).T
    signal = sp.signal.filtfilt(*sp.signal.butter(4, 125.0/250.,  'low'),signal.T).T
    
    # 2. Get segmentation centers and locate w.r.t. QRS wave
    # Reference, qrs wave
    qrs_on  = QRSon_QTDB.get(seg_code,[]).astype(int)
    qrs_off = QRSoff_QTDB.get(seg_code,[]).astype(int)
    qrs_center = (qrs_on+qrs_off)//2
    # p wave
    p_on    = order_fiducials(qrs_center,Pon_QTDB.get(seg_code,[]),mode='smaller').astype(int)
    p_off   = order_fiducials(qrs_center,Poff_QTDB.get(seg_code,[]),mode='smaller').astype(int)
    # t wave
    t_on    = order_fiducials(qrs_center,Ton_QTDB.get(seg_code,[]),mode='bigger').astype(int)
    t_off   = order_fiducials(qrs_center,Toff_QTDB.get(seg_code,[]),mode='bigger').astype(int)
    
    # 3. Inter-lead: Compute the QRS amplitude w.r.t. other QRS in signal
    normalizing_factor = -np.inf
    for i in range(2):
        for j in range(qrs_center.size):
            qrs_segment = sak.signal.on_off_correction(signal[qrs_on[j]:qrs_off[j],i])
            qrs_amplitude = metric_amplitude(qrs_segment)
            if qrs_amplitude > normalizing_factor:
                normalizing_factor = qrs_amplitude
                
    # 4. Per-lead: 1) on/off correct, 2) crop and 3) find relative amplitude
    for i in range(2):
        for j in range(qrs_center.size):
            # Unique code
            segment_code = '{}_{}###{}'.format(id_file,i,j)
            
            # Segments
            p_segment   = signal[p_on[j]:p_off[j],i]
            qrs_segment = signal[qrs_on[j]:qrs_off[j],i]
            t_segment   = signal[t_on[j]:t_off[j],i]
            if p_segment.size != 0:   p_segment   = sak.signal.on_off_correction(p_segment)
            if qrs_segment.size != 0: qrs_segment = sak.signal.on_off_correction(qrs_segment)
            if t_segment.size != 0:   t_segment   = sak.signal.on_off_correction(t_segment)
                
            if qrs_segment.size < 10: aghkjsghkj
            
            pq_segment  = signal[p_off[j]:qrs_on[j],i] if p_off[j] != -1 else np.array([])
            st_segment  = signal[qrs_off[j]:t_on[j],i] if (t_on[j] != -1) else np.array([])
            if j in range(qrs_center.size-1):
                if (p_on[j+1] != -1):
                    if (t_off[j] != -1): tp_segment = signal[t_off[j]:p_on[j+1],i]
                    else:                tp_segment = signal[qrs_off[j]:p_on[j+1],i]
                elif (t_off[j] != -1):   tp_segment = signal[t_off[j]:qrs_on[j+1],i]
                else:                    tp_segment = np.array([])
            
            # Amplitudes calculation
            if p_segment.size != 0:   Pamplitudes[segment_code]   = metric_amplitude(p_segment)/metric_amplitude(qrs_segment)
            if qrs_segment.size != 0: QRSamplitudes[segment_code] = metric_amplitude(qrs_segment)/normalizing_factor
            if t_segment.size != 0:   Tamplitudes[segment_code]   = metric_amplitude(t_segment)/metric_amplitude(qrs_segment)
            if pq_segment.size != 0:  PQamplitudes[segment_code]  = metric_amplitude(pq_segment)/metric_amplitude(qrs_segment)
            if st_segment.size != 0:  STamplitudes[segment_code]  = metric_amplitude(st_segment)/metric_amplitude(qrs_segment)
            if tp_segment.size != 0:  TPamplitudes[segment_code]  = metric_amplitude(tp_segment)/metric_amplitude(qrs_segment)
            
            # Normalize segments
            if p_segment.size != 0:   p_segment = sak.data.ball_scaling(p_segment, metric=sak.signal.abs_max)
            if qrs_segment.size != 0: qrs_segment = sak.data.ball_scaling(qrs_segment, metric=sak.signal.abs_max)
            if t_segment.size != 0:   t_segment = sak.data.ball_scaling(t_segment, metric=sak.signal.abs_max)
            if pq_segment.size != 0:  pq_segment = sak.data.ball_scaling(pq_segment, metric=sak.signal.abs_max)
            if st_segment.size != 0:  st_segment = sak.data.ball_scaling(st_segment, metric=sak.signal.abs_max)
            if tp_segment.size != 0:  tp_segment = sak.data.ball_scaling(tp_segment, metric=sak.signal.abs_max)
            
            # Store signals
            if p_segment.size != 0:   Psignal[segment_code] = p_segment
            if pq_segment.size != 0:  PQsignal[segment_code] = pq_segment
            if qrs_segment.size != 0: QRSsignal[segment_code] = qrs_segment
            if st_segment.size != 0:  STsignal[segment_code] = st_segment
            if t_segment.size != 0:   Tsignal[segment_code] = t_segment
            if tp_segment.size != 0:  TPsignal[segment_code] = tp_segment

100%|██████████| 105/105 [00:03<00:00, 30.67it/s]


In [9]:
# Filter signals by length
for k in list(TPsignal.keys()):
    if isinstance(TPsignal[k],float):
        TPamplitudes.pop(k)
        TPsignal.pop(k)
    elif len(TPsignal[k]) > 250:
        TPamplitudes.pop(k)
        TPsignal.pop(k)

In [10]:
print(len(Psignal))
print(len(PQsignal))
print(len(QRSsignal))
print(len(STsignal))
print(len(Tsignal))
print(len(TPsignal))

9701
8313
9692
10226
8642
10152


In [12]:
# Filter out redundant segments
for i,id_file in enumerate(tqdm.tqdm(unique_ids)):
    # P signal
    p_ids = [a for a in list(Psignal) if a.startswith("{}_".format(id_file))]
    p_segments = [Psignal[id] for id in p_ids]
    corr = getcorr(p_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(p_ids) if i not in nodesclique]
    for seg in segments_exclude:
        Pamplitudes.pop(seg)
        Psignal.pop(seg)
        
    # QRS signal
    qrs_ids = [a for a in list(QRSsignal) if a.startswith("{}_".format(id_file))]
    qrs_segments = [QRSsignal[id] for id in qrs_ids]
    corr = getcorr(qrs_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(qrs_ids) if i not in nodesclique]
    for seg in segments_exclude:
        QRSamplitudes.pop(seg)
        QRSsignal.pop(seg)
        
    # T signal
    t_ids = [a for a in list(Tsignal) if a.startswith("{}_".format(id_file))]
    t_segments = [Tsignal[id] for id in t_ids]
    corr = getcorr(t_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(t_ids) if i not in nodesclique]
    for seg in segments_exclude:
        Tamplitudes.pop(seg)
        Tsignal.pop(seg)
        
    # PQ signal
    pq_ids = [a for a in list(PQsignal) if a.startswith("{}_".format(id_file))]
    pq_segments = [PQsignal[id] for id in pq_ids]
    corr = getcorr(pq_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(pq_ids) if i not in nodesclique]
    for seg in segments_exclude:
        PQamplitudes.pop(seg)
        PQsignal.pop(seg)
        
    # ST signal
    st_ids = [a for a in list(STsignal) if a.startswith("{}_".format(id_file))]
    st_segments = [STsignal[id] for id in st_ids]
    corr = getcorr(st_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(st_ids) if i not in nodesclique]
    for seg in segments_exclude:
        STamplitudes.pop(seg)
        STsignal.pop(seg)
        
    # TP signal
    tp_ids = [a for a in list(TPsignal) if a.startswith("{}_".format(id_file))]
    tp_segments = [TPsignal[id] for id in tp_ids]
    corr = getcorr(tp_segments)
    g = networkx.convert_matrix.from_numpy_matrix(corr < threshold)
    nodesclique = networkx.algorithms.approximation.max_clique(g)
    segments_exclude = [p for i,p in enumerate(tp_ids) if i not in nodesclique]
    for seg in segments_exclude:
        TPamplitudes.pop(seg)
        TPsignal.pop(seg)


100%|██████████| 105/105 [10:04<00:00,  5.76s/it]


In [22]:
print(len(Psignal))
print(len(PQsignal))
print(len(QRSsignal))
print(len(STsignal))
print(len(Tsignal))
print(len(TPsignal))

6885
4770
5662
6592
4009
8105


# Load VT

In [30]:
Files = os.listdir(os.path.join(basedir,'SoO','RETAG'))
Files = [os.path.splitext(f)[0] for f in Files if os.path.splitext(f)[1] == '.txt']
Segmentations = pd.read_csv(os.path.join(basedir,'SoO','SEGMENTATIONS.csv'),index_col=0,header=None).T
Keys = Segmentations.keys().tolist()
Keys = [k for k in Keys if '-'.join(k.split('-')[:2]) in Files]
database = pd.read_csv(os.path.join(basedir,'SoO','DATABASE_MANUAL.csv'))

# Data storage
QRSsignalSoO = dict()
QRSgroupSoO = dict()

for k in tqdm.tqdm(Keys):
    # Retrieve general information
    fname = '-'.join(k.split('-')[:2]) + '.txt'
    ID = int(k.split('-')[0])
    
    # Read signal and segmentation
    Signal = pd.read_csv(os.path.join(basedir,'SoO','RETAG',fname),index_col=0).values
    (son,soff) = Segmentations[k]
    fs = database['Sampling_Freq'][database['ID'] == int(ID)].values[0]
    
    # Check correct segmentation
    if son > soff:
        print("(!!!) Check file   {:>10s} has onset ({:d}) > offset ({:d})".format(k, son, soff))
        continue

    # Up/downsample to 1000 Hz
    factor = int(fs/250)
    Signal = np.round(sp.signal.decimate(Signal.T, factor)).T
    fs = fs/factor
    son = int(son/factor)
    soff = int(soff/factor)
    
    # Filter baseline wander and high freq. noise
    Signal = sp.signal.filtfilt(*sp.signal.butter(4,   0.5/fs, 'high'),Signal.T).T
    Signal = sp.signal.filtfilt(*sp.signal.butter(4, 125.0/fs,  'low'),Signal.T).T
    
    # 3. Inter-lead: Compute the QRS amplitude w.r.t. other QRS in signal
    normalizing_factor = -np.inf
    for i,lead in enumerate(StandardHeader):
        qrs_segment = sak.signal.on_off_correction(Signal[son:soff,i])
        qrs_amplitude = metric_amplitude(qrs_segment)
        if qrs_amplitude > normalizing_factor:
            normalizing_factor = qrs_amplitude

    for i,lead in enumerate(StandardHeader):
        qrs_segment = sak.signal.on_off_correction(Signal[son:soff,i])
        QRSamplitudes['SOO{}_{}###{}'.format(k,lead,0)] = metric_amplitude(qrs_segment)/normalizing_factor
        qrs_segment = sak.data.ball_scaling(qrs_segment, metric=sak.signal.abs_max)
        QRSsignal['SOO{}_{}###{}'.format(k,lead,0)] = qrs_segment

100%|██████████| 288/288 [00:15<00:00, 18.79it/s]


In [31]:
print(len(Psignal))
print(len(PQsignal))
print(len(QRSsignal))
print(len(STsignal))
print(len(Tsignal))
print(len(TPsignal))

6885
4770
7846
6592
4009
8105


# Delete too short or too long signals

In [None]:
# Signal lengths
Plength = {k: len(Psignal[k]) for k in Psignal.keys() if not isinstance(Psignal[k],float)}
PQlength = {k: len(PQsignal[k]) for k in PQsignal.keys() if not isinstance(PQsignal[k],float)}
QRSlength = {k: len(QRSsignal[k]) for k in QRSsignal.keys() if not isinstance(QRSsignal[k],float)}
STlength = {k: len(STsignal[k]) for k in STsignal.keys() if not isinstance(STsignal[k],float)}
Tlength = {k: len(Tsignal[k]) for k in Tsignal.keys() if not isinstance(Tsignal[k],float)}
TPlength = {k: len(TPsignal[k]) for k in TPsignal.keys() if not isinstance(TPsignal[k],float)}

In [None]:
# Filter signals by length
for k in list(Psignal.keys()):
    if isinstance(Psignal[k],float):                                Psignal.pop(k); Pamplitudes.pop(k)
    elif not ((len(Psignal[k]) > 2) and (len(Psignal[k]) < 45)):    Psignal.pop(k); Pamplitudes.pop(k)
for k in list(PQsignal.keys()):
    if isinstance(PQsignal[k],float):                               PQsignal.pop(k); PQamplitudes.pop(k)
    elif not ((len(PQsignal[k]) > 1) and (len(PQsignal[k]) < 35)):  PQsignal.pop(k); PQamplitudes.pop(k)
for k in list(QRSsignal.keys()):
    if isinstance(QRSsignal[k],float):                              QRSsignal.pop(k); QRSamplitudes.pop(k)
    elif not ((len(QRSsignal[k]) > 10)):                            QRSsignal.pop(k); QRSamplitudes.pop(k)
for k in list(STsignal.keys()):
    if isinstance(STsignal[k],float):                               STsignal.pop(k); STamplitudes.pop(k)
    elif not ((len(STsignal[k]) > 1) and (len(STsignal[k]) < 65)):  STsignal.pop(k); STamplitudes.pop(k)
for k in list(Tsignal.keys()):
    if isinstance(Tsignal[k],float):                                Tsignal.pop(k); Tamplitudes.pop(k)
    elif not ((len(Tsignal[k]) > 10) and (len(Tsignal[k]) < 100)):  Tsignal.pop(k); Tamplitudes.pop(k)
for k in list(TPsignal.keys()):
    if isinstance(TPsignal[k],float):                               TPsignal.pop(k); TPamplitudes.pop(k)
    elif not ((len(TPsignal[k]) > 2) and (len(TPsignal[k]) < 250)): TPsignal.pop(k); TPamplitudes.pop(k)

In [None]:
print(len(Psignal))
print(len(PQsignal))
print(len(QRSsignal))
print(len(STsignal))
print(len(Tsignal))
print(len(TPsignal))

# Save files

In [44]:
sak.pickledump(Psignal,os.path.join('.','pickle','Psignal_new.pkl'))
sak.pickledump(Pamplitudes,os.path.join('.','pickle','Pamplitudes_new.pkl'))
sak.pickledump(PQsignal,os.path.join('.','pickle','PQsignal_new.pkl'))
sak.pickledump(PQamplitudes,os.path.join('.','pickle','PQamplitudes_new.pkl'))
sak.pickledump(QRSsignal,os.path.join('.','pickle','QRSsignal_new.pkl'))
sak.pickledump(QRSamplitudes,os.path.join('.','pickle','QRSamplitudes_new.pkl'))
sak.pickledump(STsignal,os.path.join('.','pickle','STsignal_new.pkl'))
sak.pickledump(STamplitudes,os.path.join('.','pickle','STamplitudes_new.pkl'))
sak.pickledump(Tsignal,os.path.join('.','pickle','Tsignal_new.pkl'))
sak.pickledump(Tamplitudes,os.path.join('.','pickle','Tamplitudes_new.pkl'))
sak.pickledump(TPsignal,os.path.join('.','pickle','TPsignal_new.pkl'))
sak.pickledump(TPamplitudes,os.path.join('.','pickle','TPamplitudes_new.pkl'))

# CHECK FILES

In [None]:
# def getcorr_old(segments):
#     if len(segments) > 0:
#         length = 2*max([segments[i][2].size for i in range(len(segments))])
#     else:
#         return np.zeros((0,0))
# 
#     corr = np.zeros((len(segments),len(segments)))
# 
#     for i in range(len(segments)):
#         for j in range(len(segments)):
#             if i != j:
#                 if segments[i][2].size != segments[j][2].size:
#                     if segments[i][2].size != 1:
#                         x1 = sp.interpolate.interp1d(np.linspace(0,1,len(segments[i][2])),segments[i][2])(np.linspace(0,1,length))
#                     else:
#                         x1 = np.full((length,),segments[i][2][0])
#                     if segments[j][2].size != 1:
#                         x2 = sp.interpolate.interp1d(np.linspace(0,1,len(segments[j][2])),segments[j][2])(np.linspace(0,1,length))
#                     else:
#                         x2 = np.full((length,),segments[j][2][0])
#                 else:
#                     x1 = segments[i][2]
#                     x2 = segments[j][2]
#                 if (x1.size == 1) and (x2.size == 1):
#                     corr[i,j] = 1
#                 else:
#                     c,_ = sak.signal.xcorr(x1,x2)
#                     corr[i,j] = np.max(np.abs(c))
#             else:
#                 corr[i,j] = 1
# 
#     return corr

In [None]:
# # #### LOAD DATASETS ####
# # dataset             = pd.read_csv(os.path.join(basedir,'QTDB','Dataset.csv'), index_col=0)
# # dataset             = dataset.sort_index(axis=1)
# # labels              = np.asarray(list(dataset)) # In case no data augmentation is applied
# # description         = dataset.describe()
# # group               = {k: '_'.join(k.split('_')[:-1]) for k in dataset}
# # unique_ids          = list(set([k.split('_')[0] for k in dataset]))

# # Load fiducials
# Pon_QTDB = load_data(os.path.join(basedir,'QTDB','PonNew.csv'))
# Poff_QTDB = load_data(os.path.join(basedir,'QTDB','PoffNew.csv'))
# QRSon_QTDB = load_data(os.path.join(basedir,'QTDB','QRSonNew.csv'))
# QRSoff_QTDB = load_data(os.path.join(basedir,'QTDB','QRSoffNew.csv'))
# Ton_QTDB = load_data(os.path.join(basedir,'QTDB','TonNew.csv'))
# Toff_QTDB = load_data(os.path.join(basedir,'QTDB','ToffNew.csv'))

# path_SVG = '/home/guille/Escritorio/QTDB/'
# f,ax = plt.subplots(2,1,figsize=(40,6))
# for id_file in unique_ids:
#     if id_file in ['sel35','sel36','sel232']: continue
#     if id_file != 'sele0406': continue
#     w = 100
#     id0 = id_file + '_0'
#     id1 = id_file + '_1'
#     ax[0].plot(dataset[id0][QRSon_QTDB[id0][0]-w:QRSoff_QTDB[id0][-1]+w])
#     ax[0].set_xlim([QRSon_QTDB[id0][0]-w,QRSoff_QTDB[id0][-1]+w])
#     if id0 in Pon_QTDB:
#         [ax[0].axvspan(Pon_QTDB[id0][j],Poff_QTDB[id0][j], color='r', alpha=0.15) for j in range(len(Pon_QTDB[id0]))]
#     [ax[0].axvspan(QRSon_QTDB[id0][j],QRSoff_QTDB[id0][j], color='g', alpha=0.15) for j in range(len(QRSon_QTDB[id0]))]
#     [ax[0].axvspan(Ton_QTDB[id0][j],Toff_QTDB[id0][j], color='m', alpha=0.15) for j in range(len(Ton_QTDB[id0]))]
#     ax[1].plot(dataset[id1][QRSon_QTDB[id1][0]-w:QRSoff_QTDB[id1][-1]+w])
#     ax[1].set_xlim([QRSon_QTDB[id1][0]-w,QRSoff_QTDB[id1][-1]+w])
#     if id1 in Pon_QTDB:
#         [ax[1].axvspan(Pon_QTDB[id1][j],Poff_QTDB[id1][j], color='r', alpha=0.15) for j in range(len(Pon_QTDB[id1]))]
#     [ax[1].axvspan(QRSon_QTDB[id1][j],QRSoff_QTDB[id1][j], color='g', alpha=0.15) for j in range(len(QRSon_QTDB[id1]))]
#     [ax[1].axvspan(Ton_QTDB[id1][j],Toff_QTDB[id1][j], color='m', alpha=0.15) for j in range(len(Ton_QTDB[id1]))]
#     f.tight_layout()
#     f.subplots_adjust(hspace=0.00,wspace=0.05)
#     f.savefig(os.path.join(path_SVG,id_file+'.svg'))
#     [ax[i].clear() for i in range(2)]


In [None]:
# print(seg_code)
# seg_code = seg_code.split('_')[0] + '_1'
# print(j)
# print(QRSon_QTDB[seg_code][j])
# print(QRSoff_QTDB[seg_code][j])

# plt.figure()
# plt.plot(dataset[seg_code][QRSon_QTDB[seg_code][j]-100:QRSon_QTDB[seg_code][j]+100])
# plt.gca().axvspan(xmin=QRSon_QTDB[seg_code][j],xmax=QRSoff_QTDB[seg_code][j],alpha=0.15)
# seg_code = seg_code.split('_')[0] + '_1'
# plt.figure()
# plt.plot(dataset[seg_code][QRSon_QTDB[seg_code][j]-100:QRSon_QTDB[seg_code][j]+100])
# plt.gca().axvspan(xmin=QRSon_QTDB[seg_code][j],xmax=QRSoff_QTDB[seg_code][j],alpha=0.15)

In [None]:
# Pon_QTDB = load_data(os.path.join(basedir,'QTDB','PonNew.csv'))
# Poff_QTDB = load_data(os.path.join(basedir,'QTDB','PoffNew.csv'))
# QRSon_QTDB = load_data(os.path.join(basedir,'QTDB','QRSonNew.csv'))
# QRSoff_QTDB = load_data(os.path.join(basedir,'QTDB','QRSoffNew.csv'))
# Ton_QTDB = load_data(os.path.join(basedir,'QTDB','TonNew.csv'))
# Toff_QTDB = load_data(os.path.join(basedir,'QTDB','ToffNew.csv'))

# seg_code = 'sele0124_1'
# j = 48
# print(QRSon_QTDB[seg_code][j])
# print(QRSoff_QTDB[seg_code][j])
# print("")


# if 0:
#     mod_on = -7
#     mod_off = 9

#     QRSon_QTDB[seg_code.split('_')[0] + '_0'][j]  += mod_on
#     QRSoff_QTDB[seg_code.split('_')[0] + '_0'][j] += mod_off
#     QRSon_QTDB[seg_code.split('_')[0] + '_1'][j]  += mod_on
#     QRSoff_QTDB[seg_code.split('_')[0] + '_1'][j] += mod_off

#     print(QRSon_QTDB[seg_code][j])
#     print(QRSoff_QTDB[seg_code][j])


# plt.figure()
# plt.plot(dataset[seg_code][QRSon_QTDB[seg_code][j]-100:QRSon_QTDB[seg_code][j]+100])
# plt.gca().axvspan(xmin=QRSon_QTDB[seg_code][j],xmax=QRSoff_QTDB[seg_code][j],alpha=0.15)
# seg_code = seg_code.split('_')[0] + '_1'
# plt.figure()
# plt.plot(dataset[seg_code][QRSon_QTDB[seg_code][j]-100:QRSon_QTDB[seg_code][j]+100])
# plt.gca().axvspan(xmin=QRSon_QTDB[seg_code][j],xmax=QRSoff_QTDB[seg_code][j],alpha=0.15)