# Notebook 1 - Wavelets and Frequency Information

https://pywavelets.readthedocs.io/en/latest/ref/index.html

## Section 0 - Import libraries, and load metadata and beats

In [None]:
from multiprocessing import Pool, cpu_count
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.signal import cwt, ricker, convolve
from sklearn import svm, neighbors
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import pywt
import wfdb

from bc.beats import get_beats, get_beat_bank, BEAT_TYPES
from bc.io import ann_to_df
from bc.plot import plot_beat, plot_four_beats

base_dir = os.path.abspath('..')
data_dir = os.path.join(base_dir, 'data')

# Table of record names and the beat types they contain
beat_table = pd.read_csv(os.path.join(data_dir, 'beat-types.csv'), dtype={'record':object})
beat_table.set_index('record', inplace=True)

In [None]:
# Load Beats, applying the standard ecg bandpass filter
n_beats, n_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table,
                                   wanted_type='N', filter=True)
l_beats, l_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table,
                                   wanted_type='L' ,filter=True)
r_beats, r_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table,
                                   wanted_type='R', filter=True)
v_beats, v_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table,
                                   wanted_type='V', filter=True)

In [None]:
# Visualize one of each beat type
plot_four_beats(beats=[n_beats[0], l_beats[0], r_beats[0], v_beats[0]],
                centers=[n_centers[0], l_centers[0], r_centers[0], v_centers[0]])

# Regular qrs width is about 0.05s = 0.05 * 360 = 18 samples

## Section 1 - Explore Wavelets

In [None]:
# 360 * .2 = 72 samples in 0.2s.
# 40hz signal = 360 / 40 = 9 samples.

widths = np.arange(6, 73, 6)
cwtmatr = cwt(data=n_beats[0][:,0], wavelet=ricker, widths=widths)
plt.figure(figsize=(6.4, 4.8))
plt.imshow(cwtmatr, extent=[0, n_beats[0].shape[0], 72, 6], cmap='coolwarm', aspect='auto',
         vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max())

plot_beat(n_beats[0], n_centers[0], style='C0', title='Normal Beat')

cwtmatr = cwt(data=-n_beats[0][:,1], wavelet=ricker, widths=widths)
plt.figure(figsize=(6.4, 4.8))
plt.imshow(cwtmatr, extent=[0, n_beats[0].shape[0], 72, 6], cmap='coolwarm', aspect='auto',
         vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max())

cwtmatr = cwt(data=v_beats[0][:,0], wavelet=ricker, widths=widths)
plt.figure(figsize=(6.4, 4.8))
plt.imshow(cwtmatr, extent=[0, v_beats[0].shape[0], 72, 6], cmap='coolwarm', aspect='auto',
         vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max())

We are using the `pywt` library. Explore the abailable wavelet families and wavelets.

In [None]:
# All wavelet families
print(pywt.families(), '\n')
# All wavelets
print(pywt.wavelist(), '\n')
# All wavelets of a particular family
print(pywt.wavelist(family='bior'), '\n')
# Continuous wavelets
print(pywt.wavelist(kind='continuous'), '\n')

In [None]:
# Plot the continuous wavelet family functions, excluding complex values ones
for family in pywt.wavelist(kind='continuous'):
    wavelet = pywt.ContinuousWavelet(family)
    if not wavelet.complex_cwt:
        plt.plot(wavelet.wavefun()[0])
        plt.title(family)
        plt.show()

In [None]:
help(wavelet.wavefun)

In [None]:
wavelet = pywt.ContinuousWavelet('gaus1')
plt.figure()
plt.plot(wavelet.wavefun()[0])
plt.title('gaus1')
plt.show()

In [None]:
wavelet = pywt.ContinuousWavelet('gaus1')
for level in range(1, 9):
    plt.plot(wavelet.wavefun(level=level, length=200)[0] + level, '*-')
    plt.title(str(level))

In [None]:
wavelet = pywt.ContinuousWavelet('gaus1')
for length in [16, 32, 64, 128]:
    plt.plot(wavelet.wavefun(length=length)[0], '*-')
    plt.title(str(length))
plt.show()

In [None]:
coef, freqs = pywt.cwt(n_beats[0][:, 0], np.arange(1,30), 'gaus1')
plt.figure(figsize=(12, 9))
plt.imshow(coef, cmap='coolwarm')
plt.show()

In [None]:
coef, freqs = pywt.cwt(n_beats[0][:, 0], np.arange(1,30), 'gaus2', sampling_period=1/360, )
plt.figure(figsize=(12, 9))
plt.imshow(coef, cmap='coolwarm')
plt.show()

In [None]:
coef.shape

In [None]:
freqs

In [None]:

plt.figure()
for scalenum in range(coef.shape[0]):
    plt.plot(coef[scalenum, :])
    
plt.legend([str(s) for s in range(1, 30)])
plt.show()

In [None]:
# RBBB gaus1
coef, freqs = pywt.cwt(r_beats[0][:, 0], np.arange(1,30), 'gaus1', sampling_period=1/360, )
plt.figure(figsize=(12, 9))
plt.imshow(coef, cmap='coolwarm')
plt.show()

In [None]:
# RBBB gaus2
coef, freqs = pywt.cwt(r_beats[0][:, 0], np.arange(1,30), 'gaus2', sampling_period=1/360, )
plt.figure(figsize=(12, 9))
plt.imshow(coef, cmap='coolwarm')
plt.show()

## Section 2 - Building Features

In [None]:
# For each beat, we want a set of features to feed into a classifier
# Each beat will have two wavelet families tested on its two channels.
# We will get the maximum convolution value and the level at which it occurs.
feature_names = ['gaus1_max_II', 'gaus1_max_scale_II', 'gaus2_max_II', 'gaus2_max_scale_II',
            'gaus1_max_V', 'gaus1_max_scale_V', 'gaus2_max_V', 'gaus2_max_scale_V']



In [None]:
n_features = pd.DataFrame(columns=features)

In [None]:
def calc_wavelet_features(beat):
    """
    Calculate the 8 features for the 2 channel beat
    """
    scales = np.arange(1, 30)
    sig_name = ['MLII', 'V1']
    features = []
    
    for ch in range(2):
        # Reverse channel V1 for qrs complexes to match gaus1 wavelet deflection
        # Q: Why does this time domain reversal not affect our feature?
        if ch == 1:
            sig = -beat[:, ch]
            #sig = beat[::-1, ch]
        else:
            sig = beat[:, ch]
            
        for wavefun in ['gaus1', 'gaus2']:
            # Calculate continuous wavelet transform
            coef, freqs = pywt.cwt(sig, np.arange(1,30), wavefun, sampling_period=1/360)
            # The maximum value of the convolution array
            max_coef = coef.max()
            # The scale at which the maximum convolution value occurs
            max_scale = scales[np.where(coef==max_coef)[0][0]]
            # Save the features
            features += [max_coef, max_scale]
    return features
    

In [None]:
# Try on some beats
# Normal
features = calc_wavelet_features(n_beats[0])
features

In [None]:
# Try on some beats
# LBBB
features = calc_wavelet_features(l_beats[0])
features

In [None]:
# Try on some beats
# LBBB
features = calc_wavelet_features(l_beats[0])
features

In [None]:
# Try on some beats
# Vent
features = calc_wavelet_features(v_beats[0])
features

## Section 3 - Calculate wavelet features and classify

In [None]:
# Get all the beats and labels together
all_beats = n_beats + l_beats + r_beats + v_beats
labels = [0]*len(n_beats) + [1]*len(l_beats) + [2]*len(r_beats) + [3]*len(v_beats)

In [None]:
# Calculate features for all records using multiple cpus
pool = Pool(processes=cpu_count() - 1)
features = pool.map(calc_wavelet_features, all_beats)

# Combine features into a data frame
features = pd.DataFrame(features, columns=feature_names)
features['type'] = labels

In [None]:
# Question: How many rows and columns should the feature matrix have?
features.head()

In [None]:
# Visualize some results so that we can see if there is any inter-group difference

In [None]:
# Recall the features
# feature_names = ['gaus1_max_II', 'gaus1_max_scale_II', 'gaus2_max_II', 'gaus2_max_scale_II',
#            'gaus1_max_V', 'gaus1_max_scale_V', 'gaus2_max_V', 'gaus2_max_scale_V']

# Show some features
for f in feature_names:
    plt.figure()
    for groupnum in range(4):
        plt.hist(features.loc[np.equal(labels, groupnum), f])
    plt.title(f)
    plt.legend(beat_types)
    plt.show()
    
# Question: Which features do you think will be more useful?
# Gaus1 is the sinusoud, gaus2 is the hat.


In [None]:
# Split data into training and testing sets
x_train, x_test, y_train, y_test = train_test_split(features, labels,
                                                    train_size=0.75, test_size=0.25,
                                                    random_state=0)
print('Number of training records: %d' % len(x_train))
print('Number of testing records: %d' % len(x_test))

In [None]:
# LR
clf_lr = LogisticRegression()
clf_lr.fit(x_train, y_train)
y_predict_lr = clf_lr.predict(x_test)

# KNN
clf_knn = neighbors.KNeighborsClassifier()
clf_knn.fit(x_train, y_train)
y_predict_knn = clf_knn.predict(x_test)

# SVM
clf_svm = svm.SVC()
clf_svm.fit(x_train, y_train)
y_predict_svm = clf_svm.predict(x_test)

# And GB
clf_gb = GradientBoostingClassifier()
clf_gb.fit(x_train, y_train)
y_predict_gb = clf_gb.predict(x_test)

In [None]:
print('Logistic Regression')
print(classification_report(y_test, y_predict_lr, target_names=BEAT_TYPES))

print('K Nearest Neigbors')
print(classification_report(y_test, y_predict_knn, target_names=BEAT_TYPES))

print('Support Vector Machines')
print(classification_report(y_test, y_predict_svm, target_names=BEAT_TYPES))

print('Gradient Boosting')
print(classification_report(y_test, y_predict_gb, target_names=BEAT_TYPES))


In [None]:
w = pywt.Wavelet('gaus1')

