# Notebook 0 - Exploration of the Data

## Section 0 - Import libraries and basic metadata

In [None]:
import os

import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import pandas as pd
from scipy.signal import butter, filtfilt
import wfdb

# from bc.beats import get_beats, get_beat_bank
from bc.io import ann_to_df
from bc.plot import plot_beat, plot_four_beats

base_dir = os.path.abspath('..')
data_dir = os.path.join(base_dir, 'data')

# Table of record names and the beat types they contain
beat_table = pd.read_csv(os.path.join(data_dir, 'beat-types.csv'), dtype={'record':object})
beat_table.set_index('record', inplace=True)

In [None]:
# View some records and their beats from the beat table
beat_table.head()

## Section 1 - Inspect signal and annotation content

In [None]:
# Records with L: Left bundle branch block beat
l_records = beat_table.loc[beat_table['L']>0].index.values

# Visualize some records and annotations
for rec_name in l_records:
    # Load the signals and L beat annotations
    rec = wfdb.rdrecord(os.path.join(data_dir, rec_name), sampfrom=22000, sampto=32000)
    ann = wfdb.rdann(os.path.join(data_dir, rec_name), extension='atr',
                     sampfrom=22000, sampto=32000, shift_samps=True, summarize_labels=True)
    # Show the annotations contained in the files
    ann.contained_labels.set_index('symbol', inplace=True)
    display(ann.contained_labels.loc[:, ['description', 'n_occurrences']])
    wfdb.plot_wfdb(record=rec, annotation=ann, plot_sym=True)     

### Extracting and zooming in on beats

In [None]:
def get_beats(sig, qrs_inds, beat_types, wanted_type, prop_left=0.3,
              rr_limits=(108, 540), fixed_width=None, single_chan=False,
              view=False):
    """
    Given a signal and beat locations, extract the beats of a certain
    type.

    Beats are taken as prop_left of the signal fraction to the previous
    beat, and 1-prop_left of the signal fraction to the next beat.

    Exceptions are for the first beat, last beat, and when the
    next beat is too close or far.

    Paramters
    ---------
    sig : numpy array
        The single or multi-channel signal.
    qrs_inds : numpy array
        The locations of the beat indices.
    beat_types : list
        The labeled beat types.
    wanted_type : str
        The type of beat to extract. All others will be skipped, though
        their qrs locations will be used to calculate beat boundaries.
    prop_left : float, optional
        The fraction/proportion of the beat that lies to the left of the
        beat index. The remaining 1-prop_left lies to the right.
    rr_limits : tuple, optional
        Low and high limits of acceptable rr values. Default limits 108
        and 540 samples correspond to 200bpm and 40bpm at fs=360.
    fixed_width : int, optional
        Whether to get beats of fixed width instead. If not None, this
        function ignores the `prop_left` and `rr_limits` arguments and
        instead returns beats of width specified by this parameter.
    single_chan : bool, optional
        If sig has more than 1 channel, specifies whether to keep only
        the first channel. Option ignored if sig has one channel.
    view : bool, optional
        Whether to display the individual beats collected

    Returns
    -------
    beats : list
        List of numpy arrays representing beats.
    centers : list
        List of relative locations of the beat centers for each beat
    """
    prop_right = 1 - prop_left
    sig_len = sig.shape[0]
    n_beats = len(qrs_inds)

    if fixed_width is not None:
        len_left_fixed = int(fixed_width / 2)

    # List of numpy arrays of beat segments
    beats = []
    # qrs complex detection index relative to the start of each beat
    centers = []
    # rr intervals, used to extract beats
    rr = np.diff(qrs_inds)
    mean_rr = np.average(rr[(rr < rr_limits[1]) & (rr > rr_limits[0])])

    for i in range(n_beats):
        # Only keep wanted beat types
        if beat_types[i] == wanted_type:

            if fixed_width is None:
                # Previous and next rr intervals for this qrs
                rr_prev = rr[max(0, i - 1)]
                rr_next = rr[min(i, n_beats-2)]

                # Constrain the rr intervals
                if not (rr_limits[0] < rr_prev < rr_limits[1]):
                    rr_prev = mean_rr
                if not (rr_limits[0] < rr_next < rr_limits[1]):
                    rr_next = mean_rr
                len_left = int(rr_prev * prop_left)
                len_right = int(rr_next * prop_right)
            else:
                len_left = len_right = len_left_fixed

            # Skip beats too close to boundaries
            if qrs_inds[i] - len_left < 0 or qrs_inds[i] + len_right > sig_len-1:
                continue

            if sig.ndim == 1:
                beats.append(sig[qrs_inds[i] - len_left:qrs_inds[i] + len_right])
            else:
                if single_chan:
                    beats.append(sig[qrs_inds[i] - len_left:qrs_inds[i] + len_right, 0])
                else:
                    beats.append(sig[qrs_inds[i] - len_left:qrs_inds[i] + len_right, :])
            centers.append(len_left)

            if view:
                # Viewing results
                plt.plot(beats[-1])
                plt.plot(centers[-1], beats[-1][centers[-1]], 'r*')
                plt.show()

    return beats, centers



In [None]:
# Load a record, extract some beats, and view them.
sig, fields = wfdb.rdsamp(os.path.join(data_dir, l_records[0]), sampto=2000)
ann = wfdb.rdann(os.path.join(data_dir, l_records[0]), extension='atr', sampto=2000)
# Get the peak samples and symbols in a dataframe. Remove the non-beat annotations
qrs_df = ann_to_df(ann, rm_sym=['+', '~'])
beats, centers = get_beats(sig=sig[:, 0], qrs_inds=qrs_df['sample'].values, 
                           beat_types = qrs_df['symbol'].values, wanted_type='L',
                           view=True)

## Section 2 - Load and Visualize Beat Types

In [None]:
def get_beat_bank(data_dir, beat_table, wanted_type, single_chan=False,
                  fixed_width=None, min_len=1):
    """
    Make a beat bank of ecgs by extracting all beats from the records
    from MITDB containing at least `min_len` seconds of that type of
    beat, according to the table of beat information `beats_df`.

    40/48 of the records have channels MLII and V1.
    Skip the records with different channels.

    """
    # records with alternative channel sets
    ALT_SIG_RECORDS = ['100', '102', '103', '104', '114', '117', '123', '124']

    records = beat_table.loc[beat_table[wanted_type]>=min_len].index.values

    all_beats, all_centers = [], []
    for rec_name in records:
        # Skip the records with different channels
        if rec_name not in ALT_SIG_RECORDS:
            # Load the signals and beat annotations
            sig, fields = wfdb.rdsamp(os.path.join(data_dir, rec_name))
            ann = wfdb.rdann(os.path.join(data_dir, rec_name), extension='atr')
            # Get the peak samples and symbols in a dataframe. Remove the non-beat annotations
            qrs_df = ann_to_df(ann, rm_sym=['+', '~'])
            # Get the beats and centers of the record
            beats, centers = get_beats(sig=sig, qrs_inds=qrs_df['sample'].values,
                beat_types=qrs_df['symbol'].values, wanted_type=wanted_type,
                single_chan=single_chan, fixed_width=fixed_width)

            all_beats += beats
            all_centers += centers

    return all_beats, all_centers


### Normal Beats

In [None]:
n_beats, n_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table, wanted_type='N')

In [None]:
for i in range(3):
    plot_beat(n_beats[i], n_centers[i], title='Normal Beat')

### Left Bundle Branch Block

In [None]:
l_beats, l_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table, wanted_type='L')

In [None]:
for i in range(3):
    plot_beat(l_beats[i], l_centers[i], style='C1', title='LBBB')

### Right Bundle Branch Block

In [None]:
r_beats, r_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table, wanted_type='R')

In [None]:
for i in range(3):
    plot_beat(r_beats[i], r_centers[i], style='C2', title='RBBB')

### Ventricular Premature Beat

In [None]:
v_beats, v_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table, wanted_type='V')

In [None]:
for i in range(3):
    plot_beat(v_beats[i], v_centers[i], style='C3', title='Ventricular Premature Beat')

### Compare all beats

In [None]:
plot_four_beats(beats=[n_beats[0], l_beats[0], r_beats[0], v_beats[0]],
                centers=[n_centers[0], l_centers[0], r_centers[0], v_centers[0]])

## Section 3 - Signal Filtering

In [None]:
def bandpass(sig, fs=360, f_low=0.5, f_high=40, order=4):
    """
    Bandpass filter the signal
    """
    if sig.ndim ==2:
        sig_filt = np.zeros(sig.shape)
        for ch in range(sig.shape[1]):
            sig_filt[:, ch] = bandpass(sig[:, ch], fs, f_low, f_high, order)
        return sig_filt

    f_nyq = 0.5 * fs
    wlow = f_low / f_nyq
    whigh = f_high / f_nyq
    b, a = butter(order, [wlow, whigh], btype='band')
    sig_filt = filtfilt(b, a, sig, axis=0)

    return sig_filt

In [None]:
f_low, f_high = 0.5, 40
n_beat_filtered = bandpass(n_beats[0], f_low=f_low, f_high=f_high)
l_beat_filtered = bandpass(l_beats[0], f_low=f_low, f_high=f_high)
r_beat_filtered = bandpass(r_beats[0], f_low=f_low, f_high=f_high)
v_beat_filtered = bandpass(v_beats[0], f_low=f_low, f_high=f_high)

In [None]:
print('Bandpass filtered beats with cutoff [{}hz, {}hz]'.format(f_low, f_high))
plot_four_beats(beats=[n_beat_filtered, l_beat_filtered, r_beat_filtered, v_beat_filtered],
                centers=[n_centers[0], l_centers[0], r_centers[0], v_centers[0]])

### Task: Tweak the cutoff frequencies with the bandpass filter and visualize the effect on the ECG signals
