In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path
import skimage
import skimage.segmentation
import sklearn.preprocessing
import sklearn.model_selection
import math
import shutil
import pathlib
import glob
import shutil
import uuid
import random
import platform
import torch
import torchvision
import numpy as np
import scipy as sp
import scipy.io
import scipy.signal
import pandas as pd
import networkx
import wfdb
import json
import tqdm
import dill
import pickle
import matplotlib.pyplot as plt

import scipy.stats

import src.data
import src.metrics
import utils
import utils.wavelet
import utils.data
import utils.data.augmentation
import utils.visualization
import utils.visualization.plot
import utils.torch
import utils.torch.nn
import utils.torch.nn as nn
import utils.torch.loss
import utils.torch.train
import utils.torch.data
import utils.torch.preprocessing
import utils.torch.models
import utils.torch.models.lego
import utils.torch.models.variational
import utils.torch.models.classification

from utils.signal import StandardHeader

def smooth(x: np.ndarray, window_size: int, conv_mode: str = 'same'):
    x = np.pad(np.copy(x),(window_size,window_size),'edge')
    window = np.hamming(window_size)/(window_size//2)
    x = np.convolve(x, window, mode=conv_mode)
    x = x[window_size:-window_size]
    return x

In [3]:
def dice(mask_1, mask_2):
    intersection = (mask_1 * mask_2).sum()
    union = mask_1.sum() + mask_2.sum()
    return 2.*intersection/(union + np.finfo('double').eps)


In [26]:
gt_on  = utils.load_data('/home/guille/Escritorio/Ruben/ONLYQRSFORRUBEN/onsets.csv')
gt_off = utils.load_data('/home/guille/Escritorio/Ruben/ONLYQRSFORRUBEN/offsets.csv')
pr_on  = utils.load_data('/home/guille/Escritorio/Ruben/ONLYQRSFORRUBEN/onsets.csv')
pr_off = utils.load_data('/home/guille/Escritorio/Ruben/ONLYQRSFORRUBEN/offsets.csv')

In [27]:
# Jitter on on/off - differ in mean/std for delineation metrics
for k in pr_on:
    pr_on[k]  = np.clip(pr_on[k]+np.random.randint(0,2),  a_min=0, a_max=np.max(pr_on[k]))
    pr_off[k] = np.clip(pr_off[k]+np.random.randint(0,2), a_min=0, a_max=np.max(pr_off[k]))

In [28]:
# Mark to delete
deleted = {}

for k in pr_on:
    indices = np.arange(pr_on[k].size)
    indices = np.random.permutation(indices)
    deleted_indices = indices[:np.random.randint(0,indices.size//2)]
    deleted[k] = np.sort(deleted_indices)

# Delete from prediction (both onset and offset) -> False negatives
for k in pr_on:
    pr_on[k]  = np.delete(pr_on[k],  deleted[k])
    pr_off[k] = np.delete(pr_off[k], deleted[k])


In [29]:
# Add random segments -> false positives
for k in pr_on:
    # Random number of FP
    num_FP = np.random.randint(min([5,pr_on[k].size-2]))
    
    if num_FP < 1: continue
    
    locations = np.random.permutation(np.arange(1,pr_on[k].size-1))[:num_FP]
    
    for loc in locations:
        lower_bound = pr_off[k][loc-1]
        onset  = pr_on[k][loc]
        offset = pr_off[k][loc]
        upper_bound = pr_on[k][loc+1]
        
        try:
            new_on  = np.random.randint(offset+10,upper_bound-10)
            new_off = np.random.randint(new_on+10,upper_bound-10)

            pr_on[k]  = np.sort(np.hstack((pr_on[k], [new_on])))
            pr_off[k] = np.sort(np.hstack((pr_off[k],[new_off])))
        except:
            continue


In [30]:
from typing import List
import numpy as np

def dice_score(input: np.ndarray, target: np.ndarray) -> float:
    intersection = (input * target).sum()
    union = input.sum() + target.sum()
    return 2.*intersection/(union + np.finfo('double').eps)


def filter_valid(onset, offset, validity_on = 0, validity_off = np.inf):
    validity_on  = np.array( validity_on)[np.newaxis,np.newaxis]
    validity_off = np.array(validity_off)[np.newaxis,np.newaxis]

    mask_on    = (onset  >= validity_on) & (onset  <= validity_off)
    mask_off   = (offset >= validity_on) & (offset <= validity_off)
    mask_total = np.any(mask_on & mask_off, axis=0) # beat has to be found in every one

    onset = onset[mask_total]
    offset = offset[mask_total]

    return onset, offset


def correspondence(input_onset, input_offset, target_onset, target_offset):
    filtA =  ( input_onset <=  target_onset[:,np.newaxis]) & ( target_onset[:,np.newaxis] <= input_offset)
    filtB =  ( input_onset <= target_offset[:,np.newaxis]) & (target_offset[:,np.newaxis] <= input_offset)
    filtC = ((target_onset <=   input_onset[:,np.newaxis]) & (  input_onset[:,np.newaxis] <= target_offset)).T
    filtD = ((target_onset <=  input_offset[:,np.newaxis]) & ( input_offset[:,np.newaxis] <= target_offset)).T

    filter = filtA | filtB | filtC | filtD

    return filter


def interlead_correspondence(input_onsets: List[np.ndarray], input_offsets: List[np.ndarray], 
                             target_onsets: List[np.ndarray], target_offsets: List[np.ndarray], 
                             validity_on: int, validity_off: int):
    # ##### NOT FINISHED #####
    # filtA =  (res_0_on <= res_1_on[:,np.newaxis]) & (res_1_on[:,np.newaxis] <= res_0_of)
    # filtB =  (res_0_on <= res_1_of[:,np.newaxis]) & (res_1_of[:,np.newaxis] <= res_0_of)
    # filtC = ((res_1_on <= res_0_on[:,np.newaxis]) & (res_0_on[:,np.newaxis] <= res_1_of)).T
    # filtD = ((res_1_on <= res_0_of[:,np.newaxis]) & (res_0_of[:,np.newaxis] <= res_1_of)).T
    # filter = filtA | filtB | filtC | filtD
    # return filter
    pass


def post_processing(input_onset,input_offset,target_onset,target_offset,validity_on,validity_off):
    input_onset,input_offset = filter_valid(input_onset,input_offset,validity_on,validity_off)
    target_onset,target_offset = filter_valid(target_onset,target_offset,validity_on,validity_off)
    
    return input_onset,input_offset,target_onset,target_offset


def compute_metrics(input_onset, input_offset, target_onset, target_offset):
    # Init output
    tp   = 0
    fp   = 0
    fn   = 0
    dice = 0
    onset_error  = []
    offset_error = []

    # Find correspondence between fiducials
    filter = correspondence(input_onset, input_offset, target_onset, target_offset)

    # Check correspondence of GT beats to detected beats
    corr  = dict()
    
    # Account for already detected beats to calculate false positives
    chosen = np.zeros((filter.shape[0],), dtype=bool)
    for i,column in enumerate(filter.T):
        corr[i] = np.where(column)[0]
        chosen = chosen | column
        
    # Retrieve beats detected that do not correspond to any GT beat (potential false positives)
    not_chosen = np.where(np.logical_not(chosen))[0]
    
    # Compute Dice coefficient
    mask_input  = np.zeros((np.max(np.hstack((input_offset,target_offset)))+10,),dtype=bool)
    mask_target = np.zeros((np.max(np.hstack((input_offset,target_offset)))+10,),dtype=bool)
    for (onset,offset) in zip(input_onset,input_offset):
        mask_input[onset:offset] = True
    for (onset,offset) in zip(target_onset,target_offset):
        mask_target[onset:offset] = True
    dice = dice_score(mask_input, mask_target)

    # Compute metrics - Fusion strategy of results of both leads, following Martinez et al.
    for i in range(filter.shape[1]):
        # If any GT beat has a correspondence to any segmented beat, true positive + accounts for on/offset error
        if len(corr[i]) != 0:
            # Mark beat as true positive
            tp += 1
            
            # Compute the onset-offset errors
            onset_error.append(int(target_onset[corr[i]]  - input_onset[i]))
            offset_error.append(int(target_offset[corr[i]] - input_offset[i]))
            
        # If any GT beat has a correspondence to more than one segmented beat, 
        #     the rest of the pairs have to be false positives (Martinez et al.)
        if len(corr[i]) > 1:
            fp += len(corr[i]) - 1
        
        # If any GT beat has no correspondence to any segmented beat, false negative
        if len(corr[i]) == 0:
            fn += 1
            
    # False positives will correspond to those existing in the results that do not correspond to any beat in the GT (the not chosen)
    fp += len(not_chosen)
    
    return tp,fp,fn,dice,onset_error,offset_error
        

def precision(tp: int, fp: int, fn: int) -> float:
    return tp/(tp+fp)

def recall(tp: int, fp: int, fn: int) -> float:
    return tp/(tp+fn)

def f1_score(tp: int, fp: int, fn: int) -> float:
    return tp/(tp+(fp+fn)/2)



In [31]:
compute_metrics(pr_on[k],pr_off[k],gt_on[k],gt_off[k])

(16,
 3,
 1,
 0.8464163822525598,
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0])

# Multi-to-single lead

In [159]:
# ludb
gt = utils.load_data('/media/guille/DADES/DADES/Delineator/ludb/QRS.csv')
pr = utils.load_data('/media/guille/DADES/DADES/Delineator/ludb/QRS.csv')
gt_on  = {k: gt[k][0::2] for k in gt}
gt_off = {k: gt[k][1::2] for k in gt}
pr_on  = {k: pr[k][0::2] for k in pr}
pr_off = {k: pr[k][1::2] for k in pr}


In [177]:
# Jitter on on/off - differ in mean/std for delineation metrics
for k in tqdm.tqdm(pr_on):
    pr_on[k]  = np.clip(pr_on[k]+np.random.randint(0,2),  a_min=0, a_max=np.max(pr_on[k]))
    pr_off[k] = np.clip(pr_off[k]+np.random.randint(0,2), a_min=0, a_max=np.max(pr_off[k]))

100%|██████████| 2388/2388 [00:00<00:00, 16083.65it/s]


In [182]:
# Mark to delete
deleted = {}

for k in tqdm.tqdm(pr_on):
    indices = np.arange(pr_on[k].size)
    indices = np.random.permutation(indices)
    deleted_indices = indices[:np.random.randint(0,indices.size//2)]
    deleted[k] = np.sort(deleted_indices)

# Delete from prediction (both onset and offset) -> False negatives
for k in tqdm.tqdm(pr_on):
    pr_on[k]  = np.delete(pr_on[k],  deleted[k])
    pr_off[k] = np.delete(pr_off[k], deleted[k])


100%|██████████| 2388/2388 [00:00<00:00, 80142.73it/s]
100%|██████████| 2388/2388 [00:00<00:00, 86416.33it/s]


In [187]:
# Add random segments -> false positives
for k in tqdm.tqdm(pr_on):
    # Random number of FP
    num_FP = np.random.randint(min([5,pr_on[k].size-2]))
    
    if num_FP < 1: continue
    
    locations = np.random.permutation(np.arange(1,pr_on[k].size-1))[:num_FP]
    
    for loc in locations:
        lower_bound = pr_off[k][loc-1]
        onset  = pr_on[k][loc]
        offset = pr_off[k][loc]
        upper_bound = pr_on[k][loc+1]
        
        try:
            new_on  = np.random.randint(offset+10,upper_bound-10)
            new_off = np.random.randint(new_on+10,upper_bound-10)

            pr_on[k]  = np.sort(np.hstack((pr_on[k], [new_on])))
            pr_off[k] = np.sort(np.hstack((pr_off[k],[new_off])))
        except:
            continue


100%|██████████| 2388/2388 [00:00<00:00, 244078.32it/s]


In [113]:
def cross_correspondence(input_onsets_A, input_offsets_A, input_onsets_B, input_offsets_B):
    filtA =  (input_onsets_A <=  input_onsets_B[:,None]) &  (input_onsets_B[:,None] <= input_offsets_A)
    filtB =  (input_onsets_A <= input_offsets_B[:,None]) & (input_offsets_B[:,None] <= input_offsets_A)
    filtC = ((input_onsets_B <=  input_onsets_A[:,None]) &  (input_onsets_A[:,None] <= input_offsets_B)).T
    filtD = ((input_onsets_B <= input_offsets_A[:,None]) & (input_offsets_A[:,None] <= input_offsets_B)).T

    filter = filtA | filtB | filtC | filtD

    return filter


In [188]:
k = '49'
key = k.split('###')[0]
listk = [k for k in pr_on.keys() if k.startswith('{}###'.format(key))]

In [189]:
pr_on["{}###AVL".format(k)] = np.hstack((pr_on["{}###AVL".format(k)],[2017, 2017+700]))
pr_off["{}###AVL".format(k)] = np.hstack((pr_off["{}###AVL".format(k)],[2217, 2217+700]))

In [192]:
pr_on["{}###AVL".format(k)] = pr_on["{}###AVL".format(k)][[0,2,3,4]]
pr_off["{}###AVL".format(k)] = pr_off["{}###AVL".format(k)][[0,2,3,4]]

In [193]:
filters = [cross_correspondence(pr_on[k1],pr_off[k1],pr_on[k2],pr_off[k2]) for k1 in listk for k2 in listk if k1 != k2]

In [197]:
in_on  = [ pr_on[k] for k in listk]
in_off = [pr_off[k] for k in listk]
tg_on  = [ gt_on[k] for k in listk]
tg_off = [gt_off[k] for k in listk]

In [122]:
filters,filters_corr,chosen,corr = src.metrics.compute_multilead_metrics(in_on,in_off,tg_on,tg_off)

In [146]:
falsepositive,filters_corr,not_chosen = src.metrics.compute_multilead_metrics(in_on,in_off,tg_on,tg_off)

In [147]:
falsepositive

1

In [196]:
not_chosen

[array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([3]),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64),
 array([], dtype=int64)]

In [206]:
for col1, col2 in zip(filters_corr[0],filters_corr[1]):
    print(col1)
    print(col2)
    print("")

[ True False False]
[ True False False]

[False  True False]
[False  True False]

[False False  True]
[False False  True]



In [202]:
src.metrics.compute_metrics(in_on[4],in_off[4],tg_on[4],tg_off[4],return_not_chosen=1)

{0: array([0]), 1: array([], dtype=int64), 2: array([1])}


(2, 2, 1, 0.2685624012638231, [-11, 0], [-6, 0], array([2, 3]))