In [26]:
import ast
import wfdb
import torch
import numpy as np
import pandas as pd
import neurokit2 as nk
from scipy.stats import wilcoxon

from models import ECGNet, ECGWithMetaNet, ModelWrapper
from ecg_features import calc_features, load_features
from multilabel_metrics import accuracy_multilabel, f1_multilabel, roc_auc_multilabel
from multilabel_metrics import sensitivity_multilabel, specificity_multilabel, get_metrics, print_metrics

from chart_studio import plotly
from plotly.offline import init_notebook_mode, iplot
import plotly.graph_objs as go
init_notebook_mode(connected=True)

pd.options.display.max_colwidth = 200
pd.options.display.max_columns = 200

### Подготовка датасетов

Классы и метрики, которые для них будем считать 

In [27]:
SCP_LABELS = {
    'SR': 'sinus rhythm',
    'SARRH': 'sinus arrhythmia',
    'SBRAD': 'bradycardia',
    'STACH': 'sinus tachycardia',
    'AFIB': 'artrial fibrillation',
}
SCP_LABELS_ARR = ['SR', 'SARRH', 'SBRAD', 'STACH', 'AFIB']

METRICS = [
    accuracy_multilabel,
    f1_multilabel,
    roc_auc_multilabel,
    sensitivity_multilabel,
    specificity_multilabel
]
LABELS = [f'{SCP_LABELS[l]} [{l}]' for l in SCP_LABELS_ARR] + ['Total']
METRICS_LABELS = ['Accuracy', 'f1', 'ROC AUC', 'Sensitivity [TPR]', 'Specificity [TNR]']

DB_ROOT = 'data/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1'

Загружаем диагнозы в виде SCP-кодов и делим их на отдельные колонки.

In [3]:
Y = pd.read_csv(f'{DB_ROOT}/ptbxl_database.csv')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

# Split scp labels into separate columns
for scp_label in SCP_LABELS_ARR:
    Y[scp_label] = Y.scp_codes.apply(lambda x: int(scp_label in x))

# If one of the illnesses or normal
Y['labels_cnt'] = Y[SCP_LABELS_ARR].sum(axis=1)
Y['has_label'] = Y.labels_cnt > 0

Y.head(2)

Unnamed: 0,ecg_id,patient_id,age,sex,height,weight,nurse,site,device,recording_date,...,strat_fold,filename_lr,filename_hr,SR,SARRH,SBRAD,STACH,AFIB,labels_cnt,has_label
0,1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,...,3,records100/00000/00001_lr,records500/00000/00001_hr,1,0,0,0,0,1,True
1,2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,...,2,records100/00000/00002_lr,records500/00000/00002_hr,0,0,1,0,0,1,True


Оставляем в каждом фолде по 100 псевдо-случайных ЭКГ со статусом SR.

In [4]:
def count_nonzero(x):
    return np.sum(x > 0)

def get_random_n(obj, n, replace=False, seed=123):
    np.random.seed(seed)
    return obj.loc[np.random.choice(obj.index, n, replace), :]
    
SR_ecgids = Y[Y.SR == 1].groupby('strat_fold', as_index=False).apply(lambda r: get_random_n(r, 100))['ecg_id'].values
Y = Y[(Y.SR == 0) | (Y.ecg_id.isin(SR_ecgids))]

# Check result
Y[['strat_fold', 'has_label', 'ecg_id'] + SCP_LABELS_ARR].groupby(['strat_fold', 'has_label']).agg(count_nonzero).reset_index()

Unnamed: 0,strat_fold,has_label,ecg_id,SR,SARRH,SBRAD,STACH,AFIB
0,1,False,126,0,0,0,0,0
1,1,True,473,100,77,63,82,151
2,2,False,131,0,0,0,0,0
3,2,True,475,100,77,64,83,151
4,3,False,144,0,0,0,0,0
5,3,True,472,100,77,63,82,151
6,4,False,121,0,0,0,0,0
7,4,True,475,100,77,64,83,151
8,5,False,121,0,0,0,0,0
9,5,True,476,100,78,64,83,152


Загружаем и сглаживаем сигналы ЭКГ с sampling_rate=100

In [6]:
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(f"{path}/{f}") for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(f"{path}/{f}") for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

X = load_raw_data(Y, 100, DB_ROOT)

# Smooth ECG signal
X = np.apply_along_axis(nk.ecg_clean, 1, X, 100)

### Обучаем простую CNN-модель

Используем 10-fold cross validation. Сохраняем модели.

In [None]:
train_acc = {}
val_acc = {}
best_val_acc = {}

for i in range(1, 11):
    print(f'Start train {i}-th model')
    
    # Train dataset
    X_train = X[np.where(Y.strat_fold != i)]
    Y_train = Y[Y.strat_fold != i]
    y_train = Y_train[SCP_LABELS_ARR].values

    # Validation dataset
    X_val = X[np.where(Y.strat_fold == i)]
    Y_val = Y[Y.strat_fold == i]
    y_val = Y_val[SCP_LABELS_ARR].values

    wrapper = ModelWrapper(ECGNet())
    save_path = f'./models/simple_ecg_{i}'
    ta, va, bva = wrapper.train_n_epochs(
        X_train, None, y_train, X_val, None, y_val,
        save_path, epochs=35, verbose=False, batch_size=256
    )
    train_acc[i] = ta
    val_acc[i] = va
    best_val_acc[i] = bva

Считаем метрики

In [13]:
# Load models of 10-fold cross validation and evaluate
predictions_1 = []
for i in range(1, 11):
    print(f'Start evaluate {i}-th model')
    
    # Validation dataset
    X_val = X[np.where(Y.strat_fold == i)]
    X_val_T = torch.Tensor(np.transpose(X_val, [0, 2, 1]))
    Y_val = Y[Y.strat_fold == i]
    y_val = Y_val[SCP_LABELS_ARR].values
    
    wrapper = ModelWrapper(ECGNet())
    save_path = f'./models/simple_ecg_{i}'
    wrapper.load(save_path)
    y_pred_weights = wrapper.predict(X_val_T).detach().numpy()
    predictions_1.append([y_val, y_pred_weights])

Start evaluate 1-th model
Start evaluate 2-th model
Start evaluate 3-th model
Start evaluate 4-th model
Start evaluate 5-th model
Start evaluate 6-th model
Start evaluate 7-th model
Start evaluate 8-th model
Start evaluate 9-th model
Start evaluate 10-th model


In [14]:
metrics_1 = get_metrics(predictions_1, METRICS)
print_metrics(LABELS, METRICS_LABELS, metrics_1)

sinus rhythm [SR]
	Accuracy: 0.8337  f1: 0.4249  ROC AUC: 0.8429  Sensitivity [TPR]: 0.396  Specificity [TNR]: 0.9206
sinus arrhythmia [SARRH]
	Accuracy: 0.8537  f1: 0.1873  ROC AUC: 0.8084  Sensitivity [TPR]: 0.1532  Specificity [TNR]: 0.9559
bradycardia [SBRAD]
	Accuracy: 0.9306  f1: 0.6871  ROC AUC: 0.9553  Sensitivity [TPR]: 0.7313  Specificity [TNR]: 0.9539
sinus tachycardia [STACH]
	Accuracy: 0.9548  f1: 0.8328  ROC AUC: 0.9795  Sensitivity [TPR]: 0.8354  Specificity [TNR]: 0.9737
artrial fibrillation [AFIB]
	Accuracy: 0.8931  f1: 0.7592  ROC AUC: 0.954  Sensitivity [TPR]: 0.6778  Specificity [TNR]: 0.9649
Total
	Accuracy: 0.5606  f1: 0.5782  ROC AUC: 0.908  Sensitivity [TPR]: 0.5588  Specificity [TNR]: 0.9538


### Обучаем CNN-модель с мета-фичами

Считаем мета-фичи по очереди по каждому каналу. В пределах одного канала обработка параллельная на половине ядер.

In [32]:
%%capture --no-stdout
calc_features(X, 'meta_features/all', channels=12)

Start process 0-th channel
Start process 1-th channel
Start process 2-th channel
Start process 3-th channel
Start process 4-th channel
Start process 5-th channel
Start process 6-th channel
Start process 7-th channel
Start process 8-th channel
Start process 9-th channel
Start process 10-th channel
Start process 11-th channel
Meta features shape: (6055, 588)


Оставляем только интересующие мета-фичи и нормализуем

In [66]:
from sklearn.preprocessing import normalize

# 0: mean_rate
# 1: RR_min/RR_max
# 2: RR_mean/RR_max
# 3: RR_min/RR_mean
# 17: Q_max
# 25: S_max
# 33: RR_max
# 34: RR_min
# 39: PQ_mean
# 41: QRS_max
# 42: QRS_min

X_meta = load_features('meta_features/all', channels=12)
print(f'Meta features shape: {X_meta.shape}')

target_ids = []
for channel_id in range (0, 12):
    #try1: for i in [0, 1, 2, 17, 25, 33, 34, 39, 41, 42]:
    #try2: for i in [0, 1, 2, 25, 33, 34]:
    for i in [0, 1, 2, 3]:
        target_ids.append(i + 49*channel_id)
        
X_meta = X_meta[:, target_ids]
average_values = np.mean(X_meta, axis=0)
for i in range(average_values.shape[0]):
    replacement = average_values[i]
    X_meta[:, i] = np.where(X_meta[:, i] == 0, replacement, X_meta[:, i])
    
X_meta = normalize(X_meta, axis=0)
print(f'Target meta-features shape: {X_meta.shape}')

Meta features shape: (6055, 588)
Target meta-features shape: (6055, 48)


Используем 10-fold cross validation. Сохраняем модели.

In [None]:
train_acc_2 = {}
val_acc_2 = {}
best_val_acc_2 = {}

for i in range(1, 11):
    print(f'Start train {i}-th model')
    
    # Train dataset
    X_train = X[np.where(Y.strat_fold != i)]
    X_train_meta = X_meta[np.where(Y.strat_fold != i)]
    Y_train = Y[Y.strat_fold != i]
    y_train = Y_train[SCP_LABELS_ARR].values

    # Validation dataset
    X_val = X[np.where(Y.strat_fold == i)]
    X_val_meta = X_meta[np.where(Y.strat_fold == i)]
    Y_val = Y[Y.strat_fold == i]
    y_val = Y_val[SCP_LABELS_ARR].values

    wrapper = ModelWrapper(ECGWithMetaNet())
    save_path = f'./models/ecg_meta_{i}'
    ta, va, bva = wrapper.train_n_epochs(
        X_train, X_train_meta, y_train, X_val, X_val_meta, y_val,
        save_path, epochs=35, verbose=False, batch_size=256
    )
    train_acc_2[i] = ta
    val_acc_2[i] = va
    best_val_acc_2[i] = bva

In [62]:
# Load models of 10-fold cross validation and evaluate
predictions_2 = []

for i in range(1, 11):
    print(f'Start evaluate {i}-th model')
    
    # Validation dataset
    X_val = X[np.where(Y.strat_fold == i)]
    X_val_T = torch.Tensor(np.transpose(X_val, [0, 2, 1]))
    X_val_meta = X_meta[np.where(Y.strat_fold == i)]
    X_val_meta_T = torch.Tensor(X_val_meta)
    Y_val = Y[Y.strat_fold == i]
    y_val = Y_val[SCP_LABELS_ARR].values
    
    wrapper = ModelWrapper(ECGWithMetaNet())
    save_path = f'./models/ecg_meta_{i}'
    wrapper.load(save_path)
    y_pred_weights = wrapper.predict((X_val_T, X_val_meta_T)).detach().numpy()
    predictions_2.append([y_val, y_pred_weights])

Start evaluate 1-th model
Start evaluate 2-th model
Start evaluate 3-th model
Start evaluate 4-th model
Start evaluate 5-th model
Start evaluate 6-th model
Start evaluate 7-th model
Start evaluate 8-th model
Start evaluate 9-th model
Start evaluate 10-th model


In [12]:
metrics_2 = get_metrics(predictions_2, METRICS)
print_metrics(LABELS, METRICS_LABELS, metrics_2)

sinus rhythm [SR]
	Accuracy: 0.8605  f1: 0.5716  ROC AUC: 0.888  Sensitivity [TPR]: 0.568  Specificity [TNR]: 0.9186
sinus arrhythmia [SARRH]
	Accuracy: 0.8862  f1: 0.5344  ROC AUC: 0.8955  Sensitivity [TPR]: 0.5217  Specificity [TNR]: 0.9392
bradycardia [SBRAD]
	Accuracy: 0.939  f1: 0.6878  ROC AUC: 0.9608  Sensitivity [TPR]: 0.6485  Specificity [TNR]: 0.9731
sinus tachycardia [STACH]
	Accuracy: 0.9564  f1: 0.8502  ROC AUC: 0.9824  Sensitivity [TPR]: 0.9019  Specificity [TNR]: 0.9651
artrial fibrillation [AFIB]
	Accuracy: 0.9189  f1: 0.8301  ROC AUC: 0.9692  Sensitivity [TPR]: 0.8079  Specificity [TNR]: 0.9555
Total
	Accuracy: 0.6414  f1: 0.6948  ROC AUC: 0.9392  Sensitivity [TPR]: 0.6896  Specificity [TNR]: 0.9503


In [75]:
train_acc_2_mean = np.zeros(len(train_acc_2[1]))
val_acc_2_mean = np.zeros(len(val_acc_2[1]))

for i in range(1, 11):
    train_acc_2_mean += train_acc_2[i]
    val_acc_2_mean += val_acc_2[i]
train_acc_2_mean = train_acc_2_mean / 10
val_acc_2_mean = val_acc_2_mean / 10

In [79]:
train_trace = go.Scatter(x=np.arange(len(train_acc_2_mean)), y=train_acc_2_mean, name='Train')
val_trace = go.Scatter(x=np.arange(len(val_acc_2_mean)), y=val_acc_2_mean, name='Validation')
fig = go.Figure(data=[train_trace, val_trace], layout=go.Layout(title='Mean accuracy per epoch, CNN with meta'))
iplot(fig)

### Проверяем, что модели различаются статистически значимо

Считать стат значимость будем по f1-метрике.<br>
Для каждой модели у нас есть 6 выборок по 10 значений (5 классов + Total).<br>
В качестве статистического критерия используем критерий Уилкоксона (непараметрический, для небольших зависимых выборок)

In [25]:
for i, label in enumerate(LABELS):
    v1 = metrics_1[:, i, 1]
    v2 = metrics_2[:, i, 1]
    pval = wilcoxon(v1, v2)[1]
    print(label)
    print(f'\tmean f1 [1]: {np.round(np.mean(v1), 3)}, mean f1 [2]: {np.round(np.mean(v2), 3)}, p_value: {pval}')

sinus rhythm [SR]
	mean f1 [1]: 0.425, mean f1 [2]: 0.572, p_value: 0.00390625
sinus arrhythmia [SARRH]
	mean f1 [1]: 0.187, mean f1 [2]: 0.534, p_value: 0.001953125
bradycardia [SBRAD]
	mean f1 [1]: 0.687, mean f1 [2]: 0.688, p_value: 1.0
sinus tachycardia [STACH]
	mean f1 [1]: 0.833, mean f1 [2]: 0.85, p_value: 0.083984375
artrial fibrillation [AFIB]
	mean f1 [1]: 0.759, mean f1 [2]: 0.83, p_value: 0.005859375
Total
	mean f1 [1]: 0.578, mean f1 [2]: 0.695, p_value: 0.001953125


Как можно заметить, с помощью мета-фичей, основанных на R-пике можно статистически значимо улушить распознавание синусового ритма и аритмии