In [None]:
import ast
import wfdb
import numpy as np
import pandas as pd

pd.options.display.max_colwidth = 200
pd.options.display.max_columns = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import neurokit2 as nk
import multiprocessing
import warnings

from sklearn import metrics

from chart_studio import plotly
from plotly import tools
from plotly.subplots import make_subplots
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.graph_objs as go
init_notebook_mode(connected=True)

DB_ROOT = 'data/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1'

Делим scp_codes на отдельные колонки. Проверяем сбалансированность классов.

In [8]:
SCP_LABELS = {
    'SR': 'sinus rhythm',
    'SARRH': 'sinus arrhythmia',
    'SBRAD': 'bradycardia',
    'STACH': 'sinus tachycardia',
    'AFIB': 'artrial fibrillation',
}
SCP_LABELS_ARR = ['SR', 'SARRH', 'SBRAD', 'STACH', 'AFIB']

Y = pd.read_csv(f'{DB_ROOT}/ptbxl_database.csv')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

# Split scp labels into separate columns
for scp_label in SCP_LABELS_ARR:
    Y[scp_label] = Y.scp_codes.apply(lambda x: int(scp_label in x))

# If one of the illnesses or normal
Y['labels_cnt'] = Y[SCP_LABELS_ARR].sum(axis=1)
Y['has_label'] = Y.labels_cnt > 0

Y.head(2)

Unnamed: 0,ecg_id,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,scp_codes,heart_axis,infarction_stadium1,infarction_stadium2,validated_by,second_opinion,initial_autogenerated_report,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,SR,SARRH,SBRAD,STACH,AFIB,labels_cnt,has_label
0,1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,"{'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}",,,,,False,False,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,1,0,0,0,0,1,True
1,2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,"{'NORM': 80.0, 'SBRAD': 0.0}",,,,,False,False,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,0,0,1,0,0,1,True


In [5]:
def count_nonzero(x):
    return np.sum(x > 0)

Y[['strat_fold', 'has_label', 'ecg_id'] + SCP_LABELS_ARR].groupby(['strat_fold', 'has_label']).agg(count_nonzero).reset_index()

Unnamed: 0,strat_fold,has_label,ecg_id,SR,SARRH,SBRAD,STACH,AFIB
0,1,False,126,0,0,0,0,0
1,1,True,2051,1678,77,63,82,151
2,2,False,131,0,0,0,0,0
3,2,True,2053,1678,77,64,83,151
4,3,False,144,0,0,0,0,0
5,3,True,2050,1678,77,63,82,151
6,4,False,121,0,0,0,0,0
7,4,True,2054,1679,77,64,83,151
8,5,False,121,0,0,0,0,0
9,5,True,2055,1679,78,64,83,152


Видим, что лейблов SR (синусовый ритм) слишком много. Оставим в каждом фолде по 100 псевдо-случайных ЭКГ со статусом SR.

In [6]:
def get_random_n(obj, n, replace=False, seed=123):
    np.random.seed(seed)
    return obj.loc[np.random.choice(obj.index, n, replace), :]
    
SR_ecgids = Y[Y.SR == 1].groupby('strat_fold', as_index=False).apply(lambda r: get_random_n(r, 100))['ecg_id'].values

Y = Y[(Y.SR == 0) | (Y.ecg_id.isin(SR_ecgids))]
Y[['strat_fold', 'has_label', 'ecg_id'] + SCP_LABELS_ARR].groupby(['strat_fold', 'has_label']).agg(count_nonzero).reset_index()

Unnamed: 0,strat_fold,has_label,ecg_id,SR,SARRH,SBRAD,STACH,AFIB
0,1,False,126,0,0,0,0,0
1,1,True,473,100,77,63,82,151
2,2,False,131,0,0,0,0,0
3,2,True,475,100,77,64,83,151
4,3,False,144,0,0,0,0,0
5,3,True,472,100,77,63,82,151
6,4,False,121,0,0,0,0,0
7,4,True,475,100,77,64,83,151
8,5,False,121,0,0,0,0,0
9,5,True,476,100,78,64,83,152


Cчитаем, как часто встречаются несколько labels для одной ЭКГ.<br>
P.S. Для данной задачи всего 3 случая.

In [7]:
print(f"ECG-examples with more than 1 label: {Y[Y.labels_cnt > 1].shape[0]}")
Y[Y.labels_cnt > 1][['report', 'scp_codes']]

ECG-examples with more than 1 label: 3


Unnamed: 0,report,scp_codes
283,"sinus bradycardia with sinus arrhythmia. the cause of the bradycardia is not evident. voltages are high in chest leads suggesting lvh. st segments are depressed in i, ii, avl, v4,5,6. this may be ...","{'LVH': 100.0, 'ISC_': 100.0, 'DIG': 100.0, 'VCLVH': 0.0, 'STD_': 0.0, 'SBRAD': 0.0, 'SARRH': 0.0}"
10362,"sinus bradycardia with sinus arrhythmia. the bradycardia may be physiological. st segments are elevated in i, ii, avf, v2-6, this is probably a normal variant. high v lead voltages are probably...","{'NORM': 100.0, 'SBRAD': 0.0, 'SARRH': 0.0}"
12282,sinus bradycardia with sinus arrhythmia. otherwise normal ecg. the cause of the bradycardia is not evident.,"{'NORM': 80.0, 'SBRAD': 0.0, 'SARRH': 0.0}"


Теперь смотрим на характерные признаки заболеваний.

In [16]:
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(f"{path}/{f}") for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(f"{path}/{f}") for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

X = load_raw_data(Y, 100, DB_ROOT)

# Clean ECG signal
X = np.apply_along_axis(nk.ecg_clean, 1, X, 100)

In [15]:
test_fold = 10 # from the dataset recommendations
SR_ecg = X[(Y.SR == 1)&(Y.strat_fold == test_fold)][0][0]

NameError: name 'X' is not defined