In [1]:
import numpy as np
import pandas as pd
from scipy import signal
import matplotlib.pyplot as plt

from get_iEEG_data import *
from spike_detector import *
from iEEG_helper_functions import *

In [2]:
def create_pwd_file(username, password, fname=None):
    if fname is None:
        fname = "{}_ieeglogin.bin".format(username[:3])
    with open(fname, "wb") as f:
        f.write(password.encode())
    print("-- -- IEEG password file saved -- --")


create_pwd_file("dma", "mycqEv-pevfo4-roqfan")

with open("dma_ieeglogin.bin", "r") as f:
    s = Session("dma", f.read())

ds = s.open_dataset("HUP210_phaseII")
all_channel_labels = np.array(ds.get_channel_labels())
label_idxs = electrode_selection(all_channel_labels)
labels = all_channel_labels[label_idxs]

-- -- IEEG password file saved -- --


In [3]:
ieeg_data, fs = get_iEEG_data(
    "dma",
    "dma_ieeglogin.bin",
    "HUP210_phaseII",
    (179677 + (72600 / 1024)) * 1e6,
    (179677 + (72600 / 1024) + 60) * 1e6,
    labels,
)

fs = int(fs)

In [4]:
good_channels_res = detect_bad_channels_optimized(ieeg_data.to_numpy(), fs)
good_channel_indicies = good_channels_res[0]
good_labels = labels[good_channel_indicies]
ieeg_data = ieeg_data[good_labels]

ieeg_data = common_average_montage(ieeg_data)


# Apply the filters directly on the DataFrame
ieeg_data = pd.DataFrame(notch_filter(ieeg_data.values, 59, 61, fs))
ieeg_data = pd.DataFrame(bandpass_filter(ieeg_data.values, 1, 70, fs))

ieeg_data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,180,181,182,183,184,185,186,187,188,189
0,7.507928,2.125174,8.540212,0.357006,-4.551463,-17.110671,-15.294157,-75.082156,-75.407262,-30.703291,...,6.578267,5.322186,4.252136,2.192042,-7.168558,-26.368686,-64.437617,-85.056307,-54.985452,-4.403216
1,10.963548,4.301169,10.330361,1.503340,-3.682178,-15.489678,-15.263365,-77.764374,-76.767731,-31.167498,...,7.001846,5.936817,4.790761,2.319565,-6.585044,-25.207571,-64.402307,-87.344945,-57.820404,-7.098123
2,14.321397,6.500112,12.129850,2.689941,-2.773226,-13.861167,-15.234097,-80.364181,-78.048396,-31.563032,...,7.373700,6.475069,5.277968,2.423922,-6.028722,-24.112919,-64.411549,-89.644190,-60.642614,-9.461599
3,17.504992,8.747131,13.955796,3.954757,-1.783153,-12.209362,-15.203165,-82.801981,-79.172064,-31.829357,...,7.645019,6.868037,5.662821,2.480735,-5.526403,-23.139549,-64.495517,-91.952975,-63.436568,-11.226597
4,20.461515,11.058034,15.830545,5.326356,-0.674855,-10.512439,-15.165433,-85.014694,-80.079047,-31.926122,...,7.776745,7.064336,5.900559,2.467215,-5.103763,-22.328453,-64.674011,-94.265431,-66.189038,-12.228998
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
61435,23.432388,8.458657,5.793279,11.920832,10.546166,3.854880,-20.265253,-30.479836,-43.780163,-34.553647,...,-1.145399,-6.779017,-14.032431,-27.248361,-41.809307,-76.203978,-86.607226,-58.634684,-17.750950,18.026751
61436,21.593321,7.493015,4.701843,10.709001,9.761437,3.889029,-19.303377,-29.105274,-41.542769,-32.141293,...,-1.418810,-6.571751,-13.820478,-26.396635,-40.424244,-73.114258,-81.062226,-54.028594,-15.561067,17.126227
61437,19.603638,6.573561,3.601326,9.488630,8.985855,3.965517,-18.385736,-27.856705,-39.543270,-29.851616,...,-1.643307,-6.299833,-13.592476,-25.556373,-39.242612,-70.460163,-75.735613,-49.457882,-13.522115,15.012783
61438,17.515713,5.710749,2.502656,8.268827,8.220314,4.072075,-17.505036,-26.709120,-37.733481,-27.666732,...,-1.824955,-5.975142,-13.343379,-24.713700,-38.207230,-68.125448,-70.577384,-44.924582,-11.617095,11.994490


In [5]:
output = spike_detector(
    data=ieeg_data,
    fs=fs,
    labels=good_labels,
)
print(f"{len(np.unique(output[:, 2]))} spikes detected")

Using optimized spike detector...
35 spikes detected


In [6]:
output = output.astype(int)
output

array([[  619,   154,     0],
       [  633,    69,     0],
       [  644,    39,     0],
       ...,
       [60160,    41,    34],
       [60163,   154,    34],
       [60164,   153,    34]])

In [7]:
def line_length(x):
    return np.sum(np.absolute(np.ediff1d(x)))


def zero_crossing_around_mean(x):
    x_bar = np.mean(x)
    x_prime = x - x_bar
    return np.sum((x_prime[:-1] * x_prime[1:]) < 0)

In [8]:
def extract_spike_morphology(spike_signal):
    """
    function to find the morphological features of a spike
    major assumption - that the peak is closest to the spike detection

    input: myspike - the single spike to be analyzed
    """
    if len(spike_signal) == 0:
        return None, None, False, "Short segment"

    # detrend the spike (make the mean = 0)
    spike_signal = spike_signal - np.mean(spike_signal)

    # find the peak closest to the spike detection (this will be our main reference point)

    allmaxima = signal.argrelextrema(spike_signal, np.greater)[
        0
    ]  # find all the maximas
    allminima = signal.argrelextrema(spike_signal, np.less)[0]  # find all the minimas

    stndev = np.std(spike_signal)
    peaks_pos = peaks_pos = signal.find_peaks(
        spike_signal[1000 - 50 : 1000 + 50], height=stndev
    )[0]
    peaks_neg = signal.find_peaks(
        -1 * spike_signal[1000 - 50 : 1000 + 50], height=stndev
    )[0]
    peaks_pos = peaks_pos + 950
    peaks_neg = peaks_neg + 950
    combined_peaks = [peaks_pos, peaks_neg]
    combined_peaks = [x for x in combined_peaks for x in x]

    for peaks in combined_peaks:
        if (spike_signal[peaks] > spike_signal[peaks - 3]) & (
            spike_signal[peaks] < spike_signal[peaks + 3]
        ):
            combined_peaks.remove(peaks)
        if (spike_signal[peaks] < spike_signal[peaks - 3]) & (
            spike_signal[peaks] > spike_signal[peaks + 3]
        ):
            combined_peaks.remove(peaks)

    if not combined_peaks:
        peak = None
        left_point = None
        right_point = None
        slow_end = None
        slow_max = None

    else:
        if np.size(combined_peaks) > 1:
            peak_from_mid = [x - 1000 for x in combined_peaks]
            peak_idx = np.argmin(np.abs(peak_from_mid))
            peak = combined_peaks[peak_idx]
        else:
            peak_idx = np.argmax(np.abs(spike_signal[combined_peaks]))
            peak = combined_peaks[peak_idx]

        # find the left and right points

        # from peak we will navigate to either baseline or the next minima/maxima
        # here we will trim down the potential left/right peaks/troughs to the 5 closest to the peak
        if (spike_signal[peak + 3] > spike_signal[peak]) & (
            spike_signal[peak - 3] > spike_signal[peak]
        ):  # negative peak
            left_points_trim = allmaxima[allmaxima < peak][-3::]
            right_points_trim = allmaxima[allmaxima > peak][0:3]
        if (spike_signal[peak + 3] < spike_signal[peak]) & (
            spike_signal[peak - 3] < spike_signal[peak]
        ):  # positive peak
            left_points_trim = allminima[allminima < peak][-3::]
            right_points_trim = allminima[allminima > peak][0:3]

        left_points_trim2 = []
        right_points_trim2 = []
        for i, (left, right) in enumerate(zip(left_points_trim, right_points_trim)):
            if (spike_signal[peak + 3] > spike_signal[peak]) & (
                spike_signal[peak - 3] > spike_signal[peak]
            ):  # negative peak
                if spike_signal[left] > 0.5 * spike_signal[peak]:
                    left_points_trim2.append(left)
                if spike_signal[right] > 0.5 * spike_signal[peak]:
                    right_points_trim2.append(right)
            if (spike_signal[peak + 3] < spike_signal[peak]) & (
                spike_signal[peak - 3] < spike_signal[peak]
            ):  # positive peak
                if spike_signal[left] < 0.5 * spike_signal[peak]:
                    left_points_trim2.append(left)
                if spike_signal[right] < 0.5 * spike_signal[peak]:
                    right_points_trim2.append(right)

        if not left_points_trim2:
            left_points_trim2 = [x for x in left_points_trim]
        if not right_points_trim2:
            right_points_trim2 = [x for x in right_points_trim]

        # find the closest spike with the greatest amplitude difference? try to balance this?
        left_point = []
        right_point = []
        if (spike_signal[peak + 3] > spike_signal[peak]) & (
            spike_signal[peak - 3] > spike_signal[peak]
        ):  # negative peak
            dist_from_peak_left = left_points_trim2 - peak
            dist_from_peak_right = right_points_trim2 - peak
            # restrict what we are looking at by looking at the cloesest to the peak (50 samples from peak)
            left_points_trim2 = [
                x + peak for x in dist_from_peak_left if (x <= 50) & (x >= -50)
            ]
            right_points_trim2 = [
                x + peak for x in dist_from_peak_right if (x <= 50) & (x >= -50)
            ]

            # backup if it doesn't find any (e.g. wide spike)
            if not left_points_trim2:
                left_points_trim2 = [
                    x + peak for x in dist_from_peak_left if (x <= 100) & (x >= -100)
                ]
                if not left_points_trim2:
                    left_points_trim2 = [x for x in left_points_trim]
            if not right_points_trim2:
                right_points_trim2 = [
                    x + peak for x in dist_from_peak_right if (x <= 100) & (x >= -100)
                ]
                if not right_points_trim2:
                    right_points_trim2 = [x for x in right_points_trim]

            if not left_points_trim2:
                left_point = None
                right_point = None

            if not right_points_trim2:
                right_point = None
                left_point = None

            value_leftpoints = spike_signal[left_points_trim2]
            value_rightpoints = spike_signal[right_points_trim2]
            left_value_oi = np.argmax(value_leftpoints)
            right_value_oi = np.argmax(value_rightpoints)
            left_point = left_points_trim2[left_value_oi]
            right_point = right_points_trim2[right_value_oi]

        if (spike_signal[peak + 3] < spike_signal[peak]) & (
            spike_signal[peak - 3] < spike_signal[peak]
        ):  # positive peak
            dist_from_peak_left = left_points_trim2 - peak
            dist_from_peak_right = right_points_trim2 - peak
            # restrict what we are looking at by looking at the cloesest to the peak (50 samples from peak)
            left_points_trim2 = [
                x + peak for x in dist_from_peak_left if (x <= 50) & (x >= -50)
            ]
            right_points_trim2 = [
                x + peak for x in dist_from_peak_right if (x <= 50) & (x >= -50)
            ]

            # backup if it doesn't find any (e.g. wide spike)
            if not left_points_trim2:
                left_points_trim2 = [
                    x + peak for x in dist_from_peak_left if (x <= 100) & (x >= -100)
                ]
                if not left_points_trim2:
                    left_points_trim2 = [x for x in left_points_trim]
            if not right_points_trim2:
                right_points_trim2 = [
                    x + peak for x in dist_from_peak_right if (x <= 100) & (x >= -100)
                ]
                if not right_points_trim2:
                    right_points_trim2 = [x for x in right_points_trim]

            if not left_points_trim2:
                left_point = None
                right_point = None

            if not right_points_trim2:
                right_point = None
                left_point = None

            else:
                value_leftpoints = spike_signal[left_points_trim2]
                value_rightpoints = spike_signal[right_points_trim2]
                left_value_oi = np.argmin(value_leftpoints)
                right_value_oi = np.argmin(value_rightpoints)
                left_point = left_points_trim2[left_value_oi]
                right_point = right_points_trim2[right_value_oi]

        # now we will look for the start and end of the aftergoing slow wave.
        # for positive peaks
        counter = 0
        if (spike_signal[peak + 3] < spike_signal[peak]) & (
            spike_signal[peak - 3] < spike_signal[peak]
        ):  # positive peak
            right_of_right_peaks = [x for x in allmaxima if x > right_point]
            right_of_right_troughs = [x for x in allminima if x > right_point]
            slow_start = right_point

            slow_end = []
            for peaks, troughs in zip(right_of_right_peaks, right_of_right_troughs):
                if zero_crossing_around_mean(spike_signal[right_point:peaks]) >= 1:
                    counter += 1
                if (counter >= 1) | (np.abs(spike_signal[right_point]) >= 100):
                    if (
                        (spike_signal[troughs] < 0)
                        | (spike_signal[troughs] < spike_signal[right_point])
                    ) & (troughs - right_point >= 50):
                        slow_end = troughs
                        break

        # for negative peaks
        if (spike_signal[peak + 3] > spike_signal[peak]) & (
            spike_signal[peak - 3] > spike_signal[peak]
        ):  # negative peak
            right_of_right_peaks = [x for x in allmaxima if x > right_point]
            right_of_right_troughs = [x for x in allminima if x > right_point]
            slow_start = right_point

            slow_end = []
            for peaks, troughs in zip(right_of_right_peaks, right_of_right_troughs):
                if zero_crossing_around_mean(spike_signal[right_point:peaks]) >= 1:
                    counter += 1
                if (counter >= 1) | (np.abs(spike_signal[right_point]) >= 100):
                    if (
                        (spike_signal[peaks] > 0)
                        | (spike_signal[peaks] > spike_signal[right_point])
                    ) & (peaks - right_point >= 50):
                        slow_end = peaks
                        break

        # find slow wave peak
        if slow_end:
            slow_max_idx = np.argmax(spike_signal[right_point:slow_end]) + right_point
            slow_min_idx = np.argmin(spike_signal[right_point:slow_end]) + right_point
            slow_max = spike_signal[slow_max_idx] - spike_signal[slow_min_idx]

        if not slow_end:
            slow_end = None
            slow_max = None

    if not peak:
        rise_amp = None
        decay_amp = None
        slow_width = None
        slow_amp = None
        rise_slope = None
        decay_slope = None
        average_amp = None
        linelen = None

    elif not slow_end:
        rise_amp = np.abs(spike_signal[peak] - spike_signal[left_point])
        decay_amp = np.abs(spike_signal[peak] - spike_signal[right_point])
        slow_width = None
        slow_amp = None
        rise_slope = (spike_signal[peak] - spike_signal[left_point]) / (
            peak - left_point
        )
        decay_slope = (spike_signal[right_point] - spike_signal[peak]) / (
            right_point - peak
        )
        average_amp = rise_amp + decay_amp / 2
        linelen = None

    else:
        rise_amp = np.abs(spike_signal[peak] - spike_signal[left_point])
        decay_amp = np.abs(spike_signal[peak] - spike_signal[right_point])
        slow_width = slow_end - right_point
        slow_amp = slow_max
        rise_slope = (spike_signal[peak] - spike_signal[left_point]) / (
            peak - left_point
        )
        decay_slope = (spike_signal[right_point] - spike_signal[peak]) / (
            right_point - peak
        )
        average_amp = rise_amp + decay_amp / 2
        linelen = line_length(spike_signal[left_point:slow_end])

    basic_features = peak, left_point, right_point, slow_end, slow_max
    advanced_features = (
        rise_amp,
        decay_amp,
        slow_width,
        slow_amp,
        rise_slope,
        decay_slope,
        average_amp,
        linelen,
    )

    if left_point is None or right_point is None or slow_end is None:
        is_valid = False
        bad_reason = "Bad feature"
    else:
        is_valid = True
        bad_reason = None

    return basic_features, advanced_features, is_valid, bad_reason

In [9]:
num_good = 0

for spike in output:
    channel_id = spike[1]
    peak_index = spike[0]
    spike_signal = ieeg_data[peak_index - 1000 : peak_index + 1000][
        channel_id
    ].to_numpy()

    basic_features, advanced_features, is_valid, bad_reason = extract_spike_morphology(
        spike_signal
    )

    if is_valid:
        num_good += 1
        # peak, left_point, right_point, slow_end, slow_max = basic_features
        # plt.plot(spike_signal)
        # plt.plot(peak, spike_signal[peak], "x")
        # plt.plot(left_point, spike_signal[left_point], "o")
        # plt.plot(right_point, spike_signal[right_point], "o")
        # plt.plot(slow_end, spike_signal[slow_end], "o", color="k")
        # plt.title("A spike")
        # plt.xlim(250, 1750)
        # plt.show()
    elif bad_reason != "Short segment":
        print(bad_reason)
        # plt.plot(spike_signal)
        # plt.title(f"NOT a spike because of {bad_reason}")
        # plt.xlim(250, 1750)
        # plt.show()

In [10]:
num_good

384

In [11]:
len(output)

398