In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import wfdb
import ast
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
import random

In [2]:
db_url = './ptbxl_database.csv'
scp_url = './scp_statements.csv'

In [3]:
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path+f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path+f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

In [4]:
def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_data.index:
            tmp.append(agg_data.loc[key].diagnostic_subclass)
    return list(set(tmp))

In [5]:
annotation_data = pd.read_csv(db_url, index_col='ecg_id')
annotation_data.scp_codes = annotation_data.scp_codes.apply(lambda x: ast.literal_eval(x))

In [6]:
signal_data = load_raw_data(annotation_data, 100, './')

In [7]:
agg_data = pd.read_csv(scp_url, index_col = 0)

In [8]:
agg_data

Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7
...,...,...,...,...,...,...,...,...,...,...,...,...
BIGU,"bigeminal pattern (unknown origin, SV or Ventr...",,,1.0,,,Statements related to ectopic rhythm abnormali...,"bigeminal pattern (unknown origin, SV or Ventr...",,,,
AFLT,atrial flutter,,,1.0,,,Statements related to impulse formation (abnor...,atrial flutter,51.0,MDC_ECG_RHY_ATR_FLUT,,
SVTAC,supraventricular tachycardia,,,1.0,,,Statements related to impulse formation (abnor...,supraventricular tachycardia,55.0,MDC_ECG_RHY_SV_TACHY,,D3-31290
PSVT,paroxysmal supraventricular tachycardia,,,1.0,,,Statements related to impulse formation (abnor...,paroxysmal supraventricular tachycardia,,MDC_ECG_RHY_SV_TACHY_PAROX,,


In [9]:
agg_data = agg_data[agg_data.diagnostic == 1]

In [10]:
annotation_data['diagnostic_subclass'] = annotation_data.scp_codes.apply(aggregate_diagnostic)

In [11]:
annotation_data.shape

(21837, 28)

In [12]:
agg_data.head(12)

Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7
IMI,inferior myocardial infarction,1.0,,,MI,IMI,Myocardial Infarction,inferior myocardial infarction,161.0,,,
ASMI,anteroseptal myocardial infarction,1.0,,,MI,AMI,Myocardial Infarction,anteroseptal myocardial infarction,165.0,,,
LVH,left ventricular hypertrophy,1.0,,,HYP,LVH,Ventricular Hypertrophy,left ventricular hypertrophy,142.0,,C71076,
LAFB,left anterior fascicular block,1.0,,,CD,LAFB/LPFB,Intraventricular and intra-atrial Conduction d...,left anterior fascicular block,101.0,MDC_ECG_BEAT_BLK_ANT_L_HEMI,C62267,D3-33140
ISC_,non-specific ischemic,1.0,,,STTC,ISC_,Basic roots for coding ST-T changes and abnorm...,ischemic ST-T changes,226.0,,,


In [13]:
annotation_data['diagnostic_subclass'] = annotation_data.scp_codes.apply(aggregate_diagnostic)

In [14]:
annotation_data['diagnostic_subclass'].value_counts()

diagnostic_subclass
[NORM]                            9083
[STTC]                            1404
[IMI]                             1250
[AMI]                              702
[ISC_, LVH]                        502
                                  ... 
[LAO/LAE, ISCI, ISCA, RAO/RAE]       1
[IVCD, LVH, NST_, STTC]              1
[IMI, ISCA, ILBBB]                   1
[LAO/LAE, LVH, CLBBB]                1
[ISCI, RAO/RAE]                      1
Name: count, Length: 691, dtype: int64

In [15]:
delete_id=[]
signal_info=[]
for index in range(len(annotation_data)):
    if len(annotation_data['diagnostic_subclass'].iat[index]) == 0: 
        delete_id.append(index)
    else: 
        signal_info.append(signal_data[index])

In [16]:
for index in range(len(delete_id)):
    annotation_data = annotation_data.drop(annotation_data.index[[delete_id[index] - index]])

In [17]:
annotation_data['diagnostic_subclass'].value_counts()

diagnostic_subclass
[NORM]                              9083
[STTC]                              1404
[IMI]                               1250
[AMI]                                702
[ISC_, LVH]                          502
                                    ... 
[IMI, CLBBB, _AVB, AMI, LAO/LAE]       1
[IRBBB, LAO/LAE, _AVB, STTC]           1
[IRBBB, ISC_, LVH, LMI]                1
[RVH, IMI, CRBBB, AMI]                 1
[ISCI, RAO/RAE]                        1
Name: count, Length: 690, dtype: int64

In [18]:
signal_info=np.array(signal_info)

In [19]:
signal_info.shape

(21430, 1000, 12)

In [20]:
subclass_code=np.array(annotation_data['diagnostic_subclass'])
subclass_code

array([list(['NORM']), list(['NORM']), list(['NORM']), ...,
       list(['ISCA']), list(['NORM']), list(['NORM'])], dtype=object)

In [21]:
subclass_code = subclass_code.reshape(len(subclass_code), 1)

In [22]:
subclass_code = subclass_code.reshape(-1, 1)
subclass_code

array([[list(['NORM'])],
       [list(['NORM'])],
       [list(['NORM'])],
       ...,
       [list(['ISCA'])],
       [list(['NORM'])],
       [list(['NORM'])]], dtype=object)

In [23]:
signal_info=signal_info.reshape(signal_info.shape[0],12000)
signal_info.shape

(21430, 12000)

In [24]:
annotation_data

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnostic_subclass
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,[NORM]
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,[NORM]
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,[NORM]
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,[NORM]
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,[NORM]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21833,17180.0,67.0,1,,,1.0,2.0,AT-60 3,2001-05-31 09:14:35,ventrikulÄre extrasystole(n) sinustachykardie ...,...,,", alles,",,,1ES,,7,records100/21000/21833_lr,records500/21000/21833_hr,[STTC]
21834,20703.0,93.0,0,,,1.0,2.0,AT-60 3,2001-06-05 11:33:39,sinusrhythmus lagetyp normal qrs(t) abnorm ...,...,,,,,,,4,records100/21000/21834_lr,records500/21000/21834_hr,[NORM]
21835,19311.0,59.0,1,,,1.0,2.0,AT-60 3,2001-06-08 10:30:27,sinusrhythmus lagetyp normal t abnorm in anter...,...,,", I-AVR,",,,,,2,records100/21000/21835_lr,records500/21000/21835_hr,[ISCA]
21836,8873.0,64.0,1,,,1.0,2.0,AT-60 3,2001-06-09 18:21:49,supraventrikulÄre extrasystole(n) sinusrhythmu...,...,,,,,SVES,,8,records100/21000/21836_lr,records500/21000/21836_hr,[NORM]


In [25]:
# test_fold = 30
# X_train = signal_info[np.where(annotation_data.strat_fold  != test_fold)]
# Y_train = annotation_data[(annotation_data.strat_fold  != test_fold)].diagnostic_subclass

# X_test = signal_info[np.where(annotation_data.strat_fold  == test_fold)]
# Y_test = annotation_data[(annotation_data.strat_fold  == test_fold)].diagnostic_subclass

# print(X_train.shape, Y_train.shape, X_test.shape, Y_test.shape)

In [26]:
X_train, X_test, Y_train, Y_test = train_test_split(signal_info, annotation_data['diagnostic_subclass'], test_size=0.30, random_state=42)
print(X_train.shape, Y_train.shape, X_test.shape, Y_test.shape)

In [27]:
from sklearn.preprocessing import MultiLabelBinarizer

encoder = MultiLabelBinarizer()
y_train_encoded = encoder.fit_transform(Y_train)
y_train_encoded

array([[0, 1, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 1, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [28]:
y_train_encoded[0]

array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0])

In [29]:
encoder = MultiLabelBinarizer()
y_test_encoded = encoder.fit_transform(Y_test)
y_test_encoded

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 1, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [30]:
np.save('TrainSubClass\X_train',X_train)
np.save('TrainSubClass\Y_train',y_train_encoded)
np.save('TestSubClass\X_test',X_test)
np.save('TestSubClass\Y_test',y_test_encoded)

In [31]:
print(X_train.shape, Y_train.shape, X_test.shape, Y_test.shape)

(15001, 12000) (15001,) (6429, 12000) (6429,)


In [32]:
Y_test.value_counts()

diagnostic_subclass
[NORM]                             2713
[STTC]                              442
[IMI]                               386
[AMI]                               199
[IMI, AMI]                          152
                                   ... 
[ISCA, STTC, LAFB/LPFB]               1
[IVCD, IMI, LAFB/LPFB, AMI]           1
[IVCD, ISCA, LVH, AMI]                1
[IRBBB, LVH, CRBBB]                   1
[LAO/LAE, AMI, IVCD, ISC_, LVH]       1
Name: count, Length: 398, dtype: int64