In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from os import listdir
import seaborn 
import peakutils
import wfdb
import pywt
from sklearn.model_selection import train_test_split

In [2]:
dataPath = '/Users/oobiri/Documents/Carleton/AFDetect/data/ltaf/'

In [3]:
# Function that gets patient Ids - Patients 00735 & 03665 raw data is unavailabe, patients 04936 & 05091 include incorrect annotation
def get_ids(filePath):
    patientIds = []
    
    for filename in listdir(dataPath):
        if filename.endswith(".dat") and '04936' not in filename and '05091' not in filename:
            newName = filename.replace('.dat','')
            patientIds.append(newName)
        else:
            continue
        
    return patientIds

In [4]:
pIds = get_ids(dataPath)

In [5]:
df = pd.DataFrame()

for pi in pIds:
    file = dataPath + pi
    annotation = wfdb.rdann(file, 'atr')
    sym = annotation.symbol
    
    values, counts = np.unique(sym, return_counts=True)
    df_sub = pd.DataFrame({'sym':values, 'val':counts, 'pi':[pi]*len(counts)})
    df = pd.concat([df, df_sub], axis = 0)

In [6]:
df.groupby('sym').val.sum().sort_values(ascending = False)

sym
N    8710873
A     152332
V     132679
+      53704
"       5959
Q         89
Name: val, dtype: int64

In [5]:
nonBeat = ["[", "!", "]", "x", "(", ")", "p", "t", "u", "`", "'", "^", "|", "~", "+", "s", "T", "*", "D", "=", '"', "@"]

#For when using arrythmia datasets
abnormalBeat =  ["L", "R", "B", "A", "a", "J", "S", "V", "r", "F", "e", "j", "n", "E", "/", "f", "Q", "?"]

afBeat = "A"

nonafBeat = ["L", "R", "B", "N", "a", "J", "S", "V", "r", "F", "e", "j", "n", "E", "/", "f", "Q", "?"]

In [6]:
# Function that loads a patient's signls and annotations
def load_ecg(file):
    record = wfdb.rdrecord(file)
    annotation = wfdb.rdann(file, 'atr')
    
    p_signal = record.p_signal
    
    ann_sym = annotation.symbol
    ann_sample = annotation.sample
    
    return p_signal, ann_sym, ann_sample

In [7]:
#Creates the x,y matrices for each beat
def build_XY(p_signal, df_ann, num_cols, nonaf):
    numRows = len(df_ann)
    
    #initialize arrays
    signals = np.zeros((numRows, num_cols))
    labels = np.zeros((numRows, 1))
    sym = []
    
    max_row = 0 
    
    for ann_sample, ann_sym in zip(df_ann.ann_sample.values, df_ann.ann_sym.values):
        left = max([0, ann_sample - num_sec*fs])
        right = min([len(p_signal), (ann_sample + num_sec*fs)])
        x = p_signal[left: right]
        if len(x) == num_cols:
            signals[max_row,:] = x
            labels[max_row,:] = int(ann_sym in nonaf) #
            sym.append(ann_sym)
            max_row += 1
            
    signals = signals[:max_row,:]
    labels = labels[:max_row,:]
    
    return signals, labels, sym

In [8]:
#creates dataset that is centered on beats +- 15 seconds before and after
def create_dataset(pids, num_sec, fs, nonaf):
    
    num_cols = 2*num_sec*fs
    all_signals = np.zeros((1,num_cols))
    all_labels = np.zeros((1,1))
    all_sym = []
    
    max_rows = []
    
    for pi in pids:
        file = dataPath + pi
        
        p_signal, ann_sym, ann_sample = load_ecg(file)
        
        # grab the first signal
        p_signal = p_signal[:,0]
        
        # make df to exclude the nonbeats
        df_ann = pd.DataFrame({'ann_sym':ann_sym,
                              'ann_sample':ann_sample})
        df_ann = df_ann.loc[df_ann.ann_sym.isin(nonaf + ['A'])]
        
        X,Y,sym = build_XY(p_signal,df_ann, num_cols, nonaf)
        all_sym = all_sym+sym
        max_rows.append(X.shape[0])
        all_signals = np.append(all_signals,X,axis = 0)
        all_labels = np.append(all_labels,Y,axis = 0)
    # drop the first zero row
    all_signals = all_signals[1:,:]
    all_labels = all_labels[1:,:]
    
    # check sizes make sense
    assert np.sum(max_rows) == all_signals.shape[0], 'number of signals, max_rows rows messed up'
    assert all_labels.shape[0] == all_signals.shape[0], 'number of signals, labels rows messed up'
    assert all_labels.shape[0] == len(all_sym), 'number of labels, sym rows messed up'

    return all_signals, all_labels, all_sym

In [9]:
num_sec = 15
fs = 128

In [None]:
all_signals, all_labels, all_sym = create_dataset(pIds, num_sec, fs, nonafBeat)

In [None]:
X_train, X_valid, y_train, y_valid = train_test_split(all_signals, all_labels, test_size=0.33, random_state=42)