In [32]:
import pandas as pd
import numpy as np
import wfdb
import ast

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.metrics import accuracy_score

pd.options.display.max_colwidth = 200
pd.options.display.max_columns = 200

DB_ROOT = 'data/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1'

In [2]:
SCP_LABELS = {
    'SR': 'sinus rhythm',
    'SARRH': 'sinus arrhythmia',
    'SBRAD': 'bradycardia',
    'STACH': 'sinus tachycardia',
    'AFIB': 'artrial fibrillation',
}

Причесываем данные и анализируем их

In [3]:
Y = pd.read_csv(f'{DB_ROOT}/ptbxl_database.csv')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

# Split scp labels into separate columns
for scp_label in SCP_LABELS.keys():
    Y[scp_label] = Y.scp_codes.apply(lambda x: int(scp_label in x))

# If one of the illnesses or normal
Y['labels_cnt'] = Y[SCP_LABELS.keys()].sum(axis=1)
Y['has_label'] = Y.labels_cnt > 0

Y.head(2)

Unnamed: 0,ecg_id,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,scp_codes,heart_axis,infarction_stadium1,infarction_stadium2,validated_by,second_opinion,initial_autogenerated_report,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,SR,SARRH,SBRAD,STACH,AFIB,labels_cnt,has_label
0,1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,"{'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}",,,,,False,False,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,1,0,0,0,0,1,True
1,2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,"{'NORM': 80.0, 'SBRAD': 0.0}",,,,,False,False,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,0,0,1,0,0,1,True


Проверяем сбалансированность классов и считаем, как часто встречаются несколько labels.

In [4]:
def count_nonzero(x):
    return np.sum(x > 0)

Y[['strat_fold', 'has_label', 'ecg_id'] + list(SCP_LABELS.keys())].groupby(['strat_fold', 'has_label']).agg(count_nonzero).reset_index()

Unnamed: 0,strat_fold,has_label,ecg_id,SR,SARRH,SBRAD,STACH,AFIB
0,1,False,126,0,0,0,0,0
1,1,True,2051,1678,77,63,82,151
2,2,False,131,0,0,0,0,0
3,2,True,2053,1678,77,64,83,151
4,3,False,144,0,0,0,0,0
5,3,True,2050,1678,77,63,82,151
6,4,False,121,0,0,0,0,0
7,4,True,2054,1679,77,64,83,151
8,5,False,121,0,0,0,0,0
9,5,True,2055,1679,78,64,83,152


Видим, что SR слишком много, оставим в каждом фолде по 100 ЭКГ со статусом SR

In [75]:
def get_random_n(obj, n, replace=False, seed=123):
    np.random.seed(seed)
    return obj.loc[np.random.choice(obj.index, n, replace), :]
    
SR_ecgids = Y[Y.SR == 1].groupby('strat_fold', as_index=False).apply(lambda r: get_random_n(r, 100))['ecg_id'].values

In [79]:
Y = Y[(Y.SR == 0) | (Y.ecg_id.isin(SR_ecgids))]
Y[['strat_fold', 'has_label', 'ecg_id'] + list(SCP_LABELS.keys())].groupby(['strat_fold', 'has_label']).agg(count_nonzero).reset_index()

Unnamed: 0,strat_fold,has_label,ecg_id,SR,SARRH,SBRAD,STACH,AFIB
0,1,False,126,0,0,0,0,0
1,1,True,473,100,77,63,82,151
2,2,False,131,0,0,0,0,0
3,2,True,475,100,77,64,83,151
4,3,False,144,0,0,0,0,0
5,3,True,472,100,77,63,82,151
6,4,False,121,0,0,0,0,0
7,4,True,475,100,77,64,83,151
8,5,False,121,0,0,0,0,0
9,5,True,476,100,78,64,83,152


In [80]:
# How many ecg examples with more than 1 label
print(f"ECG-examples with more than 1 label: {Y[Y.labels_cnt > 1].shape[0]}")
Y[Y.labels_cnt > 1][['report', 'scp_codes']]

ECG-examples with more than 1 label: 3


Unnamed: 0,report,scp_codes
283,"sinus bradycardia with sinus arrhythmia. the cause of the bradycardia is not evident. voltages are high in chest leads suggesting lvh. st segments are depressed in i, ii, avl, v4,5,6. this may be ...","{'LVH': 100.0, 'ISC_': 100.0, 'DIG': 100.0, 'VCLVH': 0.0, 'STD_': 0.0, 'SBRAD': 0.0, 'SARRH': 0.0}"
10362,"sinus bradycardia with sinus arrhythmia. the bradycardia may be physiological. st segments are elevated in i, ii, avf, v2-6, this is probably a normal variant. high v lead voltages are probably...","{'NORM': 100.0, 'SBRAD': 0.0, 'SARRH': 0.0}"
12282,sinus bradycardia with sinus arrhythmia. otherwise normal ecg. the cause of the bradycardia is not evident.,"{'NORM': 80.0, 'SBRAD': 0.0, 'SARRH': 0.0}"


In [81]:
# a special module that converts [batch, channel, w, h] to [batch, units]
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
    
class Print(nn.Module):
    def forward(self, x):
        print(f'printtt: {x.size()}')
        return x

class SimpleECGClassifier:
    def __init__(self):
        # input: [n, 12, 1000]
        self.model = nn.Sequential()
        self.model.add_module('p1', Print())
        self.model.add_module('conv1', nn.Conv1d(in_channels=12, out_channels=12, kernel_size=17, stride=2))
        # input: [n, 12, 492]
        self.model.add_module('p2', Print())
        #self.model.add_module('bnorm1', nn.BatchNorm1d(492))
        self.model.add_module('relu1', nn.ReLU())
        self.model.add_module('conv2', nn.Conv1d(in_channels=12, out_channels=12, kernel_size=10, stride=2))
        # input: [n, 12, 242]
        self.model.add_module('p3', Print())
        #self.model.add_module('bnorm2', nn.BatchNorm1d(242))
        self.model.add_module('relu2', nn.ReLU())
        self.model.add_module('conv3', nn.Conv1d(in_channels=12, out_channels=12, kernel_size=10, stride=2))
        # input: [n, 12, 117]
        self.model.add_module('p4', Print())
        #self.model.add_module('bnorm2', nn.BatchNorm1d(242))
        self.model.add_module('relu2', nn.ReLU())
        self.model.add_module('flatten1', Flatten())
        self.model.add_module('p5', Print())
        self.model.add_module('dense1', nn.Linear(1404, 64))
        self.model.add_module('p6', Print())
        self.model.add_module('relu3', nn.ReLU())
        self.model.add_module('dense2', nn.Linear(64, 5))
        self.model.add_module('p7', Print())
        self.model.add_module('sigmoid', nn.Sigmoid())
        
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.opt = torch.optim.Adam(self.model.parameters(), lr=0.01)
    
    def fit(self, X, y):
        self.model.train()
        self.opt.zero_grad()        
        prediction = self.model(X)
        loss = self.criterion(prediction, y)
        loss.backward()
        self.opt.step()
        
    def predict(self, X):
        self.model.eval()
        prediction = self.model(X)
        return prediction

Формируем входные данные с помощью функции из примера для данного датасета

In [82]:
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(f"{path}/{f}") for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(f"{path}/{f}") for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

X = load_raw_data(Y, 100, DB_ROOT)

Готовим данные для train/validation/test датасетов

In [83]:
def iter_train_validation(X, Y):
    for fold in sorted(Y.strat_fold.unique()):
        X_train = X[np.where(Y.strat_fold != fold)]
        Y_train = Y[Y.strat_fold != fold]
        
        X_val = X[np.where(Y.strat_fold == fold)]
        Y_val = Y[Y.strat_fold == fold]
        
        yield X_train, Y_train, X_val, Y_val

test_fold = 10 # from the dataset recommendations

# K-fold train dataset
X_k_fold = X[np.where(Y.strat_fold != test_fold)]
Y_k_fold = Y[Y.strat_fold != test_fold]

# Test dataset
X_test = X[np.where(Y.strat_fold == test_fold)]
Y_test = Y[Y.strat_fold == test_fold]

In [86]:
def multiclass_accuracy(y, pred):
    y = y.detach().numpy()
    pred = pred.detach().numpy()
    pred = np.rint(pred)
    return sum((np.rint(y) == pred).all(axis=1)) / y.shape[0]

In [87]:
EPOCHS = 10
model_1 = SimpleECGClassifier()
for epoch in range(0, EPOCHS):
    for (X_train, Y_train, X_val, Y_val) in iter_train_validation(X_k_fold, Y_k_fold):
        X_train = torch.Tensor(np.transpose(X_train, [0, 2, 1]))
        y_train = torch.Tensor(Y_train[SCP_LABELS].values)
        
        X_val = torch.Tensor(np.transpose(X_val, [0, 2, 1]))
        y_val = torch.Tensor(Y_val[SCP_LABELS].values)
        
        model_1.fit(X_train, y_train)
        prediction_train_1 = model_1.predict(X_train)
        prediction_val_1 = model_1.predict(X_val)
        
        accuracy_train_1 = multiclass_accuracy(prediction_train_1, y_train)
        accuracy_val_1 = multiclass_accuracy(prediction_val_1, y_val)
        
        print(accuracy_train_1, accuracy_val_1)
        
        break
    break

printtt: torch.Size([4831, 12, 1000])
printtt: torch.Size([4831, 12, 492])
printtt: torch.Size([4831, 12, 242])
printtt: torch.Size([4831, 12, 117])
printtt: torch.Size([4831, 1404])
printtt: torch.Size([4831, 64])
printtt: torch.Size([4831, 5])
printtt: torch.Size([4831, 12, 1000])
printtt: torch.Size([4831, 12, 492])
printtt: torch.Size([4831, 12, 242])
printtt: torch.Size([4831, 12, 117])
printtt: torch.Size([4831, 1404])
printtt: torch.Size([4831, 64])
printtt: torch.Size([4831, 5])
printtt: torch.Size([599, 12, 1000])
printtt: torch.Size([599, 12, 492])
printtt: torch.Size([599, 12, 242])
printtt: torch.Size([599, 12, 117])
printtt: torch.Size([599, 1404])
printtt: torch.Size([599, 64])
printtt: torch.Size([599, 5])
0.21403436141585594 0.21035058430717862
