# Notebook 2 - Preprocessing

Before calculating features from the raw signals, we should preprocess:
- cleanse data
- judge data quality

It is very useful to remove unwanted errors and noise, so that we only target our analysis on meaningful information.

`"Garbage in, garbage out"`

In [None]:
from multiprocessing import Pool, cpu_count
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wfdb
from wfdb import processing

from vt.records import get_alarms, data_dir

In [None]:
alarms, record_names, record_names_true, record_names_false = get_alarms()

## Section 0 - Measuring Flatline and Saturation

We do not want to analyze signals that are not accurately measured:

- flatline signals, ie. when the lead is disconnected
- saturated signals, ie. when the signal calibration leads to values exceeding the input digital range

**Goal**: Define quantitative metrics that can measure flatness and saturation.

In [None]:
from scipy.stats import mode

In [None]:
def get_missing_prop(sig):
    """
    Get the proportion of missing values from the signal
    """
    if sig.ndim == 2:
        return [get_missing_prop(sig[:, ch]) for ch in range(sig.shape[1])]
    
    sig_len = len(sig)
    nan_locs = np.where(np.isnan(sig))[0]
    return nan_locs.size

def is_missing(sig, missing_thresh=0.4):
    """
    Determine whether a signal has too many missing values.
    Returns True if the ratio of nans exceeds missing_thresh.
    """
    if sig.ndim == 2:
        return [is_missing(sig[:, ch], missing_thresh) for ch in range(sig.shape[1])]
    
    sig_len = len(sig)
    missing_prop = get_missing_prop(sig)
    
    if missing_prop / sig_len > missing_thresh:
        return True
    else:
        return False


def get_mode_prop(sig):
    """
    Get the proportion of samples of the signal equal
    to the mode
    """
    if sig.ndim == 2:
        return [get_mode_prop(sig[:, ch]) for ch in range(sig.shape[1])]
    
    return mode(sig).count[0] / len(sig)

def is_flatline(sig, mode_thresh=0.8):
    """
    Determine whether a signal is flatline
    by inspecting the proportion of samples that
    match the mode, or are invalid
    """
    if sig.ndim == 2:
        return [is_flatline(sig[:, ch], mode_thresh) for ch in range(sig.shape[1])]
    
    mode_prop = get_mode_prop(sig)
    
    if mode_prop / len(sig) > mode_thresh:
        return True
    else:
        return False

def get_edge_prop(sig, n_bins=8):
    """
    Get the proportion of the signal in the max and min bins
    
    """
    if sig.ndim == 2:
        return [get_edge_prop(sig[:, ch], n_bins) for ch in range(sig.shape[1])]
    
    # get rid of nans
    sig = sig[~np.isnan(sig)]
    if sig.size == 0:
        return 0
    
    # Bin the data. Get proportion of values in high and low bins
    freq, bin_edges = np.histogram(sig, bins=n_bins)
    edge_prop = (freq[0] + freq[-1]) / np.sum(freq)
    
    return edge_prop

def is_saturated(sig, edge_thresh=0.3):
    """
    Determine whether or not a signal is saturated, depending
    on whether the proportion of samples in the max/min bins
    crosses the threshold.
    """
    if sig.ndim == 2:
        return [is_saturated(sig[:, ch], edge_thresh) for ch in range(sig.shape[1])]
    
    if get_edge_prop(sig) > edge_thresh:
        return True
    else:
        return False


def is_valid(sig, missing_thresh, mode_thresh, edge_thresh):
    """
    Determine whether a signal is valid. It must be neither too:
    - empty
    - flatlined
    - saturated on both sides
    """
    if sig.ndim == 2:
        return [is_valid(sig[:, ch], missing_thresh, mode_thresh, edge_thresh) for ch in range(sig.shape[1])]
    
    if is_missing(sig, missing_thresh) or is_flatline(sig, mode_thresh) or is_saturated(sig, edge_thresh):
        return False
    else:
        return True


## Section 1 - Visualize performance

Visualize metrics and see whether metrics match desired labelling

In [None]:
def visualize_record_quality(record_name, start_sec=290, stop_sec=300):
    """
    Visualize a few 'quality' parameters of a single record.
    
    """
    fs = 250
    # Read record
    signal, fields = wfdb.rdsamp(os.path.join(data_dir, record_name),
                                 sampfrom=start_sec*fs, sampto=stop_sec*fs,
                                 channels=[0, 1, 2])
    
    missing_prop = get_missing_prop(signal)
    mode_prop = get_mode_prop(signal)
    edge_prop = get_edge_prop(signal)
    
    if alarms.loc[record_name, 'result']:
        sig_style = 'r'
    else:
        sig_style = 'b'

    wfdb.plot_items(signal=signal, title='Record %s' % record_name, sig_style=sig_style,
                    figsize=(14, 7))
    print('missing prop:', missing_prop)
    print('mode prop', mode_prop)
    print('edge prop', edge_prop)

In [None]:
# Visualize signal quality metrics of the last 10s
for record_name in record_names[260:275]:
    visualize_record_quality(record_name)
    print('')
    # Seems 10s is too large of a window to use for these aggregate metrics.
    # Need to localize

In [None]:
# Visualize signal quality metrics of the last 10s, in 5s windows
for record_name in record_names[260:275]:
    visualize_record_quality(record_name, start_sec=290, stop_sec=295)
    visualize_record_quality(record_name, start_sec=295, stop_sec=300)
    print('')


It seems the proportion of values in edge bins does a poor job at picking up saturation.

Can you think of something better for any of the situations we are trying to detect?

In [None]:
# Define your functions here

# def is_saturated():
#     pass

# def is_flatline():
#     pass

# def visualize_record_quality(record_name, start_sec=290, stop_sec=300):
#     """
#     Visualize a few 'quality' parameters of a single record.
    
#     """
#     fs = 250
#     # Read record
#     signal, fields = wfdb.rdsamp(os.path.join(data_dir, record_name),
#                                  sampfrom=start_sec*fs, sampto=stop_sec*fs,
#                                  channels=[0, 1, 2])
    
#     # Insert features here!!!!!
    
#     if alarms.loc[record_name, 'result']:
#         sig_style = 'r'
#     else:
#         sig_style = 'b'

#     wfdb.plot_items(signal=signal, title='Record %s' % record_name, sig_style=sig_style,
#                     figsize=(14, 7))
    
#     # Print your features here!!!!!
#     print()
    

In [None]:
# Visualize signal quality metrics of the last 10s, in 5s windows
# for record_name in record_names[260:275]:
#     visualize_record_quality(record_name, start_sec=290, stop_sec=295)
#     visualize_record_quality(record_name, start_sec=295, stop_sec=300)
#     print('')

### Another possible solution

This process is often iterative!

In [None]:
def get_consec_prop(sig):
    """
    Get the proportion of the signal that shares the same
    value as its previous sample.
    
    """
    if sig.ndim == 2:
        return [get_consec_prop(sig[:, ch]) for ch in range(sig.shape[1])]
    
    n_consec = len(np.where(np.diff(sig)==0)[0])
    consec_prop = n_consec / len(sig)
    return consec_prop

def is_saturated(sig, consec_thresh=0.01):
    """
    Determine whether or not a signal is saturated, depending
    on whether the proportion of consecutive samples
    crosses the threshold.
    """
    if sig.ndim == 2:
        return [is_saturated(sig[:, ch], edge_thresh) for ch in range(sig.shape[1])]
    
    if get_consec_prop(sig) > consec_thresh:
        return True
    else:
        return False

def visualize_record_quality(record_name, start_sec=290, stop_sec=300):
    """
    Visualize a few 'quality' parameters of a single record.
    
    """
    fs = 250
    # Read record
    signal, fields = wfdb.rdsamp(os.path.join(data_dir, record_name),
                                 sampfrom=start_sec*fs, sampto=stop_sec*fs,
                                 channels=[0, 1, 2])
    
    mode_prop = get_mode_prop(signal)
    consec_prop = get_consec_prop(signal)
    
    if alarms.loc[record_name, 'result']:
        sig_style = 'r'
    else:
        sig_style = 'b'

    wfdb.plot_items(signal=signal, title='Record %s' % record_name, sig_style=sig_style,
                    figsize=(14, 7))
    print('mode prop', mode_prop, 'consec prop', consec_prop)


In [None]:
# Visualize signal quality metrics of the last 10s, in 5s windows
for record_name in record_names[260:275]:
    visualize_record_quality(record_name, start_sec=290, stop_sec=295)
    visualize_record_quality(record_name, start_sec=295, stop_sec=300)


## The issue is difficult

Problems with potential solutions:
- Detect saturation
  - Using proportion of samples in edge bins: flat one-sided ecgs will have most of their samples in one of the edge bins.
  - Using proportion of samples with 0 gradient: low resolution digitization may produce many consecutive samples
- Detect flatline
  - Using standard deviation: unscaled signals have different amplitude ranges from the expected values
  - Using proportion of samples in mode: flatlines may not be *totally* flat. There could be very small sample variations.


## Section 2 - Filling Missing Values

In [None]:
from scipy import interpolate

In [None]:
def fill_missing(sig):
    """
    Fill missing values of a signal by interpolating between
    present samples, and extending the earliest/latest values
    forwards/backwards.
    """
    if sig.ndim == 2:
        clean_sig = np.empty([sig.shape[0], sig.shape[1]])
        for ch in range(sig.shape[1]):
            clean_sig[:, ch] = fill_missing(sig=sig[:, ch])
        return clean_sig
    
    sig_len = len(sig)
    invalid_inds = np.where(np.isnan(sig))[0]
    
    n_invalid = invalid_inds.size
    
    # Return flatline for completely empty signal
    if n_invalid == sig_len:
        return np.zeros(sig_len)
    
    if n_invalid:
        valid_inds = np.where(~np.isnan(sig))[0]
        valid_samps = sig[valid_inds]
        # https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.interp1d.html
        f = interpolate.interp1d(valid_inds, valid_samps)

        clean_sig = sig.copy()

        # Set samples on sides of first and last valid samples
        # to their values.
        if valid_inds[0] != 0:
            clean_sig[:valid_inds[0]] = sig[valid_inds[0]]
        if valid_inds[-1] != sig_len - 1:
            clean_sig[valid_inds[-1] + 1:] = sig[valid_inds[-1]]

        invalid_inds = np.where(np.isnan(sig))[0]
        invalid_inds = invalid_inds[(invalid_inds > valid_inds[0]) & (invalid_inds < valid_inds[-1])]
        # Interpolate between existing samples.
        clean_sig[invalid_inds] = f(invalid_inds)
    else:
        clean_sig = sig
    
    return clean_sig

## Section 3 - Filtering to Removing Noise

Remove high and low frequency components of the signals

In [None]:
from scipy.signal import butter, filtfilt

In [None]:
def bandpass(sig, fs=250, f_low=0.5, f_high=40, order=4):
    """
    Bandpass filter the signal
    """
    if sig.ndim ==2:
        sig_filt = np.zeros(sig.shape)
        for ch in range(sig.shape[1]):
            sig_filt[:, ch] = bandpass(sig[:, ch], fs, f_low, f_high, order)
        return sig_filt
    
    f_nyq = 0.5 * fs
    wlow = f_low / f_nyq
    whigh = f_high / f_nyq
    b, a = butter(order, [wlow, whigh], btype='band')
    sig_filt = filtfilt(b, a, sig, axis=0)
    
    return sig_filt

def is_valid(sig, missing_thresh=0.4, mode_thresh=0.8, consec_thresh=0.01):
    """
    Determine whether or not a signal segment is valid.
    It is only valid if it is none of the following:
    - flatline
    - saturated
    - has too many missing values

    """
    if sig.ndim == 2:
        return [is_valid(sig[:, ch], missing_thresh, mode_thresh, consec_thresh) for ch in range(sig.shape[1])]
    
    if (is_missing(sig, missing_thresh=missing_thresh)
        or is_flatline(sig, mode_thresh=mode_thresh)
        or is_saturated(sig, consec_thresh=consec_thresh)):
        return False
    else:
        return True

In [None]:
def visualize_clean(record_name, start_sec=290, stop_sec=300, check_invalid=True):
    
    fs = 250
    # Read record
    signal, fields = wfdb.rdsamp(os.path.join(data_dir, record_name),
                                 sampfrom=start_sec * fs,
                                 sampto=stop_sec * fs, channels=[0,1,2])

    if check_invalid:
        valid = is_valid(signal)
    else:
        valid = [True] * 3
        
    # Get beat indices
    qrs_0 = processing.gqrs_detect(signal[:, 0], fs=fs)
    qrs_1 = processing.gqrs_detect(signal[:, 1], fs=fs)
    pulse_2 = wfdb.rdann(os.path.join(data_dir, record_name), 'wabp2',
                             sampfrom = start_sec * fs,
                             sampto=stop_sec * fs, shift_samps=True).sample
    beat_inds = [qrs_0, qrs_1, pulse_2]
    
    # Clean the signals
    signal = fill_missing(signal)
    signal = bandpass(signal, fs=250, f_low=0.5, f_high=30, order=4)
    
    # Alarm result
    result = alarms.loc[record_name, 'result']

    # Graph colours
    if result:
        result = 'True Alarm'
        style=['r'] * 3
    else:
        result = 'False Alarm'
        style=['b'] * 3
    
    for i in range(len(valid)):
        if not valid[i]:
            style[i] = 'g'

    wfdb.plot_items(signal=signal, ann_samp=beat_inds, time_units='seconds', fs=fs,
                    title='Record: %s %s' % (record_name, result), figsize = (16, 8),
                    ylabel=fields['sig_name'], sig_style=style, ann_style=['k*'])


In [None]:
for record_name in record_names[260:275]:
    visualize_clean(record_name, start_sec=290, stop_sec=300)