## Setup

In [None]:
import pandas as pd
import numpy as np

import wfdb
from wfdb import processing
import pywt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

lead_keys = [ 'I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6' ]
ld = dict(zip( lead_keys, range(len(lead_keys)) ))

### [PTB-XL](https://physionet.org/content/ptb-xl/1.0.1/)

In [2]:
PATH = "ptbxl/"
fs = 500.0  # sampling frequency (Hz)
samp_cutoff = 5000

metadata = pd.read_csv(PATH + "ptbxl_database.csv", index_col="ecg_id", dtype={
    'patient_id': np.int32,
})

def hasCAD(value):
    if "ischaemic heart" in value:
        return "Coronary artery disease"
    elif "normal ecg" in value:
        return "Healthy control"
    else:
        return "other"

records = metadata[["filename_hr"]].copy()

records["report"] = metadata["report"].map(hasCAD)
records = records[records["report"].isin(["Coronary artery disease", "Healthy control"])] 
records = records.reset_index(drop=True)

records.rename(columns={"filename_hr": "record", "report":"diagnostic_class"}, inplace=True)
records["record"] = PATH + records["record"]

records_ptbxl = records
records_ptbxl.describe()

Unnamed: 0,record,diagnostic_class
count,3224,3224
unique,3224,2
top,ptbxl/records500/00000/00253_hr,Healthy control
freq,1,2786


### [PTB Diagnostic ECG Database](https://physionet.org/content/ptbdb/1.0.0/)

In [3]:
PATH = "ptbdb/"
fs = 1000.0  # sampling frequency (Hz)
samp_cutoff = 30000

records = pd.read_fwf(PATH + "RECORDS", sep='\n', names=["record"])
records["diagnostic_class"] = [ wfdb.rdheader(PATH + i).comments[4][22:] for i in records["record"] ]
 
records = records[records["diagnostic_class"].isin(["Myocardial infarction", "Healthy control"])] 
records = records.reset_index(drop=True)

records["record"] = PATH + records["record"]

records_ptbdb = records
records_ptbdb.describe()

Unnamed: 0,record,diagnostic_class
count,448,448
unique,448,2
top,ptbdb/patient001/s0010_re,Myocardial infarction
freq,1,368


### [St Petersburg INCART 12-lead Arrhythmia Database](https://physionet.org/content/incartdb/1.0.0/)

In [4]:
PATH = "st-petersburg/"
fs = 257.0  # sampling frequency (Hz)
samp_cutoff = 462600

records = []

with open(PATH + "files-patients-diagnoses.txt") as f:
    content = f.readlines()
    for i in range(len(content)):
        if i % 3 == 1:
            buffer = content[i]
        elif i % 3 == 2:
            if "Coronary artery disease" in content[i]:
                records.extend(buffer.strip().split())
                
records = pd.DataFrame(
    list(zip(records, ["Coronary artery disease"] * len(records))), 
    columns=['record', 'diagnostic_class']
)

records["record"] = PATH + records["record"]

records_petersburg = records
records_petersburg.describe()

Unnamed: 0,record,diagnostic_class
count,17,17
unique,17,1
top,st-petersburg/I01,Coronary artery disease
freq,1,17


## Combine Datasets

In [5]:
all_records = pd.DataFrame(pd.concat([records_ptbxl, records_petersburg]))  # PTB DB is excluded
all_records.describe()

Unnamed: 0,record,diagnostic_class
count,3241,3241
unique,3241,2
top,ptbxl/records500/00000/00253_hr,Healthy control
freq,1,2786


## Test-Train Split

In [6]:
le = LabelEncoder()

X_temp, X_test, y_temp, y_test = train_test_split(
    # specifiying a random states seeds the random shuffle, 42 is common as per the scikit docs
    all_records["record"], le.fit_transform(all_records["diagnostic_class"]), test_size=0.1, random_state=42
)
X_train, X_validate, y_train, y_validate = train_test_split(X_temp, y_temp, test_size=0.1, random_state=42)

def uniq(data):
    return np.unique(data, return_counts=True)[1]

pd.DataFrame(
    [ uniq(y_train), uniq(y_test), uniq(y_validate) ],
    [ "Training Set", "Testing Set", "Validation Set" ],
    le.classes_,
)

Unnamed: 0,Coronary artery disease,Healthy control
Training Set,362,2262
Testing Set,50,275
Validation Set,43,249


## Pre-processing

In [7]:
def clean(lead_x):
    # transformation in https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7531205/#j_joeb-2019-0007_s_003_s_002title
    wvlt = 'db6'

    # max useful decomposition level
    max_level = pywt.dwt_max_level(data_len=len(lead_x), filter_len=pywt.Wavelet(wvlt).dec_len)  

    coeffs = pywt.wavedec(
        lead_x,
        wvlt,
        level=max_level
    )

    coeffs[0] = np.zeros(len(coeffs[0]))  # remove lowest frequency band
    
    # inverse DWT (multilevel reconstruction)
    idwt = pywt.waverec(
        coeffs,
        wvlt,
    )
    
    return idwt

In [8]:
downsample_frequency = 250.0

def pre_process(record):
    if record.startswith("ptbxl"):
        samp_cutoff = 5000
    elif record.startswith("ptbdb"):
        samp_cutoff = 30000
    elif record.startswith("st-petersburg"):
        samp_cutoff = 462600
        
    record = wfdb.rdrecord(record, sampto=samp_cutoff)   
    
    signal = record.p_signal[:, ld['II']]  # has better timedelta than using 'channel' param
    
    signal = clean(signal) # cleaning signal before segmentation has more consistent results 
    
    # second return value discards the new sample locations
    samples, _ = wfdb.processing.resample_sig(signal, fs, downsample_frequency)
    
    return samples

In [9]:
segment_len = int(downsample_frequency * 5)  # five seconds worth of samples

def segment(signal, num_segments):
    return np.array_split(signal[:(segment_len * num_segments)], num_segments)

In [None]:
# 1. pre-process
# create a new pandas column for filtered signal, fill that column
all_records["signal"] = all_records["record"].apply(pre_process)

In [None]:
# 2. segment each record
all_signals = np.empty((0, segment_len))
all_keys = []

for index, row in all_records.iterrows():
    num_segments = int(len(row["signal"]) // segment_len)
    
    all_signals = np.vstack( (all_signals, segment(row["signal"], num_segments)) )
    all_keys.extend([row["diagnostic_class"]] * num_segments)

## Export

In [None]:
# 3. Save the keys and signals to local file
import pickle

with open("all_signals.pickle", "wb") as fp:
    pickle.dump(all_signals, fp)

with open("all_keys.pickle", "wb") as fp:
    pickle.dump(all_keys, fp)