In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import platform
import random
import uuid
import math
import os
import os.path
import skimage
import sak
import sak.wavelet
import sak.data
import sak.data.augmentation
import sak.data.preprocessing
import numpy as np
import scipy as sp
import scipy.signal
import pandas as pd
import networkx
import networkx.algorithms.approximation
import wfdb
import json
import tqdm
import matplotlib.pyplot as plt
from scipy.stats import norm
from sak.signal import StandardHeader

# Data loader to un-clutter code    
def load_data(filepath):
    dic = dict()
    with open(filepath) as f:
        text = list(f)
    for line in text:
        line = line.replace(' ','').replace('\n','').replace(',,','')
        if line[-1] == ',': line = line[:-1]
        head = line.split(',')[0]
        tail = line.split(',')[1:]
        if tail == ['']:
            tail = np.asarray([])
        else:
            tail = np.asarray(tail).astype(int)

        dic[head] = tail
    return dic


def trailonset(sig,on):
    on = on-sig[0]
    off = on-sig[0]+sig[-1]
    sig = sig+np.linspace(on,off,sig.size)
    
    return sig

def getcorr(segments):
    if len(segments) > 0:
        length = 2*max([seg.size for seg in segments])
    else:
        return np.zeros((0,0))

    corr = np.zeros((len(segments),len(segments)))

    for i,seg1 in enumerate(segments):
        for j,seg2 in enumerate(segments):
            if i != j:
                if seg1.size != seg2.size:
                    if seg1.size != 1:
                        x1 = sp.interpolate.interp1d(np.linspace(0,1,len(seg1)),seg1)(np.linspace(0,1,length))
                    else:
                        x1 = np.full((length,),seg1[0])
                    if seg2.size != 1:
                        x2 = sp.interpolate.interp1d(np.linspace(0,1,len(seg2)),seg2)(np.linspace(0,1,length))
                    else:
                        x2 = np.full((length,),seg2[0])
                else:
                    x1 = seg1
                    x2 = seg2
                if (x1.size == 1) and (x2.size == 1):
                    corr[i,j] = 1
                else:
                    c,_ = sak.signal.xcorr(x1,x2)
                    corr[i,j] = np.max(np.abs(c))
            else:
                corr[i,j] = 1
                
    return corr

def order_fiducials(qrs, other, mode='smaller'):
    # just in case
    other = np.copy(other)

    if mode=='smaller':
        filt = (other[None,:] < qrs[:,None])
    elif mode=='bigger':
        filt = (qrs[:,None] < other[None,:])
    else:
        raise ValueError("Mode must be {smaller, bigger}")
        
    if other.size == 0:
        return np.full_like(qrs,-1,dtype=int)
        
    for i in range(qrs.size):
        if i >= filt.shape[1]:
            newcol = np.zeros((qrs.size,),dtype=bool)
            if mode=='smaller':
                newcol[i:] = True
            elif mode=='bigger':
                newcol[:i+1] = True
            other = np.insert(other,i,-1)
            filt = np.insert(filt,i,newcol,axis=1)
        else:
            if mode=='smaller':
                j = np.nonzero(filt[:,i])[0][0]
            elif mode=='bigger':
                j = qrs.size-np.nonzero(np.flip(filt[:,i]))[0][0]-1

            if i != j:
                newcol = np.zeros((qrs.size,),dtype=bool)
                if mode=='smaller':
                    newcol[i:] = True
                elif mode=='bigger':
                    newcol[:i+1] = True
                other = np.insert(other,i,-1)
                filt = np.insert(filt,i,newcol,axis=1)

    return other

In [3]:
basedir = '/media/guille/DADES/DADES/Delineator'

In [None]:
#### LOAD DATASETS ####
dataset             = pd.read_csv(os.path.join(basedir,'QTDB','Dataset.csv'), index_col=0)
dataset             = dataset.sort_index(axis=1)
labels              = np.asarray(list(dataset)) # In case no data augmentation is applied
description         = dataset.describe()
group               = {k: '_'.join(k.split('_')[:-1]) for k in dataset}
unique_ids          = list(set([k.split('_')[0] for k in dataset]))

# Load validity
validity            = sak.load_data(os.path.join(basedir,'QTDB','validity.csv'))

# Load fiducials
Pon_QTDB            = load_data(os.path.join(basedir,'QTDB','PonNew.csv'))
Poff_QTDB           = load_data(os.path.join(basedir,'QTDB','PoffNew.csv'))
QRSon_QTDB          = load_data(os.path.join(basedir,'QTDB','QRSonNew.csv'))
QRSoff_QTDB         = load_data(os.path.join(basedir,'QTDB','QRSoffNew.csv'))
Ton_QTDB            = load_data(os.path.join(basedir,'QTDB','TonNew.csv'))
Toff_QTDB           = load_data(os.path.join(basedir,'QTDB','ToffNew.csv'))

In [None]:
# Generate masks & signals
x = {}
y = {}
for k in tqdm.tqdm(QRSon_QTDB):
    # Check file exists and all that
    if k not in validity:
        print("Issue with file {}, continuing...".format(k))
        continue

    # Store signal
    signal = dataset[k][validity[k][0]:validity[k][1]].values
    signal = sak.signal.on_off_correction(signal)
    amplitude = np.median(sak.signal.moving_lambda(signal,200,sak.signal.abs_max))
    signal = signal/amplitude
    x[k] = signal[None,]
    
    # Generate boolean mask
    segmentation = np.zeros((3,dataset.shape[0]),dtype=bool)
    if k in Pon_QTDB:
        for on,off in zip(Pon_QTDB[k],Poff_QTDB[k]):
            segmentation[0,on:off] = True
    if k in QRSon_QTDB:
        for on,off in zip(QRSon_QTDB[k],QRSoff_QTDB[k]):
            segmentation[1,on:off] = True
    if k in Ton_QTDB:
        for on,off in zip(Ton_QTDB[k],Toff_QTDB[k]):
            segmentation[2,on:off] = True
    
    y[k] = segmentation[:,validity[k][0]:validity[k][1]]
    

In [409]:
import torch
import torch.utils

class DataQTDB(torch.utils.data.Dataset):
    '''Generates data for PyTorch'''

    def __init__(self, x, y, window, stride, dtype='float32'):
        '''Initialization'''
        assert set(x.keys()) == set(y.keys())
        # Store inputs
        self.window = window
        self.stride = stride
        self.dtype = dtype
        self.x = x
        self.y = y
        self.keys = list(x)

        # Compute size
        self.window_distribution = np.cumsum([0] + [(x[k].size - window + stride)//stride for k in x])
        self.num_windows = self.window_distribution[-1] # Extremely small performance gain

    def __len__(self):
        '''Denotes the number of batches per epoch'''
        return self.num_windows
    
    def __get_key_window(self, i):
        """Retrieve an index's key and number of window"""
        loc = np.argmax(i < self.window_distribution)
        key = self.keys[loc-1]
        win = i-self.window_distribution[loc-1]
        
        return key,win
    
    def __getitem__(self, i):
        '''Generates one datapoint''' 
        # Retrieve window location
        key,n_window = self.__get_key_window(i)

        # Compute onsets and offsets for localization
        on  = n_window*self.stride
        off = on + self.window
        
        # Retrieve data
        x = self.x[key][:,on:off]
        y = self.y[key][:,on:off]
        
        if i == self.num_windows:
            raise StopIteration
        
        return x,y

In [410]:
self = DataQTDB(x,y,window,stride)

In [408]:
save = []
for i,out in enumerate(tqdm.tqdm(self)):
    pass

100%|██████████| 8070/8070 [00:00<00:00, 139328.44it/s]


In [None]:
total = sum([(x[k].size - 2048 + 128)//128 for k in x])

for i in range(total):
    # Retrieve single window
    numRecord = int(i)//int(self.num_windows)
    key       = self.x.keys()[numRecord]
    n_window  = (int(i)-numRecord*self.num_windows)


In [None]:
# Desambiguate data augmentation
j         = i//(self.num_windows*len(self.x))
i         = i%(self.num_windows*len(self.x))
numRecord = int(i)//int(self.num_windows)
key       = self.x.keys()[numRecord]
n_window  = (int(i)-numRecord*self.num_windows)


In [108]:
import skimage

In [119]:
x[k].shape

(1, 8968)

In [131]:
(x[k].size-2048+128)//128

55