# Playground for Interictal Spike Detection

In [6]:
!rm -rf __pycache__

import numpy as np
import pandas as pd
import seaborn as sns
from ieeg.auth import Session
from matplotlib import pyplot as plt
from scipy.signal import butter, filtfilt, iirnotch

from get_iEEG_data import *
from spike_detector import *
from iEEG_helper_functions import *

In [7]:
def create_pwd_file(username, password, fname=None):
    if fname is None:
        fname = "{}_ieeglogin.bin".format(username[:3])
    with open(fname, "wb") as f:
        f.write(password.encode())
    print("-- -- IEEG password file saved -- --")


create_pwd_file("dma", "mycqEv-pevfo4-roqfan")

with open("dma_ieeglogin.bin", "r") as f:
    s = Session("dma", f.read())

ds = s.open_dataset("HUP210_phaseII")
all_channel_labels = np.array(ds.get_channel_labels())
label_idxs = electrode_selection(all_channel_labels)
labels = all_channel_labels[label_idxs]

ieeg_data, fs = get_iEEG_data(
    "dma",
    "dma_ieeglogin.bin",
    "HUP210_phaseII",
    (179677 + (72600 / 1024)) * 1e6,
    (179677 + (72600 / 1024) + 60) * 1e6,
    labels,
)

fs = int(fs)

-- -- IEEG password file saved -- --


In [8]:
good_channels_res = detect_bad_channels_optimized(ieeg_data.to_numpy(), fs)
good_channel_indicies = good_channels_res[0]
good_labels = labels[good_channel_indicies]
ieeg_data = ieeg_data[good_labels]

ieeg_data = common_average_montage(ieeg_data)


def notch_filter(data, low_cut, high_cut, fs, order=4):
    nyq = 0.5 * fs
    low = low_cut / nyq
    high = high_cut / nyq
    b, a = iirnotch(w0=(low + high) / 2, Q=30, fs=fs)
    y = filtfilt(b, a, data, axis=0)
    return y


def bandpass_filter(data, lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype="band")
    y = filtfilt(b, a, data, axis=0)
    return y


# Apply the filters directly on the DataFrame
ieeg_data = pd.DataFrame(notch_filter(ieeg_data.values, 59, 61, fs))
ieeg_data = pd.DataFrame(bandpass_filter(ieeg_data.values, 1, 70, fs))

ieeg_data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,180,181,182,183,184,185,186,187,188,189
0,7.507928,2.125174,8.540212,0.357006,-4.551463,-17.110671,-15.294157,-75.082156,-75.407262,-30.703291,...,6.578267,5.322186,4.252136,2.192042,-7.168558,-26.368686,-64.437617,-85.056307,-54.985452,-4.403216
1,10.963548,4.301169,10.330361,1.503340,-3.682178,-15.489678,-15.263365,-77.764374,-76.767731,-31.167498,...,7.001846,5.936817,4.790761,2.319565,-6.585044,-25.207571,-64.402307,-87.344945,-57.820404,-7.098123
2,14.321397,6.500112,12.129850,2.689941,-2.773226,-13.861167,-15.234097,-80.364181,-78.048396,-31.563032,...,7.373700,6.475069,5.277968,2.423922,-6.028722,-24.112919,-64.411549,-89.644190,-60.642614,-9.461599
3,17.504992,8.747131,13.955796,3.954757,-1.783153,-12.209362,-15.203165,-82.801981,-79.172064,-31.829357,...,7.645019,6.868037,5.662821,2.480735,-5.526403,-23.139549,-64.495517,-91.952975,-63.436568,-11.226597
4,20.461515,11.058034,15.830545,5.326356,-0.674855,-10.512439,-15.165433,-85.014694,-80.079047,-31.926122,...,7.776745,7.064336,5.900559,2.467215,-5.103763,-22.328453,-64.674011,-94.265431,-66.189038,-12.228998
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
61435,23.432388,8.458657,5.793279,11.920832,10.546166,3.854880,-20.265253,-30.479836,-43.780163,-34.553647,...,-1.145399,-6.779017,-14.032431,-27.248361,-41.809307,-76.203978,-86.607226,-58.634684,-17.750950,18.026751
61436,21.593321,7.493015,4.701843,10.709001,9.761437,3.889029,-19.303377,-29.105274,-41.542769,-32.141293,...,-1.418810,-6.571751,-13.820478,-26.396635,-40.424244,-73.114258,-81.062226,-54.028594,-15.561067,17.126227
61437,19.603638,6.573561,3.601326,9.488630,8.985855,3.965517,-18.385736,-27.856705,-39.543270,-29.851616,...,-1.643307,-6.299833,-13.592476,-25.556373,-39.242612,-70.460163,-75.735613,-49.457882,-13.522115,15.012783
61438,17.515713,5.710749,2.502656,8.268827,8.220314,4.072075,-17.505036,-26.709120,-37.733481,-27.666732,...,-1.824955,-5.975142,-13.343379,-24.713700,-38.207230,-68.125448,-70.577384,-44.924582,-11.617095,11.994490


## Erin's Spike Detector - Alfredo's Implementation

I will try to reproduce the spike detectors proposed here: https://www.sciencedirect.com/science/article/pii/S1388245707001666?via%3Dihub, which are the ones that Erin currently uses

Actually, Erin's code can be accessed here: https://github.com/erinconrad/FC_toolbox/blob/main/spike_detector/clean_detector.m
so it is just a matter of translating this matlab code into a Python code

Detector based off of Erin's code - there are 3 functions that are needed for this to run: 1. function for filtering the ieeg data (`eegfilt`), 2. function for finding peaks (`findpeaks`), 3. function for ensuring that spikes are detected in more than one channel (`multi_channel_requirements`)

`findpeaks` function definition

In [None]:
def findpeaks(s):
    """
    Inputs:
        s: timeseries signal
    Outputs:
        p: location in the signal with an increasing signal (positive slope)
        t: location in the signal with a decreasing signal (negative slope)
    """
    ds = np.diff(s)
    ds = np.hstack((ds[0], ds))  # pad diff
    filt = np.where(ds[1:] == 0)[0] + 1  # find zeros
    ds[filt] = ds[filt - 1]  # replace zeros
    ds = np.sign(ds)
    # compute the second derivative -  inflection points
    ds = np.diff(ds)
    t = np.where(ds > 0)[0]
    p = np.where(ds < 0)[0]
    return p, t

`eegfilt` function definition

In [None]:
def eegfilt(x, fc, typ, fs):
    """
    Inputs:
        x: timeseries signal
        fc: cutoff frequency
        typ: type of filtering (lp - lowpass, hp - highpass)
        fs: sampling rate of x
    Outputs:
        p: location in the signal with an increasing signal (positive slope)
        t: location in the signal with a decreasing signal (negative slope)
    """
    # filter eeg data using Butterworth filter
    # out = eegfilt(data,cutfreq,typ);
    # out = eegfilt(data,70,'hp'); high pass with 70Hz cutoff

    # EEG_BUTTER - Butterworth filter implementation
    # xf = eeg_butter(x,sampl_freq,cutoff_freq,filter_type,num_poles)

    np = 6  # order of the butterworth filter

    if np.sum(fc >= fs / 2):
        raise ValueError("Cutoff frequency must be < one half the sampling rate")

    fn = fs / 2

    if typ == "bp":
        typ = "lp"

    if typ == "lp":
        B, A = butter(np, fc / fn)
    elif typ == "hp":
        B, A = butter(np, fc / fn, "high")
    elif typ == "st":
        B, A = butter(np, fc / fn, "stop")

    out = filtfilt(B, A, x)
    return out

`multichannel_requirements` function that makes sure that the spikes occur in 2 or more channels

In [None]:
def multichannel_requirements(gdf, nchs, fs):
    # Parameters
    min_chs = 2  # spike should be on at least 2 channels
    max_chs = int(nchs * 0.5)  # on no more than half the channels
    min_time = int(100 * 1e-3 * fs)  # 100 ms to look for other spikes

    final_spikes = []

    s = 0
    curr_seq = [s]
    last_time = gdf[s, 1]

    while s < gdf.shape[0] - 1:
        # move to next spike time
        new_time = gdf[s + 1, 1]

        # if it's within the time diff
        if new_time - last_time < min_time:
            curr_seq.append(s + 1)  # append it to the current sequence

            if s == gdf.shape[0] - 2:
                # done with sequence, check if the number of involved chs is appropriate
                l = len(np.unique(gdf[curr_seq, 0]))
                if min_chs <= l <= max_chs:
                    final_spikes.append(
                        np.hstack(
                            (
                                gdf[curr_seq, :],
                                (gdf[curr_seq, 1] - np.min(gdf[curr_seq, 1]))[
                                    :, np.newaxis
                                ],
                            )
                        )
                    )

        else:
            # done with sequence, check if the length of sequence is appropriate
            l = len(np.unique(gdf[curr_seq, 0]))
            if min_chs <= l <= max_chs:
                final_spikes.append(
                    np.hstack(
                        (
                            gdf[curr_seq, :],
                            (gdf[curr_seq, 1] - np.min(gdf[curr_seq, 1]))[
                                :, np.newaxis
                            ],
                        )
                    )
                )

            # reset sequence
            curr_seq = [s + 1]

        # increase the last time
        last_time = gdf[s + 1, 1]

        # increase the current spike
        s += 1
    if len(final_spikes) > 0:
        multichannel_spikes = np.vstack(final_spikes)
    else:
        print("No spikes meet the criteria...")
        multichannel_spikes = []
    return multichannel_spikes

#### Actual Spike Detector

In [None]:
def clean_detector(
    eeg_df,
    fs,
    remove_channels=[
        "EEG EKG 02-Ref",
        "ECG1",
        "EEG EKG1-Ref",
        "EKG2",
        "EKG",
        "EKG1",
        "EKG02",
        "EEG EKG2-Ref",
        "EEG EKG 01-Ref",
        "EEG EKG-Ref",
        "ECG2",
        "EKG01",
    ],
):
    # extract the data from the dataframe
    eeg = eeg_df.values

    ## Parameters
    tmul = 19  # minimum relative amplitude (compared to baseline)
    absthresh = 100  # minimum absolute amplitude (uV)
    sur_time = (
        0.5  # surround time (in s) against which to compare for relative amplitude
    )
    close_to_edge = 0.05  # time (in s) surrounding start and end of sample to ignore
    too_high_abs = 1e3  # amplitude above which I reject it as artifact
    spkdur = [15, 200]  # spike duration must be within this range (in ms)
    spkdur = np.array(spkdur) * fs // 1000  # convert above to samples
    lpf1 = 30  # low pass filter for artifact component
    hpf = 7  # high pass filter for spikey component

    ## Initialize things
    all_spikes = np.empty((0, 4))
    nchs = eeg.shape[1]
    ch_names = list(eeg_df.columns)

    ## Iterate channels and detect spikes
    print("Detecting spikes through channels")
    for j in range(nchs):
        if ch_names[j] not in remove_channels:
            # initialize out array with final spike info
            out = np.empty((0, 3))

            # extract channel data
            data = eeg[:, j]

            # Skip if all nans
            if np.sum(np.isnan(data)) > 0:
                continue

            # re-adjust the mean of the data to be zero
            data = data - np.nanmean(data)

            # initialize array with tentative spike info
            spikes = []

            # Low pass filter to remove artifact
            b, a = butter(4, lpf1 / (fs / 2), btype="lowpass")
            lpdata = filtfilt(b, a, data)  # low pass filter

            # high pass filter to get the spikey part
            b, a = butter(4, hpf / (fs / 2), btype="highpass")
            hpdata = filtfilt(b, a, lpdata)  # high pass filter

            # establish the baseline for the relative amplitude threshold
            lthresh = np.median(np.abs(hpdata))
            thresh = lthresh * tmul  # this is the final threshold we want to impose

            # Run the spike detector to find both negative and positive spikes
            for k in range(2):
                if k == 1:
                    kdata = -hpdata  # flip the sign of the data to find positive spikes
                else:
                    kdata = hpdata

                # find peaks (spp) and troughs (spv) in the data
                spp, spv = findpeaks(kdata)

                # find peak-to-peak durations within allowable range
                idx = np.where(np.diff(spp) <= spkdur[1])[0]

                # peak before list
                startdx = spp[idx]

                # peak after list
                startdx1 = spp[idx + 1]

                # Loop over peaks
                for i in range(len(startdx)):
                    # find the valley that is between the two peaks
                    spkvalley = spv[np.where((spv > startdx[i]) & (spv < startdx1[i]))]

                    # If the height from valley to either peak is big enough, it could be a spike
                    max_height = max(
                        abs(kdata[startdx1[i]] - kdata[spkvalley]),
                        abs(kdata[startdx[i]] - kdata[spkvalley]),
                    )
                    if (
                        max_height > thresh
                    ):  # if amplitude from peak to valley is large enough, append as a spike
                        # add the location of the spike valley, the duration of spike from peak 1 to peak 2 and the amplitude from peak to valley
                        spikes.append(
                            [spkvalley[0], startdx1[i] - startdx[i], max_height]
                        )

            if len(spikes) > 0:
                # Add channel number and convert spike time to samples
                spikes = [[a, b, c[0]] for a, b, c in spikes]
                spikes = np.array(spikes)
                # print
                # spikes[:, -1] = list(map(lambda x: x[0], spikes[:, -1]))
                spikes = spikes.astype(float)

                # check different properties for each of the detected spikes to make sure that they are truly spikes
                # make sure they are not too small in amplitude (noise), too sharp/short in time (noise), or too large (artifact)
                toosmall = []
                toosharp = []
                toobig = []

                # for each spike - spikes is a n_spikes by properties array
                for i in range(spikes.shape[0]):
                    # re-define baseline to be period surrounding spike
                    istart = int(
                        max(1, round(spikes[i, 0] - sur_time * fs))
                    )  # starting time for the surrounding timepoints is either 1 (if spike is at timepoint 0), or at the surrounding timepoint
                    iend = int(
                        min(len(hpdata), round(spikes[i, 0] + sur_time * fs))
                    )  # same but for ending time. It accounts for the spike being in the last timepoint

                    # define a regional local threshold within the specified time range
                    alt_thresh = np.median(np.abs(hpdata[istart:iend])) * tmul

                    if (
                        spikes[i, 2] > alt_thresh and spikes[i, 2] > absthresh
                    ):  # both parts together are bigger than thresh: so have some flexibility in relative sizes
                        if (
                            spikes[i, 1] * 1000 / fs > spkdur[0]
                        ):  # spike wave cannot be too sharp: then it is either too small or noise
                            if spikes[i, 2] < too_high_abs:
                                out = np.vstack(
                                    (out, spikes[i, :])
                                )  # add info of spike to output list
                            else:
                                toobig.append(spikes[i, 0])
                        else:
                            toosharp.append(spikes[i, 0])
                    else:
                        toosmall.append(spikes[i, 0])

                if out.shape[0] > 0:
                    # Re-align spikes to peak of the spikey component
                    timeToPeak = [
                        -0.15,
                        0.15,
                    ]  # Only look 150 ms before and after the currently defined peak
                    fullSurround = [-sur_time, sur_time] * fs
                    idxToPeak = (np.array(timeToPeak) * fs).astype(int)

                    for i in range(out.shape[0]):
                        currIdx = out[i, 0]
                        surround_idx = np.arange(
                            max(1, round(currIdx + fullSurround[0])),
                            min(round(currIdx + fullSurround[1]), len(hpdata)),
                        ).astype(int)
                        idxToLook = np.arange(
                            max(1, round(currIdx + idxToPeak[0])),
                            min(round(currIdx + idxToPeak[1]), len(hpdata)),
                        ).astype(int)
                        snapshot = data[idxToLook] - np.median(data[surround_idx])
                        # Look at the high frequency data (where the mean is subtracted already)
                        I = np.argmax(np.abs(snapshot))
                        # The peak is the maximum absolute value of this
                        out[i, 0] = idxToLook[0] + I - 1

                    all_spikes = np.vstack(
                        (all_spikes, np.hstack((np.full((out.shape[0], 1), j), out)))
                    )

    # convert the last column to a list of numbers instead of a list of arrays
    # output the numpy array with the spikes
    gdf = all_spikes
    gdf = np.unique(gdf, axis=0)

    # sort by times and put ch first
    if gdf.size > 0:
        gdf = gdf[gdf[:, 1].argsort(), :]  # sort by time

        """
        times = gdf[:,0]
        chs = gdf[:,1]
        I = np.argsort(times)
        chs = chs[I]
        times = times[I]
        gdf = np.vstack((chs, times)).T
        """
    # Remove those at beginning and end
    if gdf.size > 0:
        close_idx = int(close_to_edge * fs)
        gdf = gdf[gdf[:, 1] >= close_idx]
        gdf = gdf[gdf[:, 1] <= eeg.shape[0] - close_idx]

    # remove duplicates
    if gdf.size > 0:
        keep = np.ones(gdf.shape[0], dtype=bool)

        # take diff of times
        diff_times = np.hstack((np.inf, np.diff(gdf[:, 1])))

        # take diff of chs
        diff_chs = np.hstack((np.inf, np.diff(gdf[:, 0])))

        # find those that are close in time and the same ch
        too_close = np.logical_and(abs(diff_times) < 100e-3 * fs, diff_chs == 0)

        keep[too_close] = 0
        keep = np.array(keep)

        n_removed = np.sum(~keep)
        gdf = gdf[keep]

    # execute the multichannel requirements
    gdf = multichannel_requirements(gdf, nchs, fs)

    if len(gdf) > 0:
        # convert gdf into a pandas dataframe
        df_spikes = pd.DataFrame()
        df_spikes["Channel Number"] = gdf[:, 0].astype(int)
        df_spikes["Channel Name"] = list(
            map(lambda x: ch_names[x], gdf[:, 0].astype(int))
        )
        df_spikes["Spike Location"] = gdf[:, 1].astype(int)
        df_spikes["Spike Duration"] = gdf[:, 2].astype(int)
        df_spikes["Spike Amplitude"] = gdf[:, 3]
    else:
        df_spikes = pd.DataFrame(
            columns=[
                "Channel Number",
                "Channel Name",
                "Spike Location",
                "Spike Duration",
                "Spike Amplitude",
            ]
        )
    return df_spikes

Function for plotting the spikes

In [None]:
def plot_spikes(ieeg_data, spike_df):
    channels_with_spikes = list(set(spike_df["Channel Number"].values))

    plt.figure(figsize=(20, 20))
    for i, ch in enumerate(channels_with_spikes):
        plt.subplot(len(channels_with_spikes), 1, i + 1)
        plt.plot(ieeg_data.values[:, ch])
        sns.despine(top=True, right=True, left=True, bottom=True)
        plt.xticks([])
        plt.yticks([])

        spike_locations = spike_df[spike_df["Channel Number"] == ch][
            "Spike Location"
        ].values
        spike_amplitudes = spike_df[spike_df["Channel Number"] == ch][
            "Spike Amplitude"
        ].values
        for j, (location, amplitude) in enumerate(
            zip(spike_locations, spike_amplitudes)
        ):
            plt.plot(
                location, ieeg_data.values[location, ch], ".", markersize=10, color="r"
            )
        channel_name = spike_df[spike_df["Channel Number"] == ch][
            "Channel Name"
        ].values[0]
        plt.ylabel(channel_name)

#### Test the Spike Detector

In [None]:
def identify_bad_channels(values, channel_indices, channel_labels, fs):
    """
    Identifies 'bad' channels in an EEG dataset based on various criteria such as high variance, missing data,
    crossing absolute threshold, high variance above baseline, and 60 Hz noise.

    Parameters:
    values (numpy.ndarray): A 2D array of EEG data where each column is a different channel and each row is a reading.
    channel_indices (list): A list containing indices of channels to be analyzed.
    channel_labels (list): A list of channel labels.
    fs (float): The sampling frequency.

    Returns:
    bad (list): A list of 'bad' channel indices.
    details (dict): A dictionary containing the reasons why each channel was marked as 'bad'. Keys are 'noisy', 'nans',
                    'zeros', 'var', 'higher_std', and 'high_voltage'. Each key maps to a list of channel indices.
    """

    # set parameters
    tile = 99
    mult = 10
    num_above = 1
    abs_thresh = 5e3
    percent_60_hz = 0.99
    mult_std = 10

    bad = []
    high_ch = []
    nan_ch = []
    zero_ch = []
    high_var_ch = []
    noisy_ch = []
    all_std = np.full(len(channel_indices), np.nan)

    for i in range(len(channel_indices)):
        bad_ch = 0
        ich = channel_indices[i]
        eeg = values[:, ich]
        bl = np.nanmedian(eeg)

        all_std[i] = np.nanstd(eeg)

        if np.sum(np.isnan(eeg)) > 0.5 * len(eeg):
            bad.append(ich)
            nan_ch.append(ich)
            continue

        if np.sum(eeg == 0) > 0.5 * len(eeg):
            bad.append(ich)
            zero_ch.append(ich)
            continue

        if np.sum(np.abs(eeg - bl) > abs_thresh) > 10:
            bad.append(ich)
            bad_ch = 1
            high_ch.append(ich)

        if bad_ch == 1:
            continue

        pct = np.percentile(eeg, [100 - tile, tile])
        thresh = [bl - mult * (bl - pct[0]), bl + mult * (pct[1] - bl)]
        sum_outside = np.sum((eeg > thresh[1]) | (eeg < thresh[0]))

        if sum_outside >= num_above:
            bad_ch = 1

        if bad_ch == 1:
            bad.append(ich)
            high_var_ch.append(ich)
            continue

        Y = fft(eeg - np.nanmean(eeg))

        P = np.abs(Y) ** 2
        freqs = np.linspace(0, fs, len(P) + 1)
        freqs = freqs[:-1]
        P = P[: int(np.ceil(len(P) / 2))]
        freqs = freqs[: int(np.ceil(len(freqs) / 2))]

        total_P = np.sum(P)
        if total_P != 0 and not np.isnan(total_P):
            P_60Hz = np.sum(P[(freqs > 58) & (freqs < 62)]) / total_P
        else:
            P_60Hz = 0  # or any other value that makes sense in the context

        if P_60Hz > percent_60_hz:
            bad_ch = 1

        if bad_ch == 1:
            bad.append(ich)
            noisy_ch.append(ich)
            continue

    median_std = np.nanmedian(all_std)
    higher_std = [
        channel_indices[i]
        for i in range(len(all_std))
        if all_std[i] > mult_std * median_std
    ]
    bad_std = [ch for ch in higher_std if ch not in bad]
    bad.extend(bad_std)

    details = {
        "noisy": noisy_ch,
        "nans": nan_ch,
        "zeros": zero_ch,
        "var": high_var_ch,
        "higher_std": bad_std,
        "high_voltage": high_ch,
    }

    return bad, details

In [None]:
def detect_bad_channels(values, fs, channel_labels):
    """
    data: raw EEG traces after filtering (i think)
    fs: sampling frequency
    channel_labels: string labels of channels to use
    """
    which_chs = np.arange(values.shape[1])
    chLabels = channel_labels
    ## Parameters to reject super high variance
    tile = 99
    mult = 10
    num_above = 1
    abs_thresh = 5e3

    ## Parameter to reject high 60 Hz
    percent_60_hz = 0.7

    ## Parameter to reject electrodes with much higher std than most electrodes
    mult_std = 10

    bad = []
    high_ch = []
    nan_ch = []
    zero_ch = []
    high_var_ch = []
    noisy_ch = []
    all_std = np.empty((len(which_chs), 1))
    all_std[:] = np.nan
    details = {}

    for i in range(len(which_chs)):
        # print(chLabels[i])

        ich = which_chs[i]
        eeg = values[:, ich]
        bl = np.nanmedian(eeg)

        ## Get channel standard deviation
        all_std[i] = np.nanstd(eeg)

        ## Remove channels with nans in more than half
        if sum(np.isnan(eeg)) > 0.5 * len(eeg):
            bad.append(ich)
            nan_ch.append(ich)
            continue

        ## Remove channels with zeros in more than half
        if sum(eeg == 0) > (0.5 * len(eeg)):
            bad.append(ich)
            zero_ch.append(ich)
            continue

        ## Remove channels with too many above absolute thresh

        if sum(abs(eeg - bl) > abs_thresh) > 10:
            bad.append(ich)
            high_ch.append(ich)
            continue

        ## Remove channels if there are rare cases of super high variance above baseline (disconnection, moving, popping)
        pct = np.percentile(eeg, [100 - tile, tile])
        thresh = [bl - mult * (bl - pct[0]), bl + mult * (pct[1] - bl)]
        sum_outside = sum(((eeg > thresh[1]) + (eeg < thresh[0])) > 0)
        if sum_outside >= num_above:
            bad.append(ich)
            high_var_ch.append(ich)
            continue

        ## Remove channels with a lot of 60 Hz noise, suggesting poor impedance

        # Calculate fft
        # orig_eeg = orig_values(:,ich)
        # Y = fft(orig_eeg-mean(orig_eeg))
        Y = np.fft.fft(eeg - np.nanmean(eeg))

        # Get power
        P = abs(Y) ** 2
        freqs = np.linspace(0, fs, len(P) + 1)
        freqs = freqs[:-1]

        # Take first half
        P = P[: np.ceil(len(P) / 2).astype(int)]
        freqs = freqs[: np.ceil(len(freqs) / 2).astype(int)]

        P_60Hz = sum(P[(freqs > 58) * (freqs < 62)]) / sum(P)
        if P_60Hz > percent_60_hz:
            bad.append(ich)
            noisy_ch.append(ich)
            continue

    ## Remove channels for whom the std is much larger than the baseline
    median_std = np.nanmedian(all_std)
    higher_std = which_chs[(all_std > (mult_std * median_std)).squeeze()]
    bad_std = higher_std
    for ch in bad_std:
        if ch not in bad:
            bad.append(ch)
    channel_mask = [i for i in which_chs if i not in bad]
    details["noisy"] = noisy_ch
    details["nans"] = nan_ch
    details["zeros"] = zero_ch
    details["var"] = high_var_ch
    details["higher_std"] = bad_std
    details["high_voltage"] = high_ch

    return channel_mask, details

### Artifact rejection

### Common average montage

### Filtering

In [None]:
# # Save ieeg_data to a csv file
# ieeg_data.to_csv("ieeg_data.csv")

### Spike detection

In [None]:
ieeg_data

In [None]:
gdf = clean_detector(ieeg_data, int(fs))

In [None]:
gdf

In [None]:
# # Select rows in gdf where Spike Location is between 72600 and 73600
# gdf_selected = gdf[(gdf["Spike Location"] >= 72600) & (gdf["Spike Location"] <= 73600)]
# gdf_selected

In [None]:
# # Randomly select 50 rows from gdf
# gdf_random = gdf.sample(n=50, random_state=1)
# # Sort by Spike Location
# gdf_random = gdf_random.sort_values(by=['Spike Location'])
# gdf_random

In [None]:
# # Group by 'Channel Name' and count the number of spikes
# grouped = gdf.groupby("Channel Name").size().reset_index(name="Total Spikes")
# # Fidn the row where Channel Name is "LA09"
# gdf[gdf["Channel Name"] == "LB07"]

In [None]:
plot_spikes(ieeg_data, gdf)

In [None]:
def plot_spike(data, spike_loc, title):
    """
    Plots the spike centered in the middle of its duration.

    :param data: Data segment containing the spike.
    :param spike_loc: Location of the spike in the data segment.
    :param duration: Duration of the spike.
    :param title: Title of the plot.
    """
    plt.plot(data)
    plt.axvline(spike_loc, color="r", linestyle="--")  # Spike location line
    plt.title(title)
    plt.rcParams["figure.figsize"] = (5, 5)
    plt.show()


# Iterate through each spike in the gdf dataframe
for _, row in gdf.iterrows():
    channel_name = row["Channel Name"]
    spike_location = row["Spike Location"]
    duration = row["Spike Duration"]

    # Extract data centered around the spike from ieeg_data
    start = int(spike_location - duration / 2 - 500)
    end = int(spike_location + duration / 2 + 500)
    data_segment = ieeg_data[channel_name][start:end]

    # Plot the spike
    plot_spike(data_segment, spike_location, f"Spike in {channel_name}")

## Erin's Spike Detector - Will's Implementation

In [9]:
output = spike_detector(
    data=ieeg_data.to_numpy(),
    fs=fs,
    labels=labels,
)
print(f"{len(np.unique(output[:, 2]))} spikes detected")

35 spikes detected


In [10]:
# Load expected_output.npy as expected_output
expected_output = np.load("expected_output.npy")
# Assert that expected_output and output are equal
np.testing.assert_equal(expected_output, output)

In [None]:
# Get the unique spike sequence indices
unique_sequences = np.unique(output[:, 2])

# For each unique spike sequence index
for seq_index in unique_sequences:
    # Filter rows (spikes) that belong to this sequence
    spikes_in_sequence = output[output[:, 2] == seq_index]

    # Create a new figure for this sequence
    fig, axs = plt.subplots(
        len(spikes_in_sequence),
        1,
        sharex=True,
        figsize=(8, len(spikes_in_sequence) * 2),
    )

    # Add a title to the figure
    fig.suptitle(f"Spike sequence {int(seq_index)}", fontsize=12)

    # If there's only one spike in the sequence, axs will not be an array. Convert it to one for consistency.
    if len(spikes_in_sequence) == 1:
        axs = [axs]

    # Plot each spike in this sequence
    for i, spike in enumerate(spikes_in_sequence):
        peak_location = int(spike[0])
        channel_index = int(spike[1])

        # Extract the data around the spike peak (500 samples before and after)
        start_idx = max(0, peak_location - 200)  # Ensure we don't go below 0
        end_idx = min(
            len(ieeg_data), peak_location + 200
        )  # Ensure we don't exceed dataframe length
        data_to_plot = ieeg_data.iloc[start_idx:end_idx, channel_index]

        # Plot this spike data
        axs[i].plot(data_to_plot.index, data_to_plot.values)

        # Add a red vertical dashed line at the location of the peak of the spike
        axs[i].axvline(x=peak_location, color="red", linestyle="--", alpha=0.7)

        axs[i].set_title(f"Channel {labels[channel_index]}")

    # Set shared x-label
    axs[-1].set_xlabel("Sample Number")

    # Adjust layout for the suptitle
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)  # Adjust this value for best appearance
    plt.show()

## Line Length Spike Detector

In [None]:
from lleventdetector import *
from lltransform import *

In [None]:
ieeg_data.shape

In [None]:
ieeg_data_transposed = ieeg_data.to_numpy().T
ieeg_data_transposed.shape

In [None]:
line_length_transform = lltransform(ieeg_data_transposed, int(fs))

In [None]:
line_length_detector_result = lleventdetector(line_length_transform, int(fs), 99.9, 15)
line_length_detector_result

In [None]:
line_length_detector_result[0].shape

In [None]:
spike_windows, electrodes_list = line_length_detector_result

# Setup colors - if there are more electrodes than colors, they will be reused.
colors = plt.cm.jet(np.linspace(0, 1, 200))

for window, electrodes in zip(spike_windows, electrodes_list):
    start, end = window
    start -= 50  # Enlarge window by 50 at the start
    end += 50  # and 50 at the end

    plt.figure(figsize=(10, 6))

    # Get each electrode number from the comma-separated string, and plot its data.
    for elec in electrodes.split(","):
        if elec:  # Check if not an empty string
            elec_num = int(elec)
            plt.plot(
                ieeg_data_transposed[elec_num, int(start) : int(end)],
                color=colors[elec_num],
                label=f"Electrode {elec_num}",
            )

    plt.title(f"Spike Window: {start}-{end}")
    plt.legend()
    plt.xlabel("Time (samples)")
    plt.ylabel("Amplitude")
    plt.tight_layout()
    plt.show()

In [None]:
for window, electrodes in zip(spike_windows[:15], electrodes_list[:15]):
    start, end = window
    start -= 100  # Enlarge window by 50 at the start
    end += 100  # and 50 at the end

    electrode_nums = [int(elec) for elec in electrodes.split(",") if elec]

    fig, axs = plt.subplots(
        len(electrode_nums), 1, sharex=True, figsize=(5, 2 * len(electrode_nums))
    )

    # If there's only one electrode for this window, axs will not be an array. Convert it to a list for consistency.
    if not isinstance(axs, np.ndarray):
        axs = [axs]

    for ax, elec_num in zip(axs, electrode_nums):
        ax.plot(
            ieeg_data_transposed[elec_num, int(start) : int(end)],
            color=colors[elec_num],
        )
        ax.set_title(f"Electrode {good_labels[elec_num]}")
        ax.set_ylabel("Amplitude")

    plt.xlabel("Time (samples)")
    plt.tight_layout()
    fig.suptitle(f"Spike Window: {start}-{end}", y=1.02)
    plt.show()