# Clean ECG using Deep Learning

In [19]:
from utils.main_model import DDPM
from utils.denoising_model_small import ConditionalModel
import torch
import yaml
from utils.signal_processing import filter_sos, filter_iir
import scipy.signal as signal
import scipy.fftpack    
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

## Load data

In [20]:
import pandas as pd
import glob

# # Nhan
# directories = {
#     '1712094934_Nhan/1712094934_BASE_LINE': 'BASE_LINE',
#     '1712094934_Nhan/1712095682_VOR_1HZ': 'VOR_1HZ',
#     '1712094934_Nhan/1712095831_VOR_1HZ_REST': 'VOR_1HZ_REST',
#     '1712094934_Nhan/1712096475_VOR_2HZ': 'VOR_2HZ',
#     '1712094934_Nhan/1712096619_VOR_2HZ_REST': 'VOR_2HZ_REST',
#     '1712094934_Nhan/1712097275_VOR_3HZ': 'VOR_3HZ',
#     '1712094934_Nhan/1712097418_VOR_3HZ_REST': 'VOR_3HZ_REST',
#     '1712094934_Nhan/1712098656_VOR_4HZ': 'VOR_4HZ',
#     '1712094934_Nhan/1712098799_VOR_4HZ_REST': 'VOR_4HZ_REST'
# }

# # Saad
# dir_name = '1713994149_Saad'
# directories = {
#     f'{dir_name}/1713994149_BASE_LINE': 'BASE_LINE',
#     f'{dir_name}/1713994824_VOR_1HZ': 'VOR_1HZ',
#     f'{dir_name}/1713994964_VOR_1HZ_REST': 'VOR_1HZ_REST',
#     f'{dir_name}/1713995471_VOR_2HZ': 'VOR_2HZ',
#     f'{dir_name}/1713995610_VOR_2HZ_REST': 'VOR_2HZ_REST',
#     f'{dir_name}/1713996271_VOR_3HZ': 'VOR_3HZ',
#     f'{dir_name}/1713996412_VOR_3HZ_REST': 'VOR_3HZ_REST',
#     f'{dir_name}/1713996958_VOR_4HZ': 'VOR_4HZ',
#     f'{dir_name}/1713997104_VOR_4HZ_REST': 'VOR_4HZ_REST'
# }

# Jacob
directories = {
    '1713293584_Jacob/1713293584_BASE_LINE': 'BASE_LINE',
    '1713293584_Jacob/1713294328_VOR_1HZ': 'VOR_1HZ',
    '1713293584_Jacob/1713294472_VOR_1HZ_REST': 'VOR_1HZ_REST',
    '1713293584_Jacob/1713295133_VOR_2HZ': 'VOR_2HZ',
    '1713293584_Jacob/1713295276_VOR_2HZ_REST': 'VOR_2HZ_REST',
    '1713293584_Jacob/1713295919_VOR_3HZ': 'VOR_3HZ',
    '1713293584_Jacob/1713296059_VOR_3HZ_REST': 'VOR_3HZ_REST',
    '1713293584_Jacob/1713296831_VOR_4HZ': 'VOR_4HZ',
    '1713293584_Jacob/1713296971_VOR_4HZ_REST': 'VOR_4HZ_REST'
}

# Initialize an empty dictionary to store the DataFrames
data_frames = {}
length = {}

global_start_idx = 0
for dir_key, label in directories.items():
    file_pattern = f'./data/logging/{dir_key}/CONVERTED_*'
    file_names = sorted(glob.glob(file_pattern))
    data_frames[label] = pd.concat([pd.read_csv(file_name, sep='\t') for file_name in file_names])
    
    # subjective_feedback = data_frames[label]['subjective feedback'].to_numpy().copy()
    data_frames[label] = data_frames[label].drop(columns=['index', 'time stamp'])
    length[label] = len(data_frames[label])
    
    if label == 'VOR_4HZ':
        vor_4hz_start = global_start_idx
        vor_4hz_end = global_start_idx + len(data_frames[label])
    global_start_idx += len(data_frames[label])
    
subjective_feedback = data_frames['VOR_4HZ']['subjective feedback'].to_numpy().copy()
data = pd.concat(data_frames.values())


## Process data

In [21]:
fs = 250 
# Preprocess and Filter the Data
filtered_data = []
raw_data = []
before_processing_data = []
for i in range(data.shape[1]):
    column_data = data.iloc[:, i].to_numpy().copy()
    raw_data.append(column_data)
    # 7 channels
    if i < 6:
        column1 = filter_iir(column_data, 'bandpass', [1, 50.0], fs, 4)
        column1 = filter_iir(column1, 'notch', 50.0, fs, 2, notch_width=4)
        column1 = filter_iir(column1, 'notch', 60.0, fs, 2, notch_width=4)
                
        column_data = filter_iir(column_data, 'bandpass', [5, 50.0], fs, 4)
        column_data = filter_iir(column_data, 'notch', 50.0, fs, 2, notch_width=4)
        column_data = filter_iir(column_data, 'notch', 60.0, fs, 2, notch_width=4)
    if i == 6:
        column1 = filter_iir(column_data, 'bandpass', [1, 50.0], fs, 4)
        column1 = filter_iir(column1, 'notch', 50.0, fs, 2, notch_width=4)
        column1 = filter_iir(column1, 'notch', 60.0, fs, 2, notch_width=4)
        
        column_data = filter_iir(column_data, 'bandpass', [5, 50.0], fs, 4)
        column_data = filter_iir(column_data, 'notch', 50.0, fs, 2, notch_width=4)
        column_data = filter_iir(column_data, 'notch', 60.0, fs, 2, notch_width=4)
        
        ############################################
        b, a = signal.butter(4, 50/(fs/2), 'low')

        tempf = signal.filtfilt(b,a, column_data)
        yff = scipy.fftpack.fft(tempf)

        nyq_rate = fs/ 2.0
        width = 5.0/nyq_rate
        ripple_db = 60.0
        O, beta = signal.kaiserord(ripple_db, width)
        cutoff_hz = 4.0
        taps = signal.firwin(O, cutoff_hz/nyq_rate, window=('kaiser', beta), pass_zero=False)
        column_data = signal.lfilter(taps, 1.0, tempf)

        ############################################                
        b, a = signal.butter(4, 50/(fs/2), 'low')
        tempf = signal.filtfilt(b,a, column1)
        yff = scipy.fftpack.fft(tempf)

        nyq_rate = fs/ 2.0
        width = 5.0/nyq_rate
        ripple_db = 60.0
        O, beta = signal.kaiserord(ripple_db, width)
        cutoff_hz = 4.0
        taps = signal.firwin(O, cutoff_hz/nyq_rate, window=('kaiser', beta), pass_zero=False)
        column1 = signal.lfilter(taps, 1.0, tempf)      
        

    filtered_data.append(column_data)
    before_processing_data.append(column1)

# Inverse the ECG signal to do the peak detection easily
filtered_data = np.array(filtered_data)
raw_data = np.array(raw_data)
    
asr_data = filtered_data.copy()
before_processing_data = np.array(before_processing_data)

## Clean the signal with ASR

In [22]:
from meegkit.asr import ASR
for i in range(6):
    asr_data[i] = filter_iir(asr_data[i], 'bandpass', [4, 40.0], fs, 4)

calib = asr_data[:6, 3140 *fs:3180 *fs]

for i in range(3):
    asr = ASR(sfreq=fs, method='rieman', estimator='lwf', win_overlap=0.7, cutoff=20)   
    _, sample_mask = asr.fit(calib[i*2: i*2+2])        
    window_size = 10*fs  # 500ms window in samples

    # Iterate through the entire data to remove motion artifacts if any
    for idx in range(0, asr_data.shape[1], window_size):     
        bad_data = asr_data[i*2: i*2+2, idx:idx+window_size]        
        clean_data, ma = asr.transform(bad_data)
        asr_data[i*2: i*2+2, idx:idx+window_size] = clean_data


## A similar concept to the PoAS