# Bonus Notebook - Ventricular Beats

Extraction and classification of beats as ventricular or non-ventricular

In [None]:
from multiprocessing import Pool, cpu_count
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import interpolate
from scipy.stats import mode, skew, kurtosis
from scipy.signal import butter, filtfilt
from sklearn.model_selection import train_test_split
import wfdb
from wfdb import processing

from vt.records import get_alarms, data_dir
from vt.features import calc_moments, calc_spectral_ratios
from vt.preprocessing import is_valid, fill_missing, bandpass, normalize

In [None]:
alarms, record_names, record_names_true, record_names_false = get_alarms()

## Section 0 - A highly targeted feature - a ventricular beat classifier

We need to create better features for our alarm classification system. Our classifiers are limited by the ability of our chosen features to discriminate between true and false alarms. Simple concepts like std and power ratios calculated on the entire 10s window are not sufficient to do so.

Let's think about the specific challenge: identify whether or not **ventricular tachycardia** has occurred.
- Ventricular beats
- Tachycardia

Idea: Make a classifier that classifies beats into ventricular and non-ventricular. To make this classifier, extract beats form the training records, create a beat bank, derive features from them, and fit an unsupervised classifier to these features. Because not all ECG channels are the same, we will train a classifier for each channel. We will not use BP channels (in this case) because the morphology difference between ventricular and non-ventricular beats is not as obvious.

We are using an unsupervised classifier because there are no labels of ventricular beats, just true and false alarms. We could try to extract beats and label them according to alarm results. Perhaps that is another method to try!

## Visualize signal distrubition

In [None]:
# Get the ecg channel distribution
ecg_channels = []

for record_name in record_names:
    record = wfdb.rdheader(os.path.join(data_dir, record_name))
    # First 2 channels are ECGs
    ecg_channels.append(record.sig_name[0])
    ecg_channels.append(record.sig_name[1])
    
channel_frequencies = dict(zip(ecg_channels,
                               [ecg_channels.count(chan) for chan in ecg_channels]))
ecg_channels = list(set(ecg_channels))

display('ECG Channel frequencies: ', channel_frequencies)

*There are three signals that only appear in one record each. We will not be training a classifier for these signals. If we encounter such signals in the testing data, we will skip ventricular beat detection for that channel.

## Section 1 - Extracting Beats

We will aim to extract all beats within the last 20 seconds of each training record, an store them in a beat bank. This should give us a good sample of ventricular and non-ventricular beats.

However, recall that our qrs detectors are not perfect (if they were, this challenge would be trivial), and often output qrs locations at non-beat locations.

Because the algorithm forces whatever we put into it in two clusters, including features from non-beats detected by the detector may fundamentally train the algorithm to group ventricular beats with false alarm signal patterns. This may be expected, because these patterns set off the (false) alarms in the first place!

This is quite troublesome, so we must be extra careful to put in only (or as high of a proportion as possible) real beats into the beat bank.

In [None]:
def get_beats(sig, beat_inds, prop_left = 1/5, rr_limits=(75, 500), view=False):
    """
    Given a signal and beat locations, extract the beats.
    Beats are taken as prop_left of the signal fraction to the previous
    beat, and 1-prop_left of the signal fraction to the next beat.

    Exceptions are for the first beat, last beat, and when the
    next beat is too close or far.
    
    Paramters
    ---------
    sig : numpy array
        The 1d signal array
    beat_inds : numpy array
        The locations of the beat indices
    prop_left : float, optional
        The fraction/proportion of the beat that lies to the left of the
        beat index. The remaining 1-prop_left lies to the right.
    rr_limits : tuple, optional
        Low and high limits of acceptable rr values. Default limits 75
        and 500 samples correspond to 200bpm and 30bpm.
    view : bool, optional
        Whether to display the individual beats collected
    
    Returns
    -------
    beats : list
        List of numpy arrays representing beats.
    centers : list
        List of relative locations of the beat centers for each beat
    """
    
    prop_right = 1 - prop_left
    sig_len = len(sig)
    n_beats = len(beat_inds)
    
    # List of numpy arrays of beat segments
    beats = []
    # qrs complex detection index relative to the start of each beat
    centers = []
    # rr intervals, used to extract beats
    rr = np.diff(beat_inds)
    mean_rr = np.average(rr[(rr < rr_limits[1]) & (rr > rr_limits[0])])
    
    for i in range(n_beats):
        if i == 0:
            len_left = rr[0]
        
        # Previous and next rr intervals for this qrs
        rr_prev = rr[max(0, i - 1)]
        rr_next = rr[min(i, n_beats-2)]
        
        # Constrain the rr intervals
        if  not rr_limits[0] < rr_prev < rr_limits[1]:
            rr_prev = mean_rr
        if  not rr_limits[0] < rr_next < rr_limits[1]:
            rr_next = mean_rr
        
        left_left = int(rr_prev * prop_left)
        len_right = int(rr_next * prop_right)
        
        # Skip beats too close to boundaries
        if beat_inds[i] - len_left < 0 or beat_inds[i] + len_right > sig_len-1:
            continue
        
        beats.append(sig[beat_inds[i] - len_left:beat_inds[i] + len_right])
        centers.append(len_left)        
        
        if view:
            # Viewing results
            print('len_left:', len_left, 'len_right:', len_right)
            plt.plot(beats[-1])
            plt.plot(centers[-1], beats[-1][centers[-1]], 'r*')
            plt.show()

    return beats, centers


In [None]:
def get_beat_bank(start_sec=280, stop_sec=300):
    """
    Make a beat bank of ecgs by extracting all beats from
    the same time section of channels 0 and 1 of all true
    alarm training records.
    """
    fs = 250
    beat_bank = {}
    # No cheating! We should only have access to training data
    records_train, records_test = train_test_split(record_names)
    
    for record_name in records_train:
        # Skip false alarm records
        if not alarms.loc[record_name, 'result']:
            continue
        # Read record
        signal, fields = wfdb.rdsamp(os.path.join(data_dir, record_name),
                                     sampfrom=start_sec*fs, sampto=stop_sec*fs,
                                     channels=[0, 1])
        # Determine which signals are valid
        valid = is_valid(signal)
            
        # Clean the signals, removing nans
        signal = fill_missing(sig=signal)
        # Filter the signal
        signal = bandpass(signal, fs=fs, f_low=0.5, f_high=40, order=2)
        
        # Get beats from each channel
        for ch in range(2):
            sig_ch = signal[:, ch]
            sig_name = fields['sig_name'][ch]
            
            # Skip the signals with too few instances
            if sig_name.startswith('aV'):
                continue
                
            # Skip flatline signals
            if not valid[ch]:
                continue

            # Get beat locations
            qrs_inds = processing.xqrs_detect(sig_ch, fs=fs,
                                              verbose=False)
            # Skip if too few beats
            if len(qrs_inds) < 2:
                continue
            # Normalize the signal
            sig_ch = normalize(sig_ch)
            # Get the beats
            beats, _ = get_beats(sig_ch, qrs_inds)
            if sig_name not in beat_bank.keys():
                beat_bank[sig_name] = []
            beat_bank[sig_name] = beat_bank[sig_name] + beats
    print('Finished obtaining beat bank')
    
    # Remove signals without beats from the dictionary
    for sig_name in beat_bank:
        if len(beat_bank[sig_name]) == 0:
            print('Obtained no beats for signal %s. Removing.' % sig_name)
            del(beat_bank[sig_name])
    return beat_bank

In [None]:
# Get the beats and display the results
beat_bank = get_beat_bank()
for sig_name in beat_bank:
    print('%d beats extracted for signal %s' % (len(beat_bank[sig_name]), sig_name))

In [None]:
# Visualize some obtained beats
for sig_name in beat_bank: 
    for beat in beat_bank[sig_name][:3]:
        plt.plot(beat)
        plt.title('Signal %s' % sig_name)
        plt.show()

## Section 3 - Calculating Features from Beats

For each signal type, we calculate features for its beats, and feed the feature array into an unsupervised classifier.

In the overall alarm classification challenge, each record has a set of features. In this challenge, each beat has a set of features.

In [None]:
def calc_beat_features(beat):
    """
    Calculate features from a single beat
    
    Parameters
    ----------
    beat : numpy array
        1d array of the beat signal
    """
    feature_labels = ['skew', 'kurtosis', 'lfp', 'mfp', 'hfp']
    
    features = [skew(beat), kurtosis(beat)] + list(calc_spectral_ratios(beat, fs=250))
    features = pd.DataFrame([features], columns=feature_labels)
    return features

In [None]:
# A dictionary of dataframes. Each keys is a signal name, and each value
# is a dataframes of features for all beats of that signal.
beat_bank_features = {}

for sig_name in beat_bank:
    pool = Pool(processes=cpu_count()-1)
    # Features for this signal
    beat_bank_features_sig = pool.map(calc_beat_features, beat_bank[sig_name])
    beat_bank_features_sig = pd.concat(beat_bank_features_sig)
    # Add to the dictionary
    beat_bank_features[sig_name] = beat_bank_features_sig
    print('Finished calculating beat features for signal %s' % sig_name)

In [None]:
# Visualize some features
for sig_name in beat_bank:
    print('Signal %s' % sig_name)
    display(beat_bank_features[sig_name].head())

## Section 4 - Training unsupervised classifiers for beats

K-Means Classification: http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html

The algorithm aims to separate the n-dimensional data into K clusters, so as to minimize the within-cluster sum of squares from the mean.

Therefore, we need only to pass in the raw data into the system. The challenge is in the previous step: choosing the features which will most effectively separate the different classes.

In [None]:
from sklearn.cluster import KMeans

In [None]:
# Train a k-means classifier for each signal type
beat_classifiers = dict([(sig_name ,[]) for sig_name in beat_bank])

for sig_name in beat_bank_features:
    clf_kmeans = KMeans(n_clusters=2, random_state=0).fit(beat_bank_features[sig_name])
    beat_classifiers[sig_name] = clf_kmeans

# clf_kmeans.cluster_centers_

In [None]:
# Visualize some of the beats
for sig_name in beat_classifiers:
    # Plot group 0 in blue, group 1 in red
    zero_inds = np.where(beat_classifiers[sig_name].labels_==0)[0] 
    one_inds = np.where(beat_classifiers[sig_name].labels_==1)[0]
    
    for i in range(min(2, len(zero_inds))):
        plt.plot(beat_bank[sig_name][zero_inds[i]], 'b')
        plt.title('Signal %s class 0' % sig_name)
        plt.show()
        
    for i in range(min(2, len(one_inds))):
        plt.plot(beat_bank[sig_name][one_inds[i]], 'r')
        plt.title('Signal %s class 1' % sig_name)
        plt.show()
                   