In [1]:
import numpy as np
import scipy
from scipy.signal import find_peaks, argrelextrema, butter, filtfilt, sosfiltfilt, sosfilt, bode, freqz, freqs, sosfreqz, iirfilter
from scipy.fftpack import fft
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import control as ct
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, mean_absolute_error
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from scipy.optimize import curve_fit
import heapq
import struct

In [14]:
class SpikeDetector(object):
    def __init__(self, x, t, 
                 fs,
                 blackout_period=0.04, 
                 b = [0.487305044246353, 0, -0.487305044246353], 
                 a = [0, -0.846977922376899, 0.025389911507294],
                 alpha=0.0002, 
                 beta=3.75, 
                 MAX_ASYM=4, 
                 MIN_RATIO=1.01, 
                 MAX_COST=0.4, 
                 MIN_THRES=-0.004,
                 MAX_THRES=0.002, 
                 MAX_LDIST=0.02, 
                 MIN_RDIST=0.0001, 
                 MAX_RDIST=0.02):
        # Original signal & time axis
        self.x = x
        self.t = t
        self.fs = fs
        
        # Time window (blackout) period 
        self.blackout_period = blackout_period
        
        # Transfer function coeffs (for IIR filter)
        self.b = np.array(b)
        self.a = np.array(a)
    
    def filter_signal(self, signal):
        return iir_filter(self.b, self.a, signal)
    
    # Detect the signals   
    def detect(self):        
        # Filter signal
        self.filtered_signal = iir_filter(self.b, self.a, self.x)
 
        # Smooth signal
        smoothed_signal, trend = brown_double_es(self.filtered_signal, alpha=0.1) # Choose between {0.1, 0.2, ..., 0.9} with smallest MSE
        
        # Find the local minima and maxima for each fixed-length window 
        local_minima, local_maxima = get_local_extrema(self.filtered_signal, self.blackout_period)

        # Detect whether spikes are detected based on characteristic values (vs thresholds)
        return detect_spikes(self.x, self.filtered_signal, smoothed_signal, local_maxima, local_minima)
       
    def plot_spikes(self, spikes):
        plt.figure()
        fig, ax = plt.subplots(1, 1, figsize=(10, 6), dpi = 600)
        plt.subplots_adjust(hspace=0.4)
        
        df = get_spikes_df(spikes)

        # Plot signals for spikes 
        plt.plot(self.t, self.filtered_signal, color='blue', linestyle='-', label='Filtered signal (y)', zorder=1)
        plt.scatter(df['rp'], df['rv'], color='red', label='Detected Spikes')

        plt.title("Filtered Signal with Detected Spikes")
        plt.xlabel("Time (s)")
        plt.ylabel("Voltage (mV)")
        plt.legend()

        plt.show()
        
    def spikes_dataframe(self, spikes):
        return get_spikes_df(spikes)
        
# Converts time t (in seconds) to index in data
def time_to_index(t):
    return int(t * fs)

# IIR filter implementation
# Assumes b, a are non-empty!
def iir_filter(b, a, x):        
    # Output of filter
    y = []

    # Internal states
    w = np.zeros(len(b) + 1) # w: pad a zero on each end (for code readability)

    y_val = 0

    for i in np.arange(0, len(x)):
        y_val = w[1] + x[i] * b[0] - y_val * a[0]

        for j in np.arange(1, len(b)):
            w[j] = w[j + 1] + x[i] * b[j] - y_val * a[j] 

        y.append(y_val)

    y = np.array(y)
    return y

# Brown's Double Exponential Smoothing (DES) Implementation for Smoothing Signals
def brown_double_es(y, alpha=0.1):
    single = np.zeros(y.shape)
    double = np.zeros(y.shape)
    
    level = np.zeros(y.shape)
    trend = np.zeros(y.shape)
    
    y_hat = np.zeros(len(y) + 1)
    
    for i in range(len(y) - 1):
        single[i] = alpha * y[i] + (1 - alpha) * single[i - 1]
        double[i] = alpha * single[i] + (1 - alpha) * double[i - 1]
        
        level[i] = 2 * single[i] - double[i]
        trend[i] = (alpha / (1 - alpha)) * (single[i] - double[i])
        
        y_hat[i + 1] = level[i] + trend[i]
    
    return np.array(y_hat[:len(y_hat) - 1]), np.array(trend)

# Identify fit values
# Post-processing: For any consecutive [min, min, min, ...] or [max, max, max, ...] pattern, ONLY keep the most extreme extremum!
# Post-processing: For any consecutive [min, min, min, ...] or [max, max, max, ...] pattern, ONLY keep the most extreme extremum!
def post_processing(y, local_extrema):
    index = 0
    while index < len(local_extrema) - 1:

        # Find the next extrema with high value    
        vals = []
        j = index + 1
        while j < len(local_extrema) and local_extrema[j][1] == local_extrema[index][1]:
            vals.append(y[int(local_extrema[j][0])])
            j += 1

        #             i   j=i+1     i   j=i+1
        # If indeed [min, max]  or [max, min], skip
        if j == index + 1:
            index = j
            continue

        # Otherwise, keep only the most extreme extremum
        else:
            type_ = local_extrema[index][1]
            assert(type_ == local_extrema[j - 1][1])
            if type_ == 'min':
                offset = np.argmin(vals)
                local_extrema[index : j] = local_extrema[index + offset]
            else:
                offset = np.argmax(vals)
                local_extrema[index : j] = local_extrema[index + offset]

        index = j   
    
    # Remove all repeated values (keep only 1 of them)
    local_extrema_unique = []
    index = 0 
    
    while index < len(local_extrema):
        j = index + 1 
        
        while j < len(local_extrema) and local_extrema[j][0] == local_extrema[index][0]:
            j += 1
        
        local_extrema_unique.append(tuple(local_extrema[index]))
        index = j
    
    return local_extrema_unique


# Get the indices for each window of length `blackout_period`
# Get the indices for each window of length `blackout_period`
def get_local_extrema(y, blackout_period = 0.04):
    index = 0
    
    local_extrema = []
    
    # Convert window time to indices
    blackout_period_in_indices = time_to_index(blackout_period)
    
    while index < len(y) - blackout_period_in_indices:
        i1 = index
        i2 = min(len(y), index + blackout_period_in_indices)
        
        # Find the local min and max of signals
        local_min = index + np.argmin(y[i1 : i2])
        local_max = index + np.argmax(y[i1 : i2])
        
        local_extrema.append((int(local_min), 'min'))
        local_extrema.append((int(local_max), 'max'))
        
        # Move to next window
        index += 1
    
    # Sort the local extrema
    local_extrema.sort()
    
    # Perform post-processing on local_extrema
    local_extrema = np.array(local_extrema)
    local_extrema = post_processing(y, local_extrema)
    
    # Add the local min and max accordingly
    local_minima = []
    local_maxima = []
    for tup in local_extrema:
        if tup[1] == 'min':
            local_minima.append(int(tup[0]))
        else:
            local_maxima.append(int(tup[0]))
    
    # Post-processing: remove first local extrema if it's a minima (we don't need it!)
    if len(local_minima) > 0 and len(local_maxima) > 0 and local_minima[0] < local_maxima[0]:
        local_minima = local_minima[1:]
    
    return local_minima, local_maxima

# Determine the fit values between a signal 
# idx: index of local max/min
def fit_values(y, local_minima, local_maxima, idx):
    
    # Get left, middle, right timestamps
    lp = t[local_maxima[idx]]
    mp = t[local_minima[idx]]
    rp = t[local_maxima[idx + 1]] # aka end time of window

    # Get left, middle, right voltages
    lv = y[local_maxima[idx]]
    mv = y[local_minima[idx]]
    rv = y[local_maxima[idx + 1]]
    
    return lp, mp, rp, lv, mv, rv


# Calculate the mean absolute deviation (MAD) of a signal
# MAD = sum(m + alpha * (|x| - m)) / sum(|x| - m), 
# where m is the mean of the smoothed signal over a time period, and x is an original signal value
# alpha - user-configurable multiplicative constant, a small number (0.0002)
def mean_absolute_deviation(smoothed_data, original_data, i1, i2, alpha):
    mean = np.mean(smoothed_data[i1:i2])
    dev = np.abs(original_data[i1:i2]) - mean
    negative_deflection = np.sum(dev)
    mad = np.abs(np.sum(mean + alpha * dev) / negative_deflection)
    return mad


# Calculate the threshold 
# thres = beta * m, where m is the mean of the signal over a time period
def calulate_threshold(smoothed_signal, i1, i2, beta):
    return beta * np.mean(smoothed_signal[i1:i2])

# Get characteristic values on the intervals i1, i2, based on previously computed fit values (lp, mp, rp, lv, mv, rv)
# alpha: hyperparam for MAD calculation
# beta: hyperparam for threshold calculation
def characteristic_values(x, smoothed_signal, lp, mp, rp, lv, mv, rv, i1, i2, alpha, beta):
    # Compute characteristic values
    ldist = mp - lp
    rdist = rp - mp
    ratio = rv / abs(mv)
    asym = lv / rv
    esterr = mean_absolute_deviation(smoothed_signal, x, i1, i2, alpha) # esterr - MAD of signal within time window
    cost = esterr / abs(mv)
    
    # Compute threshold
    thres = calulate_threshold(smoothed_signal, i1, i2, beta)
    
    return ldist, rdist, ratio, asym, esterr, cost, thres


# Compare characteristic values with threshold
# idx: index of the local maxima/minima
def is_spike(x, filtered_signal, smoothed_signal, local_maxima, local_minima, idx, alpha=0.0002, beta=3.75, MAX_ASYM=0.3, MIN_RATIO=4/3, MAX_COST=0.4, MIN_THRES=0.2, MAX_THRES=0.7, MAX_LDIST=0.02, MIN_RDIST=0.0001, MAX_RDIST=0.02):
    # assert beta in [0, 7.875] range (Q3.3)
    assert(beta >= 0 and beta <= 7.875)
    
    # Compute fit values
    lp, mp, rp, lv, mv, rv = fit_values(filtered_signal, local_minima, local_maxima, idx)
    
    # Compute characteristic values
    ldist, rdist, ratio, asym, esterr, cost, thres = characteristic_values(x, smoothed_signal, lp, mp, rp, lv, mv, rv, local_maxima[idx], local_maxima[idx + 1], alpha, beta)
    
    m1 = rv * MAX_ASYM # compare with lv
    m2 = abs(mv) * MIN_RATIO # compare with rv
    m3 = abs(mv) * MAX_COST # compare with esterr
    
    # Determine whether signal represents a spike 
    # spike: rv MUST be positive!!
    if rv > 0: 
        # lv must be under a certain value (m1)
        if lv > m1:
            return 0
        # rv must be at least a certain value (m2)
        if rv < m2:
            return 0
    else:
        return 0

    # threshold thres must be in specified range [MIN_THRES, MAX_THRES]
    if thres < MIN_THRES:
        return 0
    
    if thres > MAX_THRES:
        return 0
    
    # ldist cannot exceed MAX_LDIST
    if ldist > MAX_LDIST: 
        return 0
    
    # rdist must be in specified range [MIN_RDIST, MAX_RDIST]
    if rdist < MIN_RDIST:
        return 0
    
    if rdist > MAX_RDIST:
        return 0
    
    # MAD must be under certain value (m3)
    if esterr > m3:
        return 0
    
    spike_payload = {
        'lp': lp,
        'mp': mp,
        'rp': rp,
        'lv': lv,
        'mv': mv,
        'rv': rv,
        'ldist': ldist,
        'rdist': rdist,
        'ratio': ratio,
        'asym': asym,
        'cost': cost,
        'thres': thres
    }
    
    return (1, spike_payload)


# Detect ALL spikes
def detect_spikes(x, filtered_signal, smoothed_signal, local_maxima, local_minima, alpha = 0.0002, beta = 3.75, MIN_RATIO = 1.33, MAX_ASYM = 1, MAX_COST = 2, MIN_THRES = -0.5, MAX_THRES = 0.1, MAX_LDIST = 0.2, MIN_RDIST = 5e-5, MAX_RDIST = 0.2):
    # Sample each window (based on the local max-min-max values)
    idx = 0
    spikes = []

    while idx < len(local_maxima) - 1:
        if idx < len(local_minima): # bounds check       
            assert(local_maxima[idx] < local_minima[idx])
            assert(local_minima[idx] < local_maxima[idx + 1])

            # Determine whether a spike is detected
            spike = is_spike(x, filtered_signal, smoothed_signal, local_maxima, local_minima, idx, alpha=alpha, beta=beta, MAX_ASYM=MAX_ASYM, MIN_RATIO=MIN_RATIO, MAX_COST=MAX_COST, MIN_THRES=MIN_THRES, MAX_THRES=MAX_THRES, MAX_LDIST=MAX_LDIST, MIN_RDIST=MIN_RDIST, MAX_RDIST=MAX_RDIST)

            if spike:
                spikes.append(spike)
                idx += 2
                continue

            idx += 1

    return spikes

def get_spikes_df(spikes):
    spike_payloads = [spike[1] for spike in spikes]
    df = pd.DataFrame(spike_payloads)
    return df