In [4]:
%load_ext autoreload
%autoreload 2

In [1]:
from torch.utils.tensorboard import SummaryWriter

In [2]:
import pywt
from collections import Counter 
from scipy import signal as sig

In [3]:
import matplotlib.pyplot as plt
from collections import Counter
import numpy as np
import wfdb
from wfdb import processing
import os
from functools import partial
from utils import plot_signal_with_r_peaks, load_patient, extract_test_windows, mean
from tqdm import tqdm
import torch.nn as nn
from collections import Counter 

In [4]:
from model_unet1d import *
import torch_optimizer as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [5]:
data_dir = 'mit-bih-arrhythmia'

data_dir = 'filtered-mit-bih-arrhythmia'

In [6]:
# Window/Segment length 

l = 20 #seconds
# window stride for testing. 

s = 10 #seconds
# Sapmling frequency of ecg signal is 360 hz. 
# (https://archive.physionet.org/physiobank/database/html/mitdbdir/intro.htm#annotations)

fs = 360
# Window/Segment length in samples. 
win_size = l*fs
print(f"Window size is {win_size}")
# Stride for test window in samples. 
stride = s*fs

Window size is 7200


In [7]:
def preprocess(ecg):
    
    signal = ecg.flatten()

    # Apply DWT
    wavelet = 'db4'
    level = 9
    coeffs = pywt.wavedec(signal, wavelet, level=level)

    # Remove baseline wandering
    coeffs[1:] = [pywt.threshold(i, np.std(i)/2) for i in coeffs[1:]]
    filtered_signal = pywt.waverec(coeffs, wavelet)

    # Apply lowpass filter
    nyquist = int(180)
    cutoff = 40
    b, a = sig.butter(4, cutoff/nyquist, 'low')
    filtered_signal = sig.filtfilt(b, a, filtered_signal)
    return filtered_signal

In [8]:
def extract_training_windows(ecg, win_size, Rpeaks_pos):
    
    print('Preparing Training Data for R-peaks detection')

    # Total windows in ecg. (Number of training examples)
    tot_wins = int(len(ecg)/win_size)

    X_train = np.zeros((tot_wins,win_size), dtype=np.float64)
    y_train = np.zeros((tot_wins,win_size))

    # Annotations for each window
    R = []
    
    normalize = partial(processing.normalize_bound, lb=-1, ub=1)
    
    for i in tqdm(range(tot_wins)):
    
        # Start of window in whole ecg stream
        st = i*win_size
        
        # End of window
        end = st + win_size

        # R peaks in the current window
        rIndx = np.where((Rpeaks_pos >= st) & (Rpeaks_pos < end))[0]

        R.append(Rpeaks_pos[rIndx]-st)

                  
        for j in Rpeaks_pos[rIndx]:
            r = int(j)-st
            y_train[i,r-2:r+3] = 1


        # If ecg window is non zero. Normalize it. 
        if ecg[st:end].any():
            X_train[i,:] = np.squeeze(np.apply_along_axis(normalize, 0, ecg[st:end]))
        # All zero ecg window
        else:
            X_train[i,:] = ecg[st:end].T


    X_train = np.expand_dims(X_train, axis=1)
    
    y_train = np.expand_dims(y_train, axis=1)
    
    return X_train, y_train, R

In [9]:
patients = [x for x in list(set([x.split('.')[0] for x in os.listdir(f'./{data_dir}/')])) if len(x) > 1]
print(f"The number of patient records is {len(patients)} ")




The number of patient records is 48 


In [10]:
np.random.RandomState(seed=42).permutation(patients)

array(['231', '115', '215', '208', '219', '202', '233', '105', '220',
       '222', '102', '121', '221', '210', '205', '119', '212', '230',
       '109', '203', '100', '101', '116', '200', '122', '207', '113',
       '214', '104', '228', '213', '232', '117', '112', '217', '103',
       '111', '123', '118', '124', '201', '106', '234', '107', '209',
       '223', '114', '108'], dtype='<U3')

In [11]:
# Let's define  the splitting strategy for training our detection model
def data_split(patients):
    SEED = 42
    indices = np.random.RandomState(seed=SEED).permutation(patients)
    m = len(patients)
    training_idx, val_idx, test_idx = indices[:int(m*0.7)], indices[int(m*0.7):int(m*0.85)], indices[int(m*0.85):]

    return training_idx, test_idx, val_idx

In [12]:
training_idx, test_idx, val_idx = data_split(patients)

In [13]:
def extract_testing_windows(ecg, win_size, Rpeaks_pos):
    overlap = 2000
    hop_size = win_size - overlap
    
    ecg_pad = np.pad(ecg,(0, 5200), mode='edge')
    
    print('Preparing testing data for R-peaks detection')
    
    # Total windows in ecg. (Number of training examples)
    tot_wins = int((len(ecg) - overlap)/ hop_size)
    
    X_train = np.zeros((tot_wins,win_size), dtype=np.float64)
    y_train = np.zeros((tot_wins,win_size))

    # Annotations for each window
    R = []
    
    normalize = partial(processing.normalize_bound, lb=-1, ub=1)
    
#     for i in range(0, len(pad_sig), stride):
#     for i in range(0, len(tot_wins), stride):

    pad_id = np.arange(ecg_pad.shape[0])
    
    for i in tqdm(range(tot_wins)):
    
        # Start of window in whole ecg stream
        st = i*win_size
        
        # End of window
        end = st + win_size

        # R peaks in the current window
        rIndx = np.where((Rpeaks_pos >= st) & (Rpeaks_pos < end))[0]

        R.append(Rpeaks_pos[rIndx]-st)

                  
        for j in Rpeaks_pos[rIndx]:
            r = int(j)-st
            y_train[i,r-2:r+3] = 1


        # If ecg window is non zero. Normalize it. 
        if ecg_pad[st:end].any():
            X_train[i,:] = np.squeeze(np.apply_along_axis(normalize, 0, ecg_pad[st:end]))
        # All zero ecg window
        else:
#             print(i)
            pass
#             X_train[i,:] = ecg_pad[st:end].T


    X_train = np.expand_dims(X_train, axis=1)
    
    y_train = np.expand_dims(y_train, axis=1)
    
    return X_train, y_train, R

In [14]:
def data_laoding(idx, win_size, data_dir): 
    
    for i, pat_num in zip(range(len(idx)), idx):

        ecg, Rpeaks_pos = load_patient(data_dir, str(pat_num))

        X_train_s, y_train_s, R = extract_training_windows(ecg, win_size, Rpeaks_pos)
        if i == 0:
            X_train = X_train_s
            y_train = y_train_s
        else:
            X_train = np.concatenate((X_train, X_train_s))
            y_train = np.concatenate((y_train, y_train_s))
            
    return X_train , y_train, R

In [15]:
# def data_laoding_testing(idx, win_size, data_dir): 
    
#     for i, pat_num in zip(range(len(idx)), idx):

#         ecg, Rpeaks_pos = load_patient(data_dir, str(pat_num))
        
#         X_test_s, y_test_s, R = extract_testing_windows(ecg, win_size, Rpeaks_pos)
#         if i == 0:
#             X_test = X_test_s
#             y_test = y_test_s
#         else:
#             X_test = np.concatenate((X_test, X_test_s))
#             y_test = np.concatenate((y_test, y_test_s))
            
#     return X_test , X_test, R

In [16]:
class CustomImageDataset(Dataset):
    
    def __init__(self, data_dir = 'mit-bih-arrhythmia',
                 idx_im = training_idx, label = 'train'):
        
        self.win_size = 7200
        self.idx_im = idx_im

        self.data_dir = data_dir
        if label == 'train':
            self.X, self.y, self.R = data_laoding(self.idx_im, self.win_size, self.data_dir)
            
        if label == 'test':
            self.X, self.y, self.R = data_laoding_testing(self.idx_im, self.win_size, self.data_dir)
        
        
    def __len__(self):
        return len(self.idx_im)

    def __getitem__(self, idx):
        
        X_inp = torch.from_numpy(self.X[idx]).float()
        y_inp = torch.from_numpy(self.y[idx]).float()
        half_pad = (7232 - self.win_size) // 2
        p = torch.nn.ConstantPad1d(half_pad, 0)
        X_inp1 = p(X_inp)
        y_inp1 = p(y_inp)

        return X_inp1, y_inp1

In [17]:
training_data = CustomImageDataset('filtered-mit-bih-arrhythmia', training_idx, 'train') 
train_dataloader = DataLoader(training_data, batch_size=4, shuffle=True)

val_data = CustomImageDataset('filtered-mit-bih-arrhythmia',  val_idx, 'train')
val_dataloader = DataLoader(val_data, batch_size=4, shuffle=True)

# testing_data = CustomImageDataset('filtered-mit-bih-arrhythmia',  test_idx, 'test')
# test_dataloader = DataLoader(testing_data, batch_size=4, shuffle=True)


Loading Data for Patient : 231
Total Beats :  328
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 10469.76it/s]


Loading Data for Patient : 115
Total Beats :  1954
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 13623.77it/s]


Loading Data for Patient : 215
Total Beats :  3367
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 12522.80it/s]

Loading Data for Patient : 208





Total Beats :  2631
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 13794.53it/s]


Loading Data for Patient : 219
Total Beats :  2174
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 15695.29it/s]


Loading Data for Patient : 202
Total Beats :  2124
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 14985.01it/s]


Loading Data for Patient : 233
Total Beats :  3139
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 11534.08it/s]

Loading Data for Patient : 105





Total Beats :  2568
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 11141.56it/s]


Loading Data for Patient : 220
Total Beats :  2065
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 13156.08it/s]


Loading Data for Patient : 222
Total Beats :  2406
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 15106.14it/s]


Loading Data for Patient : 102
Total Beats :  108
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 18453.63it/s]


Loading Data for Patient : 121
Total Beats :  1864
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 16134.70it/s]


Loading Data for Patient : 221
Total Beats :  2450
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 14012.15it/s]


Loading Data for Patient : 210
Total Beats :  2634
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 10978.58it/s]

Loading Data for Patient : 205





Total Beats :  2658
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 15318.86it/s]

Loading Data for Patient : 119





Total Beats :  2090
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 17110.30it/s]

Loading Data for Patient : 212





Total Beats :  924
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 15552.38it/s]

Loading Data for Patient : 230





Total Beats :  2463
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 13228.92it/s]

Loading Data for Patient : 109





Total Beats :  39
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 16793.64it/s]

Loading Data for Patient : 203





Total Beats :  3018
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 12966.28it/s]


Loading Data for Patient : 100
Total Beats :  2274
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 13505.33it/s]


Loading Data for Patient : 101
Total Beats :  1864
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 13865.47it/s]


Loading Data for Patient : 116
Total Beats :  2413
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 15805.69it/s]


Loading Data for Patient : 200
Total Beats :  2747
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 14677.37it/s]


Loading Data for Patient : 122
Total Beats :  2477
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 14780.82it/s]


Loading Data for Patient : 207
Total Beats :  236
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 16589.93it/s]


Loading Data for Patient : 113
Total Beats :  1790
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 17387.72it/s]


Loading Data for Patient : 214
Total Beats :  281
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 18348.68it/s]


Loading Data for Patient : 104
Total Beats :  210
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 17917.57it/s]


Loading Data for Patient : 228
Total Beats :  2094
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 16929.96it/s]


Loading Data for Patient : 213
Total Beats :  2929
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 15105.54it/s]


Loading Data for Patient : 232
Total Beats :  1383
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 16697.07it/s]


Loading Data for Patient : 117
Total Beats :  1536
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 16074.24it/s]


Loading Data for Patient : 112
Total Beats :  2540
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 16781.69it/s]

Loading Data for Patient : 217





Total Beats :  473
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 17821.98it/s]

Loading Data for Patient : 103





Total Beats :  2085
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 16123.67it/s]


Loading Data for Patient : 111
Total Beats :  2
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 20505.59it/s]


Loading Data for Patient : 123
Total Beats :  1519
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 16203.96it/s]


Loading Data for Patient : 118
Total Beats :  113
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 17311.17it/s]

Loading Data for Patient : 124





Total Beats :  62
Preparing Training Data for R-peaks detection


100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 20505.59it/s]


In [18]:
# print('Total Windows : ', len(X_train))


# # All indexes of training windows.
# idx = np.arange(len(X_train))

# # indexes other than S and V beats. (Normal R peaks)
# rem_idx =  idx

# # # Choose 50% of remaining. 
# norm_idx = np.random.choice(rem_idx, size=int(len(rem_idx)*(1/3)), replace=False)

# X_train = np.concatenate((X_train[norm_idx],X_train[idx]))
# y_train = np.concatenate((y_train[norm_idx],y_train[idx]))

# assert len(X_train) == len(y_train)

# print('Selected Windows : ', len(X_train))

# print('Saving Data')

In [19]:
import datetime
def setup_experiment(title, logdir="./logs"):
    experiment_name = "{}@{}".format(title, datetime.datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
    writer = SummaryWriter(log_dir=os.path.join(logdir, experiment_name))
    best_model_path = f"{title}.best.pth"
    return writer, experiment_name, best_model_path

In [20]:
model = UNet_1D()


# The learning rate needs to be lower than the default used by Adam, otherwise the learning can be unstable.
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=5e-4)

# optimizer = torch.optim.Adam(model.parameters())

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20,40,60], gamma=0.5)

writer, experiment_name, best_model_path = setup_experiment(model.__class__.__name__, logdir="./tb_transf")
print(f"Experiment name: {experiment_name}")


print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")
# sum(p.numel() for p in model.parameters() if p.requires_grad)

Experiment name: UNet_1D@28.04.2023-15:43:33
Model has 79,153 trainable parameters


In [21]:
def run_epoch(model, iterator, optimizer, criterion, phase='train', epoch=0, writer=None):
    is_train = (phase == 'train')
    if is_train:
        model.train()
    else:
        model.eval()
    
    epoch_loss = 0

    # variables for calculating accuracy
    n_predicted = 0
    n_true_predicted = 0
    
    with torch.set_grad_enabled(is_train):
        for i, batch in enumerate(iterator):
            global_i = len(iterator) * epoch + i
            
            # unpack batch
            inputs, labels = batch
#             text, postag = batch.text, batch.postag
            
            # make prediction
            pred = model(inputs)
            
            
            # calculate loss
            loss = criterion(pred, labels)
            
            if is_train:
                # make optimization step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
 
            epoch_loss += loss.item()

        # dump epoch metrics to tensorboard
        if writer is not None:
            writer.add_scalar(f"loss_epoch/{phase}", epoch_loss / len(iterator), epoch)

        return epoch_loss / len(iterator)

In [22]:
n_epochs = 100
log = 'logs'
if not os.path.exists(log):
    os.makedirs(log)
best_model_path = f"{log}/1d_unet.best.pth"
    
best_val_loss = float('+inf')
for epoch in range(n_epochs):    
    train_loss = run_epoch(model, train_dataloader, optimizer, criterion, phase='train', epoch=epoch, writer = writer)
    val_loss = run_epoch(model, val_dataloader, None, criterion, phase='val', epoch=epoch, writer=writer)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
    
    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {val_loss:.3f} ') 

Epoch: 01
	Train Loss: 0.791
	 Val. Loss: 0.511 
Epoch: 02
	Train Loss: 0.608
	 Val. Loss: 0.487 
Epoch: 03
	Train Loss: 0.502
	 Val. Loss: 0.456 
Epoch: 04
	Train Loss: 0.426
	 Val. Loss: 0.421 
Epoch: 05
	Train Loss: 0.360
	 Val. Loss: 0.377 
Epoch: 06
	Train Loss: 0.310
	 Val. Loss: 0.340 
Epoch: 07
	Train Loss: 0.266
	 Val. Loss: 0.311 
Epoch: 08
	Train Loss: 0.229
	 Val. Loss: 0.278 
Epoch: 09
	Train Loss: 0.197
	 Val. Loss: 0.241 
Epoch: 10
	Train Loss: 0.169
	 Val. Loss: 0.212 
Epoch: 11
	Train Loss: 0.140
	 Val. Loss: 0.181 
Epoch: 12
	Train Loss: 0.127
	 Val. Loss: 0.161 
Epoch: 13
	Train Loss: 0.106
	 Val. Loss: 0.140 
Epoch: 14
	Train Loss: 0.099
	 Val. Loss: 0.132 
Epoch: 15
	Train Loss: 0.086
	 Val. Loss: 0.121 
Epoch: 16
	Train Loss: 0.078
	 Val. Loss: 0.113 
Epoch: 17
	Train Loss: 0.067
	 Val. Loss: 0.107 
Epoch: 18
	Train Loss: 0.061
	 Val. Loss: 0.100 
Epoch: 19
	Train Loss: 0.052
	 Val. Loss: 0.092 
Epoch: 20
	Train Loss: 0.047
	 Val. Loss: 0.085 
Epoch: 21
	Train Los

In [23]:
def calculate_stats(r_ref, r_ans, thr , fs):


    FP = 0
    TP = 0
    FN = 0
    for j in range(len(r_ref)):
        loc = np.where(np.abs(r_ans - r_ref[j]) <= thr*fs)[0]
            
        if len(loc) >= 1:
            TP += 1
            FP += len(loc) - 1
        elif len(loc) == 0:
            FN += 1

    Recall = (TP / (FN + TP))*100
    Precision = (TP / (FP + TP))*100

    if Recall + Precision == 0:
        F1_score = 0
    else:
        F1_score = (2 * Recall * Precision / (Recall + Precision))
    print("Recall:{}, Precision(FNR):{}, F1-Score:{}".format(Recall,Precision,F1_score))
    print("Total {}".format(len(r_ref)))
    return TP, FN, FP, Precision, Recall, F1_score
    


In [38]:
# os.environ['TENSORBOARD_BINARY'] = './tb_trans'

# %reload_ext tensorboard
# logs_base_dir = "./tb_trans"
# %tensorboard --logdir {logs_base_dir}

In [24]:
# custom evaluation 
stats_R = []

win_size = 7200
stride = 7200//2
model_path =  f"{log}/1d_unet.best.pth"

model = UNet_1D()

torch.load(model_path)

precision = []
recall = []
f1_score = []
tp_all = []
fn_all = []
fp_all = []
for i, pat_num in zip(range(len(test_idx)), test_idx):
    
    ecg, Rpeaks_pos = load_patient(data_dir, str(pat_num))

    padded_indices, data_windows = extract_test_windows(ecg, win_size, stride)
    data_windows = np.transpose(data_windows, (0, 2, 1))

    X_inp = torch.from_numpy(data_windows).float()
    
    half_pad = (7232 - 20*360) // 2
    p = torch.nn.ConstantPad1d(half_pad, 0)
    data_windows = p(X_inp)

    predictions = model(data_windows)[:, :, 16:-16].detach().numpy()
    predictions = np.transpose(predictions, (0, 2, 1))
    
    predictions = mean(win_idx=padded_indices, preds=predictions, 
                           orig_len=ecg.shape[0],win_size=win_size,
                           stride= stride)
    assert(predictions.shape == ecg.shape)
    
    threshold = 0.5
    above_thresh = predictions[predictions > threshold]
    above_threshold_idx = np.where(predictions > threshold)[0]
    
    correct_up = processing.correct_peaks(sig=ecg,
                                          peak_inds=above_threshold_idx,
                                          search_radius=30,
                                          smooth_window_size=30,
                                          peak_dir='up')
    
    filtered_peaks = []
    filtered_probs = []

    for peak_id in tqdm(np.unique(correct_up)):

        points_in_peak = np.where(correct_up == peak_id)[0]
        if points_in_peak.shape[0] >= 5:
            filtered_probs.append(above_thresh[points_in_peak].mean())
            filtered_peaks.append(peak_id)
    

    filtered_peaks = np.asarray(filtered_peaks)
    filtered_probs = np.asarray(filtered_probs)
    
    thr = 0.15 # 150 ms
    fs = 360
    TP, FN, FP, pr_pat, rec_pat, f1_pat = calculate_stats(Rpeaks_pos, filtered_peaks, thr, fs)
    tp_all.append(TP)
    fn_all.append(FN)
    fp_all.append(FP)
    
    precision.append(pr_pat)
    recall.append(rec_pat)
    f1_score.append(f1_pat)
    

Loading Data for Patient : 201
Total Beats :  1888


100%|████████████████████████████████████████████████████████████| 55363/55363 [00:03<00:00, 14265.55it/s]


Recall:100.0, Precision(FNR):37.1946414499606, F1-Score:54.221711659965536
Total 1888
Loading Data for Patient : 106
Total Beats :  2068


100%|████████████████████████████████████████████████████████████| 62413/62413 [00:04<00:00, 14520.68it/s]


Recall:100.0, Precision(FNR):34.283819628647215, F1-Score:51.06172839506173
Total 2068
Loading Data for Patient : 234
Total Beats :  2706


100%|████████████████████████████████████████████████████████████| 67492/67492 [00:04<00:00, 13710.43it/s]


Recall:100.0, Precision(FNR):32.736511008952334, F1-Score:49.32555596062705
Total 2706
Loading Data for Patient : 107
Total Beats :  60


100%|████████████████████████████████████████████████████████████| 60329/60329 [00:04<00:00, 14652.32it/s]


Recall:100.0, Precision(FNR):28.846153846153843, F1-Score:44.776119402985074
Total 60
Loading Data for Patient : 209
Total Beats :  3026


100%|████████████████████████████████████████████████████████████| 58969/58969 [00:03<00:00, 15139.66it/s]


Recall:100.0, Precision(FNR):32.394818541912, F1-Score:48.93668634268618
Total 3026
Loading Data for Patient : 223
Total Beats :  2602


100%|████████████████████████████████████████████████████████████| 67085/67085 [00:04<00:00, 13974.85it/s]


Recall:100.0, Precision(FNR):34.39978847170809, F1-Score:51.19024198308086
Total 2602
Loading Data for Patient : 114
Total Beats :  1876


100%|████████████████████████████████████████████████████████████| 56009/56009 [00:03<00:00, 14903.09it/s]


Recall:100.0, Precision(FNR):28.67186305975852, F1-Score:44.56586292908896
Total 1876
Loading Data for Patient : 108
Total Beats :  1761


100%|████████████████████████████████████████████████████████████| 52882/52882 [00:03<00:00, 15066.58it/s]

Recall:100.0, Precision(FNR):36.91050094319849, F1-Score:53.9191671769749
Total 1761





In [48]:
p(X_inp).shape

torch.Size([182, 1, 7232])

In [26]:
np.mean(precision), np.mean(recall), np.mean(f1_score)

(33.179762118786385, 100.0, 49.74963423130878)

In [43]:
# filtered data 
np.mean(precision), np.mean(recall), np.mean(f1_score)

(41.21190402002332, 99.95071201730943, 58.30041044428638)

In [115]:
# filtered data 
np.mean(precision), np.mean(recall), np.mean(f1_score)

(38.02923894402076, 99.52015426207396, 54.856202381625394)

In [None]:
# raw data

In [119]:
from torchsummary import summary

In [126]:
# summary(model(20*360))

summary(model, (1, 7232))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 16, 3616]             160
        Downsample-2             [-1, 16, 3616]               0
            Conv1d-3             [-1, 16, 1808]           2,320
       BatchNorm1d-4             [-1, 16, 1808]              32
        Downsample-5             [-1, 16, 1808]               0
            Conv1d-6              [-1, 32, 904]           3,104
       BatchNorm1d-7              [-1, 32, 904]              64
        Downsample-8              [-1, 32, 904]               0
            Conv1d-9              [-1, 32, 452]           6,176
      BatchNorm1d-10              [-1, 32, 452]              64
       Downsample-11              [-1, 32, 452]               0
           Conv1d-12              [-1, 64, 226]           6,208
      BatchNorm1d-13              [-1, 64, 226]             128
       Downsample-14              [-1, 

In [82]:
for i, pat_num in zip(range(len(test_idx)), test_idx):
    print(pat_num)

223
101
208
220
115
219
116



In [47]:
model

UNet_1D(
  (down_stack): ModuleList(
    (0): Downsample(
      (conv): Conv1d(1, 16, kernel_size=(9,), stride=(2,), padding=(4,))
      (batchnorm): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Downsample(
      (conv): Conv1d(16, 16, kernel_size=(9,), stride=(2,), padding=(4,))
      (batchnorm): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): Downsample(
      (conv): Conv1d(16, 32, kernel_size=(6,), stride=(2,), padding=(2,))
      (batchnorm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): Downsample(
      (conv): Conv1d(32, 32, kernel_size=(6,), stride=(2,), padding=(2,))
      (batchnorm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (4): Downsample(
      (conv): Conv1d(32, 64, kernel_size=(3,), stride=(2,), padding=(1,))
      (batchnorm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=Tr