# ECG Saliency Maps

In [None]:
import os
import sys
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import ListedColormap
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras import backend as K
from ml4cvd.arguments import parse_args
from ml4cvd.models import make_multimodal_multitask_model
from ml4cvd.tensor_generators import train_valid_test_tensor_generators
from ml4cvd.definitions import ECG_REST_INDEPENDENT_LEADS
from ml4cvd.tensor_maps_ecg import get_ecg_dates
from ml4cvd.TensorMap import update_tmaps
from biosppy.signals.ecg import ecg
from scipy import ndimage

%matplotlib widget

In [None]:
BOOTSTRAP = 0
MODEL = "v14"

In [None]:
sys.argv = f"""
.
--num_workers 1
--tensors /data/ecg/mgh
--train_csv {os.path.expanduser(f"~/dropbox/sts-data/bootstraps/{BOOTSTRAP}/train.csv")}
--valid_csv {os.path.expanduser(f"~/dropbox/sts-data/bootstraps/{BOOTSTRAP}/valid.csv")}
--test_csv {os.path.expanduser(f"~/dropbox/sts-data/bootstraps/{BOOTSTRAP}/test.csv")}
--input_tensors
    ecg_2500_std_sts_newest
    ecg_age_std_sts_newest
    ecg_sex_sts_newest
--output_tensors
    sts_death
--batch_size 1
--model_file {os.path.expanduser(f"~/dropbox/sts-ecg/results/{MODEL}/{BOOTSTRAP}/model_weights.h5")}
--output_folder /tmp
--id saliency
""".split()
args = parse_args()
model = make_multimodal_multitask_model(**args.__dict__)
generate_train, generate_valid, generate_test = train_valid_test_tensor_generators(**args.__dict__)

In [None]:
# Get data and gradients

batch = next(generate_test)
input_tensors, output_tensors, _, paths = batch

it = {k: tf.Variable(v, dtype=float) for k, v in input_tensors.items()}

with tf.GradientTape() as tape:
    pred = model(it, training=False)
    class_idxs_sorted = np.argsort(pred.numpy().flatten())[::-1]
    loss = pred[0][class_idxs_sorted[0]]
    
grads = tape.gradient(loss, it)
grads = {k: v / (K.sqrt(K.mean(K.square(v))) + 1e-6) for k, v in grads.items()}

In [None]:
tmaps = {}
update_tmaps("ecg_datetime_sts_newest", tmaps)
tm = tmaps["ecg_datetime_sts_newest"]

In [None]:
print(paths[0])
with h5py.File(paths[0], "r") as hd5:
    print(tm.tensor_from_file(tm, hd5))

In [None]:
# Cleanup data to voltage_tensor and voltage_gradient

key = 'input_ecg_2500_std_sts_newest_continuous'
voltage_tensor = input_tensors[key][0]
voltage_gradient = grads[key][0].numpy()

tm = [tm for tm in args.tensor_maps_in if tm.name == "ecg_2500_std_sts_newest"][0]
voltage_tensor = tm.rescale(voltage_tensor) / 1000 # rescaled to microvolts, then millivolts

## Plot saliency map for 10 second ECG leads

In [None]:
def plot_ecgs(tensor, gradient, lead_map=ECG_REST_INDEPENDENT_LEADS, hertz=250, y_max=2, blur=1, color_map=ListedColormap(cm.get_cmap("Blues", 512)(np.linspace(0.0, 1, 256)))):
    fig, ax = plt.subplots(len(lead_map), figsize=(10, 16))
    for lead, index in lead_map.items():
        lead_tensor = tensor[:, index]
        lead_gradient = gradient[:, index]
        ax[index].plot(lead_tensor, color='r')
        g = np.tile(lead_gradient, (lead_tensor.shape[-1], 1))
        g = ndimage.gaussian_filter(g, sigma=blur)
        a = ax[index].imshow(g, cmap=color_map, aspect='auto', extent=[0, lead_tensor.shape[-1], -y_max, y_max])
        ax[index].set_title(lead)
        ax[index].set_ylabel('mV')
        ax[index].set_ylim(-y_max, y_max)
        cb = plt.colorbar(a, ax=ax[index])
        cb.set_label('Salience')
        cb.set_ticks([g.min(), g.max()])
        cb.set_ticklabels(['Low', 'High'])
    plt.tight_layout()
    # plt.savefig(os.path.expanduser('~/saliency.png'))
    plt.show()
plot_ecgs(voltage_tensor, voltage_gradient)

## Median waveform saliency plots below are a work in progress

In [None]:
def stretch_ecg(raw_voltage, raw_gradient, raw_sampling_rate, desired_hr):
    """
    stretches input ECG and gradient to have the desired heart rate
    """
    raw_features = ecg(raw_voltage, sampling_rate=raw_sampling_rate, show=False)
    raw_hr = raw_features[-1].mean()
    raw_time = np.arange(len(raw_voltage))
    stretched_time = np.arange(len(raw_voltage)) * desired_hr / raw_hr
    stretched_voltage = np.interp(stretched_time, raw_time, raw_voltage)
    stretched_gradient = np.interp(stretched_time, raw_time, raw_gradient)
    stretched_sampling_rate = raw_sampling_rate * desired_hr / raw_hr
    stretched_features = ecg(
        stretched_voltage,
        sampling_rate=stretched_sampling_rate,
        show=False,
    )
    stretched_peaks = stretched_features[2]
    return stretched_voltage, stretched_gradient, stretched_peaks

In [None]:
def align_waves_and_gradients(
    voltage_tensor,
    voltage_gradient,
    median_size=250,
    use_median=True,
    use_abs=False,
    sampling_frequency=250,
    bpm=60,
):
    """
    extracts the median waveform and gradient per lead
    """
    median_waves = []
    median_gradients = []
    # get median voltage and gradients for all leads
    for lead_voltage, lead_gradient in zip(voltage_tensor.T, voltage_gradient.T):
        stretched_voltage, stretched_gradient, stretched_peaks = stretch_ecg(lead_voltage, lead_gradient, sampling_frequency, bpm)
        lead_median_waves = []
        lead_median_gradients = []
        waves = []
        gradients = []
        # calculate median voltage and gradient for all peaks within a lead
        for p0, p1, p2 in zip(stretched_peaks[:-2], stretched_peaks[1:-1], stretched_peaks[2:]):
            start = (p0 + p1) // 2
            end = p2
            if end - start < 250:
                continue

            median_x = np.arange(median_size)
            peak_x = np.arange(end - start)
            peak_median_wave = np.interp(median_x, peak_x, stretched_voltage[start:end])
            peak_median_gradient = np.interp(median_x, peak_x, stretched_gradient[start:end])
            waves.append(peak_median_wave)
            gradients.append(peak_median_gradient)

        waves = np.array(waves)
        waves -= np.mean(waves)
        gradients = np.array(gradients)
        if use_median:
            waves = np.median(waves, axis=0)
            gradients = np.median(gradients, axis=0)
        else:
            waves = np.mean(waves, axis=0)
            gradients = np.mean(gradients, axis=0)                
        if use_abs:
            gradients = np.abs(gradients)

        lead_median_waves.append(waves)  
        lead_median_gradients.append(gradients)
        median_waves.append(lead_median_waves)
        median_gradients.append(lead_median_gradients)
    return np.array(median_waves), np.array(median_gradients)

In [None]:
def plot_ecg_saliency(waves, grads, color_map, blur=1, lead_dictionary=ECG_REST_INDEPENDENT_LEADS, y_max=2):
    fig, ax = plt.subplots(4, 2, figsize=(16, 10), sharex=True)
    index2leads = {v: k for k, v in lead_dictionary.items()}
    for i in range(len(waves)):
        row, col = i % 4, i // 4
        for w in waves[i]:
            ax[row, col].plot(w, color='#E31A1C')
            break
        g = np.tile(grads[i], (waves.shape[-1], 1))
        g = ndimage.gaussian_filter(g, sigma=blur)
        a = ax[row, col].imshow(g, cmap=color_map, aspect='auto', extent=[0, waves.shape[-1], -y_max, y_max])
        ax[row, col].set_title(index2leads[i])
        ax[row, col].set_ylabel('mV')
        ax[row, col].set_ylim(-y_max, y_max)
        cb = plt.colorbar(a, ax=ax[row, col])
        cb.set_label('Salience')
        cb.set_ticks([g.min(), g.max()])
        cb.set_ticklabels(['Low', 'High'])
    plt.tight_layout()
    plt.show()

In [None]:
waves, gradients = align_waves_and_gradients(voltage_tensor, voltage_gradient, use_median=True, use_abs=True)

for blur in [1]:
    for color_map in ['Blues']:
        blues = cm.get_cmap(color_map, 512)
        newcmp = ListedColormap(blues(np.linspace(0.0, 1, 256)))
        plot_ecg_saliency(waves, gradients, color_map=newcmp, blur=blur)