In [None]:
!pip install wfdb
!pip install ecg_plot
!pip install xgboost==1.7.6

In [None]:
import pandas as pd
import ast
import wfdb
import numpy as np
import ecg_plot
import matplotlib.pyplot as plt
import os
import xgboost as xgb
import sklearn
from sklearn.metrics import accuracy_score
from sklearn import preprocessing
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report, roc_auc_score, ConfusionMatrixDisplay

## Loading data

Fist of all data are loaded using 'example_physionet.py' code with some modifications.

In [None]:
path = 'https://physionet.org/files/ptb-xl/1.0.2/'
sampling_rate=100

In [None]:
# load and convert annotation data
annotations = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')
annotations.scp_codes = annotations.scp_codes.apply(lambda x: ast.literal_eval(x))

In [None]:
# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)

In [None]:
def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
annotations['diagnostic_superclass'] = annotations.scp_codes.apply(aggregate_diagnostic)

In [None]:
def load_raw_data(df, sampling_rate, path):
    data = []
    for i, f in enumerate(df.filename_lr):
        folder = int(f.split('/')[-1].split('_')[0]) // 1000 * 1000
        file = f.split('/')[-1]
        localfile = path + f"records{sampling_rate}/{folder:05d}/{file}"
        if i%1000 == 0:
            print(i)
        try:
            if os.path.exists(localfile):
                datum = wfdb.rdsamp(f"{localfile}/")
            else:
                datum = wfdb.rdsamp(file, pn_dir=f"ptb-xl/1.0.2/records{sampling_rate}/{folder:05d}/")
            data.append(datum)
        except wfdb.io._url.NetFileNotFoundError:
            data.append([np.empty((0, 0), dtype=object), {}])
            continue
    return np.array([signal for signal, meta in data])
# Load raw signal data
# If the file exists locally, it load it, if not it looks on the server to get the data and process them
if not os.path.exists("raw_annotations.pickle"):
  raw_data = load_raw_data(annotations, sampling_rate, path)
  def transpose(x):
    return x.T
  traspose_raw_data = list(map(transpose, raw_data))
  annotations['waveforms'] = pd.Series(traspose_raw_data)
  annotations.to_pickle("raw_annotations.pickle")
else:
  annotations = pd.read_pickle('raw_annotations.pickle')

0
1000
2000
3000
4000
5000


Now that data and annotations are loaded into a dataframe and a numpy array respectively, they can be shown to a doctor asking him for the required number of record and sampling frequency .

In [None]:
ecg_id = int(5)


In [None]:
annotations_without_waveforms = annotations.loc[:, annotations.columns != 'waveforms']

Information about the required record is shown:

#### Annotations and plot for a record

In [None]:
def get_record_data(record_id):
    return annotations.iloc[record_id]
get_record_data(ecg_id)

for the same record the ECG is plotted using ecg_plot library:

#### Plotted signal

In [None]:
def plot_signal(data, record_id):
    signal = data[record_id]
    # Plot the ECG signal
    ecg_plot.plot(signal, sample_rate=sampling_rate, title="ECG Signal, Record "+str(record_id), show_lead_name=True)

In [None]:
plot_signal(annotations['waveforms'], ecg_id)

## Exploratory Data Analysis and basic cleaning

In [None]:
annotations.head(2)

In [None]:
annotations.tail(2)

In [None]:
annotations.shape

Since indexes are not matching, they will be reset so 'ecg_id' becomes a column.

In [None]:
annotations = annotations.reset_index()

In [None]:
annotations.info()

In [None]:
annotations_without_waveforms.describe()

We see that there are some  values of age that are wrong, since the maximum is 300 years. Since the origin of this mistake is unknown, let's remove these rows.

In [None]:
annotations[annotations.age>120].age.count()

In [None]:
annotations[annotations.age>120].index

These indexes must also be removed from X (records of ECG signals).

In [None]:
annotations_age_clean = annotations.drop(annotations[annotations.age>120].index)

In [None]:
annotations_age_clean.shape

Also, 30 rows seem to have had electrodes problems, and then will be removed.

In [None]:
annotations_electrodes_clean = annotations_age_clean[annotations_age_clean.electrodes_problems.isna()]

In [None]:
annotations_electrodes_clean.shape

In [None]:
annotations_clean = annotations_electrodes_clean

### Demographics

In [None]:
annotations_clean.age.plot(kind='hist')

In [None]:
annotations_clean.sex.value_counts().plot(kind='bar')

In [None]:
annotations_clean.site.value_counts().plot(kind='bar')

In [None]:
annotations_clean.device.value_counts().plot(kind='bar')

In [None]:
annotations_clean.nurse.value_counts().plot(kind='bar')

In [None]:
annotations_clean.shape

In [None]:
annotations_clean[annotations_clean.height.notna()].shape

In [None]:
annotations_clean[annotations_clean.weight.notna()].shape

### Diagnosis

In [None]:
annotations_clean_without_waveform = annotations_clean.loc[:, annotations_clean.columns != 'waveforms']

In [None]:
annotations_value_counts = annotations_clean_without_waveform.diagnostic_superclass.value_counts()
annotations_value_counts

In [None]:
complete_annotations_value_counts = [x for x in sum(annotations_value_counts.index.to_list(), []) if str(x) != 'nan']

In [None]:
diagnosis_set = set(complete_annotations_value_counts)
diagnosis_set

Since each record can have several classifications, a count of number of records in each group of diagnosis:

In [None]:
diagnosis_count = dict()
for diagnosis_type in diagnosis_set:
    diagnosis_count[diagnosis_type] = complete_annotations_value_counts.count(diagnosis_type)
diagnosis_count

In [None]:
labels = list(diagnosis_count.keys())[1:]
values = list(diagnosis_count.values())[1:]
plt.pie(values, labels=labels)
plt.show()

### Conclusions after preliminary exporation

A basic cleaning has been done to the dataset removing dirty values for age and rows with problems in electrodes. 21479 out of the initial 21801 rows remain. In this prelimiary exploration it also has been seen that rows as height and weight have values just in 6811 and 9259 of the rows, so this will be taken into account when a model is trained, depending how the chosen one deals with null values. Also, some columns such as 'site' or 'nurse' have few variations, and will probably not make a difference for the target value. It has also been seen that other variables such as sex, age, or the target value, the dataset is quite balanced.


## Identifying the heart beat of the signal, average and total heart beat in the signal

In this sections mean heart beat is calculated through the function 'get_heart_rate' for one waveform in each record. Also QRS duration and amplitude is calculated through applying 'get_qrs_duration_and_amplitude'. QRS VAT or other parameters regarding QRS used to detect abnormalities that can be calculated remain to be calculated in future versions.

In [None]:
# Calculate the number of points between each peak and the next, multiplies them by the period of each point, getting the time between succesive peaks.
# Afterwards, it averages that time, inverse it to obtain the frequency per second, and multiplies by 60 to obtain frequency per minute.
def get_heart_rate(peaks, fs):
    time_period_per_point = 1 / fs
    hearbeat_period = sum(
        (peaks[index + 1] - peaks[index]) * time_period_per_point
        for index in range(len(peaks) - 1)
    ) / len(peaks)

    return 60. / hearbeat_period

In [None]:
# Helper function for returning a pandas Series
def get_qrs_duration_and_amplitude_row(row, fs):
    qrs_duration, qrs_amplitude = get_qrs_duration_and_amplitude(row['waveforms'][11], row['r_peaks'], fs)

    row['qrs_duration'] = qrs_duration
    row['qrs_amplitude'] = qrs_amplitude
    return row

In [None]:
# Obtaines the QRS duration and amplitude
def get_qrs_duration_and_amplitude(waveform, peaks, fs):

    # Divides the waveform by the middle point between peaks, returning a subset of waveforms with each peak in the middle of it
    def chunk_waveform(waveform, peaks):
        chunks = []
        for index in range(1, len(peaks) - 1):
            prev_distance = peaks[index] - peaks[index - 1]
            next_distance = peaks[index + 1] - peaks[index]
            prev_cut_point = peaks[index - 1] + int(prev_distance / 2)
            next_cut_point = peaks[index] + int(next_distance / 2)
            chunk = waveform[prev_cut_point: next_cut_point]
            chunks.append(chunk)

        return chunks

    # Counts how many points there are at the left part of the peak until the absolute value of the signal is increased more than a 5% of the peak value
    # Afterwards, multiplies those points by the inverse of the sampling frequency to obtain the time
    def get_q_start(chunk, fs, threshold = 0.05):
        time_period_per_point = 1 / fs
        peak = max(chunk)
        initial_value = sum(chunk[:6]) / len(chunk[:6])
        low_points = 0
        for index, point in enumerate(chunk):
            if abs(point - initial_value) <= peak * threshold:
                low_points += 1
            else:
                break
        return low_points * time_period_per_point

    # Counts how many points there are at the right part of the peak until the absolute value of the signal is increased more than a 5% of the peak value
    # Afterwards, multiplies those points by the inverse of the sampling frequency to obtain the time
    def get_s_finish(chunk, fs, threshold = 0.05):
        inverted_chunk = chunk[::-1]
        time_period_per_point = 1 / fs
        peak = max(inverted_chunk)
        initial_value = sum(inverted_chunk[:6]) / len(inverted_chunk[:6])
        low_points = 0
        for index, point in enumerate(inverted_chunk):
            if abs(point - initial_value) <= peak * threshold:
                low_points += 1
            else:
                break
        return low_points * time_period_per_point

    chunks = chunk_waveform(waveform, peaks)
    qrs_duration = 0
    qrs_amplitude = 0
    # For each peak we calculate the time of the signal that is over 5% of the initial value
    for chunk in chunks:
        qrs_amplitude += max(chunk) - min(chunk)
        time_period_per_point = 1 / 100
        period = len(chunk) * time_period_per_point
        duration = period - get_q_start(chunk, fs) - get_s_finish(chunk, fs)
        qrs_duration += duration
    qrs_amplitude /= len(chunks)
    qrs_duration /= len(chunks)

    return qrs_duration, qrs_amplitude

In [None]:
annotations_clean.r_peaks = annotations_clean.r_peaks.apply(lambda x: ast.literal_eval(x.replace('  ', ' ').replace('[ ', '[').replace(' ', ',').replace(',,', ',')))

In [None]:
annotations_clean['heart_rate'] = annotations_clean.r_peaks.apply(lambda x: get_heart_rate(x, sampling_rate))
annotations_clean_all_waveforms = annotations_clean[annotations_clean['waveforms'].apply(lambda x: not isinstance(x, float))]
annotations_clean_all_waveforms = annotations_clean_all_waveforms.apply(lambda row: get_qrs_duration_and_amplitude_row(row, sampling_rate), axis=1)

## Training a model

The dataset with all the calculated parameters is now used to train an XGBoost Classifier and performance achieved reviewed.

Since various records are classified with several categories the dataset is exploded to get a set of records with one classification for the training.

In [None]:
annotations_clean_exploded = annotations_clean_all_waveforms.explode('diagnostic_superclass').dropna(subset=['diagnostic_superclass'])

In [None]:
annotations_clean_exploded.shape

In [None]:
category_columns = ['heart_axis', 'diagnostic_superclass']

In [None]:
# columns = ['strat_fold', 'age', 'sex', 'height', 'weight', 'nurse','site', 'device', 'heart_axis', 'second_opinion', 'initial_autogenerated_report', 'validated_by_human', 'baseline_drift', 'static_noise', 'burst_noise', 'extra_beats', 'pacemaker','diagnostic_superclass', 'heart_rate', 'qrs_duration', 'qrs_amplitude']
columns = ['strat_fold', 'age', 'sex', 'heart_axis', 'diagnostic_superclass', 'heart_rate', 'qrs_duration', 'qrs_amplitude']

In [None]:
data = annotations_clean_exploded[columns].copy()

In [None]:
annotations_clean_exploded['diagnostic_superclass']

In [None]:
label_encoders = {}
for column in category_columns:
  le = preprocessing.LabelEncoder()
  data[column] = le.fit_transform(data[column])
  label_encoders[column] = le

In [None]:
test_fold = 10
valid_fold = 10
data_train = data[(data['strat_fold'] != test_fold) & (data['strat_fold'] != valid_fold)]
data_test = data[data['strat_fold'] == test_fold]

In [None]:
x_train = data_train.loc[:, data_train.columns != 'diagnostic_superclass']
x_test = data_test.loc[:, data_test.columns != 'diagnostic_superclass']
x_valid = data_valid.loc[:, data_valid.columns != 'diagnostic_superclass']
y_train = data_train[['diagnostic_superclass']]
y_test = data_test[['diagnostic_superclass']]
y_valid = data_valid[['diagnostic_superclass']]

In [None]:
xgb_classifier = xgb.XGBClassifier(eta=0.006)
xgb_classifier.fit(x_train, y_train)
y_pred = xgb_classifier.predict(x_test)

In [None]:
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy: %.2f%%" % (accuracy * 100.0))

In [None]:
cm = confusion_matrix(y_test, y_pred, labels=xgb_classifier.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(label_encoders['diagnostic_superclass'].inverse_transform(xgb_classifier.classes_)))
disp.plot()
plt.show()

If we pay attention to the results obtained we would notice that just 44% of accuracy is achieved. This accuracy is nonetheless obscured by the fact that some of the records were originally classified in several categories. This fact might need a deeper study, since perhaps a multilabel model could work better.

## Some conclusions and future steps

I have tried to create a comprehensive project in the limited time I had, trying to get an overview of the problem, studying the available data and creating a process that is able to use that data and predict a possible diagnosis from the kind of information that this dataset has.

* I checked the library 'wfdb', which has some methods for processing these datasets and representing the electrocardiography in a doctor-friendly way. Sadly, the use of these methods required the annotations to be in a format different from the existing one, which was loaded from a .csv file. A method could probably be developed to convert the available annotations into a compatible format, but I  prefered to focus my efforts on more data-related tasks.
* I created some functions to extract the average hearbeat from the available data (which was quite straightforward, since the peaks indexes are contained in the annotations) and to calculate the QRS duration and amplitude from the V6 signal, with the objective of using them to represent the information from the electrocardiographies in classification model. As a future development, more methods can be built to extract more data from the rest of voltage signals regarding the QRS complex, which can be used to train better models.
* I did a first study on the variables contained in the dataset, which then I used together with the calculated parameters to train a classification model (XGBoost Classifier) with basic parameters, and a more advanced tuning and feature selection should be done in the future.
* As continuation of the project, more models should be trained and compared between them to find the one that fits our data best. One of these models should be an LSTM, which would be suited to represent the information from the electrocardiographies much better than any parameter extraction.


Sources of information for the project were:
- https://wfdb.readthedocs.io/
- https://en.wikipedia.org/wiki/QRS_complex