In [28]:
import os
import wfdb
import pickle
import pandas as pd
import numpy as np
from keras.utils import to_categorical
from tqdm import tqdm_notebook
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True)
%matplotlib inline

In [29]:
data_root = "./data/"
fs = 250
note_list = ['00735','03665','04043','04936','05091','06453','08378','08405','08434','08455']

In [30]:
keys = []
for i in os.listdir(data_root):
    if not ".dat" in i:continue
    key = i.split(".")[0]
    if key not in note_list:
        keys.append(key)
print("There are",len(keys),"records")

There are 15 records


## Split keys to train and test 

In [4]:
cross = 5
num_keys_every_cross = int(len(keys)/cross)
for cross_idx in range(cross):
    test_keys = keys[cross_idx*num_keys_every_cross:(cross_idx+1)*num_keys_every_cross]
    assert(len(test_keys)==num_keys_every_cross)
    print(test_keys)

['08219', '06426', '05261']
['04746', '07879', '04126']
['07162', '04908', '04048']
['06995', '08215', '05121']
['04015', '07859', '07910']


## Check R peak Detection Annotation in .qrs files

In [31]:
save_dir = "./CheckQRS"
win_len = 10 #second
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
    
for key in keys:
    save_folder = os.path.join(save_dir,key)
    if not os.path.exists(save_folder):
        os.mkdir(save_folder)
    
    signals,fields = wfdb.rdsamp(os.path.join(data_root,key))
    ann_RRI = wfdb.rdann(os.path.join(data_root,key),extension="qrs").sample
    plt.plot(signals[:,0])
    time_axis = [i/fs for i in range(len(signals))]
    start = 0
    end = start + win_len
    plt.plot(time_axis,signals[:,0])
    plt.xlabel("Time(s)")
    
    print("plot R-peak detection for",key)
    for idx in tqdm_notebook(ann_RRI[:100]):
        plt.scatter(x=idx/fs,y=signals[idx,0],color="r")
    
    while(end<=(ann_RRI[100+1])/fs):
        plt.xlim(start,end)
        start += win_len
        end = start+win_len
        plt.savefig(os.path.join(save_folder,"start="+str(start)+"s.png"))
    plt.clf()
    print("saved R-peak detection for",key,"\n")

plot R-peak detection for 08219


HBox(children=(IntProgress(value=0), HTML(value='')))


saved R-peak detection for 08219 



KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>

## Check Annotation Labels

In [32]:
for key in keys:
    ann = wfdb.rdann(os.path.join(data_root,key),extension="atr")
    print(key,set(ann.aux_note))

08219 {'(AFIB', '(N'}
06426 {'(AFL', '(AFIB', '(N', '(J'}
05261 {'(AFIB', '(N'}
04746 {'(AFIB', '(N'}
07879 {'(AFIB', '(N', '(J'}
04126 {'(AFIB', '(N'}
07162 {'(AFIB'}
04908 {'(AFL', '(AFIB', '(N'}
04048 {'(AFIB', '(N'}
06995 {'(AFL', '(AFIB', '(N'}
08215 {'(AFL', '(AFIB', '(N'}
05121 {'(AFIB', '(N', '(J'}
04015 {'(AFIB', '(N'}
07859 {'(AFIB'}
07910 {'(AFL', '(AFIB', '(N'}


## Check shortest segment length

In [33]:
for key in keys:
    ann = wfdb.rdann(os.path.join(data_root,key),extension="atr")
    seg_lengths = []
    total_length = 10*3600*fs
    if len(ann.sample)==1:
        seg_lengths.append(total_length - ann.sample[0])
    else:
        for i,s in enumerate(ann.sample[:-1]):
            seg_lengths.append(ann.sample[i+1]-s)
    
    print(key,min(seg_lengths))

08219 2226
06426 381
05261 848
04746 1174
07879 685
04126 4269
07162 8999826
04908 883
04048 4387
06995 1239
08215 12560
05121 624
04015 421
07859 8999957
07910 1695


## Make dataset

In [34]:
def label2int(label):
    if label=="N":
        return 0
    if label=="AFIB":
        return 1
    if label=="AFL":
        return 2
    if label== "J":
        return 3

In [35]:
def generate_dataset(key,seg_len=5):
    ann = wfdb.rdann(os.path.join(data_root,key),extension="atr")
    signals,fields = wfdb.rdsamp(os.path.join(data_root,key))
    annotation_indics = ann.sample
    annotation_labels = [label2int(i[1:]) for i in ann.aux_note]
    last_index = 0
    X = []
    y = []

    #special process when there is only one annotation
    if(len(annotation_indics)==1):
        long_seg = signals[annotation_indics[0]:]
        label = annotation_labels[0]
        start = 0
        end = start+seg_len*fs
        while(end<len(long_seg)):
            short_seg = long_seg[start:end]
            start += seg_len*fs
            end = start + seg_len*fs
            X.append(short_seg)
            y.append(label)
        X = np.array(X)
        #y = to_categorical(np.array(y),num_classes=4)
        y = np.array(y)
        return X,y
    
    for i,idx in enumerate(annotation_indics):
        long_seg = signals[last_index:idx]
        label = annotation_labels[i]
        last_index = idx
        
        #slice long_seg to short_seg
        if len(long_seg)<seg_len*fs:continue
        start = 0
        end = start+seg_len*fs
        while(end<len(long_seg)):
            short_seg = long_seg[start:end]
            start += seg_len*fs
            end = start + seg_len*fs
            X.append(short_seg)
            y.append(label)
    X = np.array(X)
    #y = to_categorical(np.array(y),num_classes=4)
    y=np.array(y)
    return X,y

In [63]:
def make_balanced_and_shuffle(X,y,num_calss=2):
    one_hot_y = to_categorical(y,4)
    
    num_each_class = np.sum(one_hot_y,axis=0)
    num = int(min(num_each_class[0],num_each_class[1]))
    
    balanced_X = []
    balanced_y = []
    for class_idx in range(num_calss):
        indices = np.where(y==class_idx)
        selected_X = X[indices]
        selected_y = y[indices]
        balanced_X.append( selected_X[:num] )
        balanced_y.append( selected_y[:num])
    X = np.concatenate(balanced_X,axis=0)
    y = np.concatenate(balanced_y,axis=0)
    rand_index = np.random.permutation(len(X))
    X = X[rand_index]
    y = to_categorical(y[rand_index],num_calss)
    return X,y

In [64]:
def normalizaiton(X_train,X_test):
    ch1 = X_train[:,:,0]
    ch2 = X_train[:,:,1]
    
    scaler1 = MinMaxScaler()
    scaler1.fit(ch1)
    scaler2 = MinMaxScaler()
    scaler2.fit(ch2)
    
    X_train = np.transpose(np.array([scaler1.transform(ch1),scaler2.transform(ch2)]),[1,2,0])
    X_test = np.transpose(np.array([scaler1.transform(X_test[:,:,0]),scaler2.transform(X_test[:,:,1])]),[1,2,0])
    return X_train,X_test

In [65]:
cross = 5
num_keys_every_cross = int(len(keys)/cross)
save_dir = "./dataset"
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

for cross_idx in range(cross):
    save_path = os.path.join(save_dir,"data-cross-"+str(cross_idx)+".pickle")
    X_train = []
    X_test = []
    y_train = []
    y_test = []
    test_keys = keys[cross_idx*num_keys_every_cross:(cross_idx+1)*num_keys_every_cross]
    print(test_keys)
    assert(len(test_keys)==num_keys_every_cross)
    for key in keys:
        X,y = generate_dataset(key)
        if key in test_keys:
            X_test.append(X)
            y_test.append(y)
        else:
            X_train.append(X)
            y_train.append(y)
            
    X_train = np.concatenate(X_train,axis=0)
    y_train = np.concatenate(y_train,axis=0)
    X_test = np.concatenate(X_test,axis=0)
    y_test = np.concatenate(y_test,axis=0)
    
    X_train,y_train = make_balanced_and_shuffle(X_train,y_train)
    X_test,y_test = make_balanced_and_shuffle(X_test,y_test)
    
    X_train,X_test = normalizaiton(X_train,X_test)
    print(X_train.shape)
    print(y_train.shape)
    print(X_test.shape)
    print(y_test.shape)
    
    dataset = {"test_key":test_keys,"X_train":X_train,"y_train":y_train,"X_test":X_test,"y_test":y_test}
    with open(save_path,"wb") as f:
        pickle.dump(obj=dataset,file=f)
    print("saved at",save_path)


['08219', '06426', '05261']
(19002, 1250, 2)
(19002, 2)
(6330, 1250, 2)
(6330, 2)
saved at ./dataset/data-cross-0.pickle
['04746', '07879', '04126']
(16968, 1250, 2)
(16968, 2)
(8364, 1250, 2)
(8364, 2)
saved at ./dataset/data-cross-1.pickle
['07162', '04908', '04048']
(24376, 1250, 2)
(24376, 2)
(956, 1250, 2)
(956, 2)
saved at ./dataset/data-cross-2.pickle
['06995', '08215', '05121']
(16534, 1250, 2)
(16534, 2)
(8798, 1250, 2)
(8798, 2)
saved at ./dataset/data-cross-3.pickle
['04015', '07859', '07910']
(24448, 1250, 2)
(24448, 2)
(884, 1250, 2)
(884, 2)
saved at ./dataset/data-cross-4.pickle
