This notebook is designed to produce saliency maps for ECG transformer classification models in the [fairseq-signals](https://github.com/Jwoo5/fairseq-signals) repository.

Before running this notebook, there are some precursor steps to be taken:
1. Compute `saliency_{split}.npy` files using `fairseq-hydra-validate` with the `common_eval.extract=[saliency]` command-line argument
2. Run the `saliency.py` script to generate a `attn_max_{split}.npy` file

Here is an example command for step 1, assuming [this preprocessing procedure](https://github.com/Jwoo5/fairseq-signals/blob/master/scripts/preprocess/ecg/README.md) was followed:
```
FAIRSEQ_ROOT="TODO"
MANIFEST_DIR="TODO"
LABEL_DIR="TODO"
OUTPUT_DIR="TODO"
CHECKPOINT_NUM="TODO"

CHECKPOINT="$OUTPUT_DIR/checkpoint$CHECKPOINT_NUM.pt"
NUM_LABELS=$(($(wc -l < "$LABEL_DIR/label_def.csv") - 1))

fairseq-hydra-validate \
    task.data=$MANIFEST_DIR \
    common_eval.path=$CHECKPOINT \
    common_eval.extract=[saliency] \
    common_eval.results_path=$OUTPUT_DIR \
    model.num_labels=$NUM_LABELS \
    dataset.valid_subset=test \
    dataset.batch_size=256 \
    dataset.num_workers=10 \
    dataset.disable_validation=false \
    distributed_training.distributed_world_size=1 \
    distributed_training.find_unused_parameters=True \
    +task.label_file=$LABEL_DIR/y.npy \
    --config-dir $FAIRSEQ_ROOT/examples/w2v_cmsc/config/finetuning/ecg_transformer \
    --config-name diagnosis
```

# Setup

In [None]:
from typing import Tuple
import os
import yaml

import numpy as np
import pandas as pd

from scipy.io import loadmat
from scipy.ndimage import map_coordinates

from fairseq_signals.utils.file import extract_filename

import matplotlib.pyplot as plt
import numpy as np

def blend_colors_hex(start_color: str, end_color: str, activations: np.ndarray) -> np.ndarray:
    """
    Blends between two colors based on an array of blend factors.

    Parameters
    ----------
    start_color : str
        Hexadecimal color code for the start color.
    end_color : str
        Hexadecimal color code for the end color.
    activations : np.ndarray
        An array of blend factors where 0 corresponds to the start color and 1 to the end color.

    Returns
    -------
    np.ndarray
        An array of hexadecimal color codes resulting from the blends.

    Raises
    ------
    ValueError
        If any of the input blend factors are not within the range [0, 1].
    """
    if np.any((activations < 0) | (activations > 1)):
        raise ValueError("All blend factors must be between 0 and 1.")

    # Convert hexadecimal to RGB
    def hex_to_rgb(hex_color: str) -> Tuple[int]:
        return tuple(int(hex_color[i: i+2], 16) for i in (1, 3, 5))

    # Get RGB tuples
    start_rgb = np.array(hex_to_rgb(start_color))
    end_rgb = np.array(hex_to_rgb(end_color))

    # Blend RGB values
    blended_rgb = np.outer(1 - activations, start_rgb) + np.outer(activations, end_rgb)

    # Convert blended RGB back to hex codes
    return blended_rgb / 255

def colored_line_segments(data: np.ndarray, colors: np.ndarray, ax=None, **kwargs):
    """
    Plots line segments based on the provided data points, with each segment
    colored according to the corresponding color specification in `colors`.

    Parameters
    ----------
    data : np.ndarray
        Array of y-values for the line segments.
    colors : np.ndarray
        Array of colors, each color applied to the corresponding line segment
        between points i and i+1.

    Raises
    ------
    ValueError
        If the `colors` array does not have exactly one less element than the `data` array,
        as each segment needs a unique color.

    Returns
    -------
    None
    """
    if len(colors) != len(data) - 1:
        raise ValueError("Colors array must have one fewer elements than data array.")

    if ax is None:
        for i in range(len(data) - 1):
            plt.plot([i, i + 1], [data[i], data[i + 1]], color=colors[i], **kwargs)
    else:
        for i in range(len(data) - 1):
            ax.plot([i, i + 1], [data[i], data[i + 1]], color=colors[i], **kwargs)

In [None]:
manifest_path = '...' # Multi-source 'manifest.csv' filepath
run_directory = '...' # Directory with 'config.yaml' from training, as well as 'attn_max_{split}.npy' file
segmented_dir = '...' # Directory of segmented files (raw signal values over which attention coloring is laid)

split = 'test'
sample_size = 2500
sample_rate = 500
lead = 'II'

# Load

In [None]:
# Load attention output weight max values
attn_max = np.load(os.path.join(run_directory, f'attn_max_{split}.npy'))
attn_max.shape

In [None]:
manifest = pd.read_csv(manifest_path, low_memory=False)
manifest.rename(columns={
    'sample_rate': 'sample_rate_org',
    'sample_size': 'sample_size_org',
}, inplace=True)
manifest

In [None]:
with open(os.path.join(run_directory, 'config.yaml'), "r") as f:
    config = yaml.safe_load(f)

manifest_dir = config['task']['data']
label_file = config['task']['label_file']

# Incorporate sample index and original sampling sizes/rates
meta = pd.read_csv(os.path.join(manifest_dir, f'{split}.tsv'), sep='\t', index_col='Unnamed: 0')
meta = meta[meta.columns[0]].rename('sample_size')
meta.index.name = 'file'
meta = meta.reset_index()
meta['save_file'] = extract_filename(meta['file']).replace("_\d+\.mat$", '.mat', regex=True)
meta = meta.merge(
    manifest[['save_file', 'idx', 'sample_size_org', 'sample_rate_org']],
    on='save_file',
    how='left',
)

# Incorporate attn_max
meta['attn_max'] = list(attn_max)

# Incorporate labels
if config['task']['label_file'] is not None:
    label_dir = os.path.dirname(config['task']['label_file'])
    label_def = pd.read_csv(os.path.join(label_dir, "label_def.csv"), index_col='name')
    y = np.load(config['task']['label_file'])

    # Align labels with manifest
    labels = y[meta["idx"].values]

    # Convert into DataFrame format
    labels_pd = pd.DataFrame(
        labels,
        columns=label_def.index,
    ).astype(bool)
    labels_pd.index.name = 'idx'
    meta = pd.concat([meta, labels_pd], axis=1)

meta

# Filter

In [None]:
meta_filtered = meta[meta['Sinus rhythm']].sample(3).copy()
meta_filtered

# Prepare plots

In [None]:
meta_filtered['lead'] = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'].index(lead)

In [None]:
# Load the original signal values
meta_filtered['seg_path'] = segmented_dir.rstrip('/') + '/' + meta_filtered['file']
assert meta_filtered['seg_path'].apply(os.path.isfile).all()

meta_filtered['feats'] = meta_filtered.apply(
    lambda row: loadmat(row['seg_path'])['feats'][row['lead']],
    axis=1,
)
meta_filtered['sample_size_extracted'] = meta_filtered['feats'].apply(
    lambda feats: feats.shape[0]
)

In [None]:
def prep_saliency_values(row):
    attn_max = row['attn_max']

    # Resample to original sample size
    new_dims = [
        np.linspace(0, original_length-1, new_length) \
        for original_length, new_length in \
        zip(attn_max.shape, (row['sample_size_extracted'] - 1,))
    ]
    coords = np.meshgrid(*new_dims, indexing='ij')
    attn_max = map_coordinates(attn_max, coords)

    # Min-max normalization
    attn_max = attn_max - attn_max.min()
    attn_max = attn_max/attn_max.max()

    return attn_max

meta_filtered['saliency_prepped'] = meta_filtered.apply(prep_saliency_values, axis=1)
meta_filtered['colors'] = meta_filtered['saliency_prepped'].apply(lambda sal: blend_colors_hex('#0047AB', '#DC143C', sal))

In [None]:
# Plot samples
for i, (_, row) in enumerate(meta_filtered.iterrows()):
    fig = plt.figure(i, figsize=(20, 2))
    fig.tight_layout()
    plt.axis('off')
    colored_line_segments(row['feats'], row['colors'])

In [None]:
label_str = pd.DataFrame(
    np.argwhere(meta_filtered[label_def.index]).tolist()
).set_index(0)[1].map({i: val for i, val in enumerate(label_def.index)}).groupby(
    level=0,
).agg('\n'.join)
label_str.index = label_str.index.map({i: ind for i, ind in enumerate(meta_filtered.index)})
meta_filtered['label_str'] = label_str
meta_filtered['label_str']

In [None]:
# Plot with true labels on the right-hand side
for i, (_, row) in enumerate(meta_filtered.iterrows()):
    fig = plt.figure(i, figsize=(20, 2))
    fig.tight_layout()
    plt.axis('off')
    plt.subplots_adjust(right=0.9)
    plt.figtext(
        0.9,
        0.5,
        row['label_str'],
        verticalalignment='center',
        horizontalalignment='left',
    )
    colored_line_segments(row['feats'], row['colors'])