# Electrocardiogram Analysis using ECG-FM

The electrocardiogram (ECG) is a low-cost, non-invasive diagnostic test that has been ubiquitous in the assessment and management of cardiovascular disease for decades. ECG-FM is a pretrained, open foundation model for ECG analysis.

In this tutorial, we will introduce how to perform inference for multi-label classification using a finetuned ECG-FM model. Specifically, we will take a model finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/) and perform inference on a sample of the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) to show how to adapt the predictions to a new set of labels.

## Overview
0. Installation
1. Prepare checkpoints
2. Prepare data
3. Run inference
4. Interpret results

## 0. Installation

ECG-FM was developed in collaboration with the [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) framework, which implements a collection of deep learning methods for ECG analysis.

Clone [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) and refer to the requirements and installation section in the top-level README. After following those steps, install `pandas` and make the environment accessible within this notebook by running:
```
python3 -m pip install --user pandas
python3 -m pip install --user --upgrade jupyterlab ipywidgets ipykernel
python3 -m ipykernel install --user --name ecg_fm
```

In [None]:
import os
import pandas as pd
import torch

from fairseq_signals.utils.store import MemmapReader

In [None]:
root = os.getcwd()
root

In [None]:
fairseq_signals_root = # TODO
fairseq_signals_root = fairseq_signals_root.rstrip('/')
fairseq_signals_root

## 1. Prepare checkpoints

### Download checkpoints

The checkpoints are available on [HuggingFace](https://huggingface.co/wanglab/ecg-fm-preprint). Alternatively, they can be downloaded using the below commands.

**Disclaimer: These models are different from those reported in our arXiv paper.** These BERT-Base sized models were trained purely on public data sources due to privacy concerns surrounding UHN-ECG data and patient identification. Validation for the final models will be available upon full publication.

In [None]:
from huggingface_hub import hf_hub_download

_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='physionet_finetuned.pt',
    local_dir=os.path.join(root, 'ckpts'),
)
_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='physionet_finetuned.yaml',
    local_dir=os.path.join(root, 'ckpts'),
)

In [None]:
assert os.path.isfile(os.path.join(root, 'ckpts/physionet_finetuned.pt'))
assert os.path.isfile(os.path.join(root, 'ckpts/physionet_finetuned.yaml'))

## 2. Prepare data

The model being used was finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/). To simplify this tutorial, we have processed a sample of 10 ECGs (14 5s segments) from the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) so that we may demonstrate how to adapt the predictions to a new set of labels.

If looking to perform inference on a full dataset (or using your own dataset), refer to the flexible, end-to-end, multi-source data preprocessing pipeline described [here](https://github.com/Jwoo5/fairseq-signals/tree/master/scripts/preprocess/ecg). Its README is useful for understanding how the data is organized. There are preprocessing scripts implemented for several datasets.

### Update manifest

The segmented split must be saved with absolute file paths, so we will update the current relative file paths accordingly.

In [None]:
segmented_split = pd.read_csv(
    os.path.join(root, 'data/code_15/segmented_split_incomplete.csv'),
    index_col='idx',
)
segmented_split['path'] = (root + '/data/code_15/segmented/') + segmented_split['path']
segmented_split.to_csv(os.path.join(root, 'data/code_15/segmented_split.csv'))

In [None]:
assert os.path.isfile(os.path.join(root, 'data/code_15/segmented_split.csv'))

Run the follow commands togenerate the `test.tsv` file used for inference.

In [None]:
print(f"""cd {fairseq_signals_root}/scripts/preprocess
python manifests.py \\
    --split_file_paths "{root}/data/code_15/segmented_split.csv" \\
    --save_dir "{root}/data/manifests/code_15_subset10/"
""")

In [None]:
assert os.path.isfile(os.path.join(root, 'data/manifests/code_15_subset10/test.tsv'))

## 3. Run inference

Inside our environment, we can run the following command using hydra's command line interface to extract the logits for each segment. There must be an available GPU.

In [None]:
print(f"""fairseq-hydra-inference \\
    task.data="{root}/data/manifests/code_15_subset10/" \\
    common_eval.path="{root}/ckpts/physionet_finetuned.pt" \\
    common_eval.results_path="{root}/outputs" \\
    model.num_labels=26 \\
    dataset.valid_subset="test" \\
    dataset.batch_size=10 \\
    dataset.num_workers=3 \\
    dataset.disable_validation=false \\
    distributed_training.distributed_world_size=1 \\
    distributed_training.find_unused_parameters=True \\
    --config-dir "{root}/ckpts" \\
    --config-name physionet_finetuned
""")

In [None]:
assert os.path.isfile(os.path.join(root, 'outputs/outputs_test.npy'))
assert os.path.isfile(os.path.join(root, 'outputs/outputs_test_header.pkl'))

## 4. Interpret results

The logits are ordered same as the samples in the manifest and labels in the label definition.

### Get predictions on PhysioNet 2021 labels

In [None]:
physionet2021_label_def = pd.read_csv(
    os.path.join(root, 'data/physionet2021/labels/label_def.csv'),
     index_col='name',
)
physionet2021_label_names = physionet2021_label_def.index
physionet2021_label_def

In [None]:
# Load the array of computed logits
logits = MemmapReader.from_header(
    os.path.join(root, 'outputs/outputs_test.npy')
)[:]
logits.shape

In [None]:
# Construct predictions from logits
pred = pd.DataFrame(
    torch.sigmoid(torch.tensor(logits)).numpy(),
    columns=physionet2021_label_names,
)

# Join in sample information
pred = segmented_split.reset_index().join(pred, how='left').set_index('idx')
pred

In [None]:
# Perform a (crude) thresholding of 0.5 for all labels
pred_thresh = pred.copy()
pred_thresh[physionet2021_label_names] = pred_thresh[physionet2021_label_names] > 0.5

# Construct a readable column of predicted labels for each sample
pred_thresh['labels'] = pred_thresh[physionet2021_label_names].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh['labels']

### Map predictions to CODE-15 labels

In [None]:
code_15_label_def = pd.read_csv(
    os.path.join(root, 'data/code_15/labels/label_def.csv'),
     index_col='name',
)
code_15_label_names = code_15_label_def.index
code_15_label_def

In [None]:
label_mapping = {
    'CRBBB|RBBB': 'RBBB',
    'CLBBB|LBBB': 'LBBB',
    'SB': 'SB',
    'STach': 'ST',
    'AF': 'AF',
}

physionet2021_label_def['name_mapped'] = physionet2021_label_def.index.map(label_mapping)
physionet2021_label_def

In [None]:
pred_mapped = pred.copy()
pred_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_mapped.rename(label_mapping, axis=1, inplace=True)
pred_mapped

In [None]:
pred_thresh_mapped = pred_thresh.copy()
pred_thresh_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_thresh_mapped.rename(label_mapping, axis=1, inplace=True)
pred_thresh_mapped['predicted'] = pred_thresh_mapped[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh_mapped

### Compare predicted CODE-15 to actual

In [None]:
code_15_labels = pd.read_csv(os.path.join(root, 'data/code_15/labels/labels.csv'), index_col='idx')
code_15_labels['actual'] = code_15_labels[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
code_15_labels

In [None]:
# Visualize predicted and actual labels side-by-side
pred_thresh_mapped[['predicted']].join(code_15_labels[['actual']], how='left')

# 5. Extra - Load models

Outside of the scripts/hydra client, models can be easily loaded as shown below:

In [None]:
from fairseq_signals.models import build_model_from_checkpoint

In [None]:
model_finetuned = build_model_from_checkpoint(
    checkpoint_path=os.path.join(root, 'ckpts/physionet_finetuned.pt')
)
model_finetuned

In [None]:
# Run if the pretrained model hasn't already been downloaded
from huggingface_hub import hf_hub_download

_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='mimic_iv_ecg_physionet_pretrained.pt',
    local_dir=os.path.join(root, 'ckpts'),
)
_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='mimic_iv_ecg_physionet_pretrained.yaml',
    local_dir=os.path.join(root, 'ckpts'),
)

In [None]:
model_pretrained = build_model_from_checkpoint(
    checkpoint_path=os.path.join(root, 'ckpts/mimic_iv_ecg_physionet_pretrained.pt')
)
model_pretrained