In [27]:
import pandas as pd
import numpy as np
import wfdb
import ast

In [28]:
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path+f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path+f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

In [29]:
help(ast.literal_eval)

Help on function literal_eval in module ast:

literal_eval(node_or_string)
    Safely evaluate an expression node or a string containing a Python
    expression.  The string or node provided may only consist of the following
    Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
    sets, booleans, and None.



In [30]:
path = '../data/physionet.org/files/ptb-xl/1.0.3/'
sampling_rate=500

In [35]:
# load and convert annotation data
Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')

# Pandas probably read the scp_codes as strings, we can change them to be interpreted as dictionaries with ast
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) 

In [24]:
# Load raw signal data
X = load_raw_data(Y, sampling_rate, path)

KeyboardInterrupt: 

In [37]:
# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

In [38]:
agg_df

Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7
IMI,inferior myocardial infarction,1.0,,,MI,IMI,Myocardial Infarction,inferior myocardial infarction,161.0,,,
ASMI,anteroseptal myocardial infarction,1.0,,,MI,AMI,Myocardial Infarction,anteroseptal myocardial infarction,165.0,,,
LVH,left ventricular hypertrophy,1.0,,,HYP,LVH,Ventricular Hypertrophy,left ventricular hypertrophy,142.0,,C71076,
LAFB,left anterior fascicular block,1.0,,,CD,LAFB/LPFB,Intraventricular and intra-atrial Conduction d...,left anterior fascicular block,101.0,MDC_ECG_BEAT_BLK_ANT_L_HEMI,C62267,D3-33140
ISC_,non-specific ischemic,1.0,,,STTC,ISC_,Basic roots for coding ST-T changes and abnorm...,ischemic ST-T changes,226.0,,,


In [39]:
def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

In [40]:
# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

In [44]:
Y['diagnostic_superclass'].value_counts()

diagnostic_superclass
[NORM]                 9069
[MI]                   2532
[STTC]                 2400
[CD]                   1708
[MI, CD]               1297
[HYP, STTC]             781
[MI, STTC]              599
[HYP]                   535
[CD, STTC]              471
[]                      411
[CD, NORM]              407
[MI, HYP, STTC]         361
[CD, HYP]               300
[MI, CD, STTC]          223
[CD, HYP, STTC]         211
[MI, HYP]               183
[MI, CD, HYP, STTC]     156
[MI, CD, HYP]           117
[NORM, STTC]             28
[CD, NORM, STTC]          5
[CD, HYP, NORM]           2
[HYP, NORM]               2
[MI, CD, HYP, NORM]       1
Name: count, dtype: int64

In [None]:
# Split data into train and test
test_fold = 10
# Train
X_train = X[np.where(Y.strat_fold != test_fold)]
y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass
# Test
X_test = X[np.where(Y.strat_fold == test_fold)]
y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass

In [None]:
Y

In [9]:
Y.scp_codes

ecg_id
1                 {'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}
2                             {'NORM': 80.0, 'SBRAD': 0.0}
3                               {'NORM': 100.0, 'SR': 0.0}
4                               {'NORM': 100.0, 'SR': 0.0}
5                               {'NORM': 100.0, 'SR': 0.0}
                               ...                        
21833    {'NDT': 100.0, 'PVC': 100.0, 'VCLVH': 0.0, 'ST...
21834             {'NORM': 100.0, 'ABQRS': 0.0, 'SR': 0.0}
21835                           {'ISCAS': 50.0, 'SR': 0.0}
21836                           {'NORM': 100.0, 'SR': 0.0}
21837                           {'NORM': 100.0, 'SR': 0.0}
Name: scp_codes, Length: 21799, dtype: object