# Spike Morphology

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from get_iEEG_data import *
from spike_detector import *
from spike_morphology import *
from iEEG_helper_functions import *

In [2]:
def create_pwd_file(username, password, fname=None):
    if fname is None:
        fname = "{}_ieeglogin.bin".format(username[:3])
    with open(fname, "wb") as f:
        f.write(password.encode())
    print("-- -- IEEG password file saved -- --")


create_pwd_file("dma", "mycqEv-pevfo4-roqfan")

with open("dma_ieeglogin.bin", "r") as f:
    s = Session("dma", f.read())

ds = s.open_dataset("HUP210_phaseII")
all_channel_labels = np.array(ds.get_channel_labels())
label_idxs = electrode_selection(all_channel_labels)
labels = all_channel_labels[label_idxs]

-- -- IEEG password file saved -- --


In [3]:
ieeg_data, fs = get_iEEG_data(
    "dma",
    "dma_ieeglogin.bin",
    "HUP210_phaseII",
    (179677) * 1e6,
    (179677 + 20) * 1e6,
    labels,
)

fs = int(fs)

In [5]:
good_channels_res = detect_bad_channels_optimized(ieeg_data.to_numpy(), fs)
good_channel_indicies = good_channels_res[0]
good_labels = labels[good_channel_indicies]
ieeg_data = ieeg_data[good_labels].to_numpy()

ieeg_data = common_average_montage(ieeg_data)


# Apply the filters directly on the DataFrame
ieeg_data = pd.DataFrame(notch_filter(ieeg_data.values, 59, 61, fs))
ieeg_data = pd.DataFrame(bandpass_filter(ieeg_data.values, 1, 70, fs))

ieeg_data

KeyError: "['LJ12'] not in index"

In [None]:
output = spike_detector(
    data=ieeg_data,
    fs=fs,
    labels=good_labels,
)
print(f"{len(np.unique(output[:, 2]))} spikes detected")

In [None]:
output = output.astype(int)
output

In [None]:
num_good = 0

for spike in output:
    channel_id = spike[1]
    peak_index = spike[0]
    spike_signal = ieeg_data[peak_index - 1000 : peak_index + 1000][
        channel_id
    ].to_numpy()

    basic_features, advanced_features, is_valid, bad_reason = extract_spike_morphology(
        spike_signal
    )

    if is_valid:
        num_good += 1
        # peak, left_point, right_point, slow_end, slow_max = basic_features
        # print(basic_features)
        # print(advanced_features)
        # plt.plot(spike_signal)
        # plt.plot(peak, spike_signal[peak], "x")
        # plt.plot(left_point, spike_signal[left_point], "o")
        # plt.plot(right_point, spike_signal[right_point], "o")
        # plt.plot(slow_end, spike_signal[slow_end], "o", color="k")
        # plt.title("A spike")
        # plt.xlim(250, 1750)
        # plt.show()
    else:
        print(bad_reason)
        plt.plot(spike_signal)
        plt.title(f"NOT a spike because of {bad_reason}")
        plt.xlim(250, 1750)
        plt.show()
    # elif bad_reason != "Short segment":
    #     print(bad_reason)
    # plt.plot(spike_signal)
    # plt.title(f"NOT a spike because of {bad_reason}")
    # plt.xlim(250, 1750)
    # plt.show()

In [None]:
# features = [
#     "slow_max",
#     "rise_amp",
#     "decay_amp",
#     "slow_width",
#     "slow_amp",
#     "rise_slope",
#     "decay_slope",
#     "average_amp",
# ]

# for feature in features:
#     fig, axarr = plt.subplots(
#         nrows=(len(all_spikes_dfs) + 1) // 2,
#         ncols=2,
#         figsize=(16, 4 * (len(all_spikes_dfs) + 1) // 2),
#     )
#     fig.suptitle(f"Change in {feature} across hour")

#     for idx, (df, hup_id, fs) in enumerate(
#         zip(all_spikes_dfs, completed_hup_ids, all_fs)
#     ):
#         grouped = df.groupby("peak_hour").mean()
#         row = idx // 2
#         col = idx % 2
#         sns.regplot(
#             x=grouped.index,
#             y=grouped[feature],
#             ax=axarr[row, col],
#             lowess=True,
#             scatter_kws={"s": 10},
#             line_kws={"color": "red"},
#         )
#         axarr[row, col].set_title(f"HUP {hup_id} {fs}Hz")
#         axarr[row, col].set_xlabel("Hour")
#         axarr[row, col].set_ylabel(feature)

#         # Load seizure times and plot vertical lines
#         seizure_times_sec = np.load(os.path.join(SEIZURES_DIR, f"HUP_{hup_id}.npy"))
#         seizure_times_hour = seizure_times_sec[:, 0] / 3600  # convert seconds to hours
#         for seizure_time in seizure_times_hour:
#             axarr[row, col].axvline(x=seizure_time, color="red", linestyle="--")

#     # Delete unused subplots
#     for i in range(len(all_spikes_dfs), 2 * ((len(all_spikes_dfs) + 1) // 2)):
#         row = i // 2
#         col = i % 2
#         fig.delaxes(axarr[row, col])

#     plt.tight_layout()
#     plt.subplots_adjust(top=0.9)
#     plt.show()

In [None]:
# def get_and_plot_longest_spike_train(test, patient_hup_id, fs, ax):
#     # Group by 'sequence_index_mask' and count rows in each group
#     grouped = test.groupby("sequence_index_mask").size()

#     # Identify the sequence_index_mask with the maximum count
#     longest_spike_train_mask = grouped.idxmax()
#     longest_spike_train_count = grouped.max()

#     print(
#         f"The longest spike train has sequence_index_mask: {longest_spike_train_mask} with {longest_spike_train_count} spikes."
#     )

#     # To get the rows corresponding to the longest spike train:
#     longest_spike_train_df = test[
#         test["sequence_index_mask"] == longest_spike_train_mask
#     ]

#     # Get the smallest peak_index and largest peak_index
#     smallest_peak_index = longest_spike_train_df["peak_index"].min()
#     largest_peak_index = longest_spike_train_df["peak_index"].max()

#     dataset_name = f"HUP{patient_hup_id}_phaseII"
#     dataset = session.open_dataset(dataset_name)
#     all_channel_labels = np.array(dataset.get_channel_labels())

#     ieeg_data, _ = get_iEEG_data(
#         "dma",
#         "dma_ieeglogin.bin",
#         f"HUP{patient_hup_id}_phaseII",
#         (smallest_peak_index - 1000) / fs * 1e6,
#         (largest_peak_index + 1000) / fs * 1e6,
#         all_channel_labels,
#     )

#     print(ieeg_data.shape)
#     for channel in ieeg_data.columns:
#         ax.plot(ieeg_data[channel])

#     ax.set_ylabel("Amplitude")
#     ax.set_xlabel("Time")
#     ax.set_title(f"HUP {hup_id}")

#     return ieeg_data

# # Creating a subplot figure
# n_patients = len(all_spikes_dfs)


# # Looping over the dataframes
# for all_spikes_df, fs, hup_id in zip(all_spikes_dfs, all_fs, completed_hup_ids):
#     print(f"Processing HUP {hup_id}...")

#     # Check for sequence change based on sequence_index or inter_spike_interval_samples
#     change_mask = all_spikes_df["inter_spike_interval_samples"] > 55

#     # Create the sequence_index_mask
#     all_spikes_df["sequence_index_mask"] = change_mask.astype(int).cumsum()

#     # Load seizure times and plot vertical lines
#     seizure_times_sec = np.load(os.path.join(SEIZURES_DIR, f"HUP_{hup_id}.npy"))
#     seizure_times_hour = seizure_times_sec[:, 0] / 3600  # convert seconds to hours

#     fig, axs = plt.subplots(1, 3, figsize=(15, 6))  # 1 row, 3 columns for each patient

#     # Extracting Fano Factors for the three scenarios:
#     # 1. Max medication
#     max_medication_data = all_spikes_df[all_spikes_df["peak_hour"].isin([0, 1, 2])]
#     get_and_plot_longest_spike_train(max_medication_data, hup_id, fs, axs[0])
#     axs[0].set_title(f"HUP {hup_id} - Max Medication")

#     # 2. Before seizure
#     first_seizure_hour = int(seizure_times_hour[0])
#     before_seizure_data = all_spikes_df[
#         all_spikes_df["peak_hour"].isin(
#             range(first_seizure_hour - 2, first_seizure_hour + 1)
#         )
#     ]
#     get_and_plot_longest_spike_train(before_seizure_data, hup_id, fs, axs[1])
#     axs[1].set_title(f"HUP {hup_id} - Before Seizure")

#     # 3. After seizure
#     last_hours = sorted(all_spikes_df["peak_hour"].unique())[-6:-3]
#     after_seizure_data = all_spikes_df[all_spikes_df["peak_hour"].isin(last_hours)]
#     get_and_plot_longest_spike_train(after_seizure_data, hup_id, fs, axs[2])
#     axs[2].set_title(f"HUP {hup_id} - After Seizure")

#     fig.suptitle(f"HUP {hup_id}")
#     plt.tight_layout()
#     plt.show()  # Displays the figure