In [1]:
import os, random
import numpy as np
import pandas as pd
import scipy.io as sio
from ieeg.auth import Session

from get_iEEG_data import *
from spike_detector import *
from iEEG_helper_functions import *

ERIN_DIRECTORY = "../../../../erinconr/projects/fc_toolbox/results/all_out"

## Patient selection

In [2]:
good_hup_ids_for_spike_detector = np.load("../good_hup_ids_for_spike_detector.npy")
good_hup_ids_for_spike_detector

array([137, 138, 139, 140, 141, 142, 143, 145, 146, 148, 150, 151, 152,
       153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166,
       167, 168, 169, 170, 171, 172, 173, 174, 175, 177, 178, 179, 180,
       181, 182, 184, 185, 186, 187, 188, 189, 190, 191, 192, 196, 197,
       199, 201, 202, 204, 205, 206, 207, 209, 210, 211, 213, 214, 215,
       219, 221, 223, 224, 225])

In [3]:
# Load HUP_implant_dates.xlsx
nina_patients_df = pd.read_excel("../../../Data/HUP_implant_dates.xlsx")
# Make the hup_id column integers
nina_patients_df["hup_id"] = nina_patients_df["hup_id"].astype(int)
nina_patients_df

Unnamed: 0,hup_id,IEEG_Portal_Number,Implant_Date,implant_time,Explant_Date,weight_kg
0,225,HUP225_phaseII,2021-10-18,07:15:00,2021-10-26 17:30:00,58.5
1,224,HUP224_phaseII,2021-10-13,07:15:00,2021-10-20 00:00:00,85.5
2,223,HUP223_phaseII,2021-09-29,07:15:00,2021-10-08 08:21:00,101.4
3,221,HUP221_phaseII,2021-08-16,07:15:00,2021-08-23 00:00:00,124.3
4,219,HUP219_phaseII,2021-07-12,07:15:00,2021-07-16 08:18:00,101.6
...,...,...,...,...,...,...
75,141,HUP141_phaseII,2017-05-24,07:15:00,2017-06-01 00:00:00,85.7
76,140,HUP140_phaseII_D01-D02,2017-05-10,07:15:00,2017-05-19 00:00:00,56.7
77,139,HUP139_phaseII,2017-04-26,07:15:00,2017-05-09 00:00:00,69.8
78,138,HUP138_phaseII,2017-04-12,07:15:00,2017-04-20 00:00:00,84.4


In [4]:
# Add a boolean column in nina_patients_df called is_single_dataset and make it True if IEEG_Portal_Number ends with "phaseII"
nina_patients_df["is_single_dataset"] = nina_patients_df[
    "IEEG_Portal_Number"
].str.endswith("phaseII")
# Add a boolean column in nina_patients_df called is_good_for_spike_detector and make it True if the row's hup_id is in good_hup_ids_for_spike_detector
nina_patients_df["is_good_for_spike_detector"] = nina_patients_df["hup_id"].isin(
    good_hup_ids_for_spike_detector
)

In [5]:
# Drop the rows in nina_patients_df where is_single_dataset is False
nina_patients_df = nina_patients_df[nina_patients_df.is_single_dataset == True]
# Drop the rows in nina_patients_df where is_good_for_spike_detector is False
nina_patients_df = nina_patients_df[nina_patients_df.is_good_for_spike_detector == True]
# Sort by hup_id in ascending order
nina_patients_df = nina_patients_df.sort_values(by=["hup_id"], ascending=True)
# Drop columns Implant_Date, implant_time, Explant_Date, weight_kg
nina_patients_df = nina_patients_df.drop(
    columns=["Implant_Date", "implant_time", "Explant_Date", "weight_kg"]
)
# Reset index
nina_patients_df = nina_patients_df.reset_index(drop=True)
nina_patients_df

Unnamed: 0,hup_id,IEEG_Portal_Number,is_single_dataset,is_good_for_spike_detector
0,138,HUP138_phaseII,True,True
1,139,HUP139_phaseII,True,True
2,141,HUP141_phaseII,True,True
3,142,HUP142_phaseII,True,True
4,143,HUP143_phaseII,True,True
5,145,HUP145_phaseII,True,True
6,146,HUP146_phaseII,True,True
7,150,HUP150_phaseII,True,True
8,151,HUP151_phaseII,True,True
9,154,HUP154_phaseII,True,True


## Helper functions

In [6]:
def format_channels(channel_array):
    formatted_array = []
    for label in channel_array:
        if label == "PZ":
            formatted_array.append(label)
            continue

        # Splitting string into two parts: prefix (letters) and number
        prefix, number = (
            label[: -len([ch for ch in label if ch.isdigit()])],
            label[-len([ch for ch in label if ch.isdigit()]) :],
        )

        # Formatting the number to have two digits
        formatted_number = f"{int(number):02}"

        # Appending prefix and formatted number
        formatted_label = prefix + formatted_number
        formatted_array.append(formatted_label)

    return np.array(formatted_array)

In [7]:
# def compare_dfs(expected_spikes_df, actual_spikes_df):
#     # Check for different shapes
#     if expected_spikes_df.shape[0] != actual_spikes_df.shape[0]:
#         # Identify the extra or missing rows
#         merged_df = pd.merge(
#             expected_spikes_df,
#             actual_spikes_df,
#             on="channel_label",
#             how="outer",
#             indicator=True,
#         )
#         extra_rows = merged_df[merged_df["_merge"] == "right_only"][
#             ["channel_label", "spike_time_y"]
#         ]
#         missing_rows = merged_df[merged_df["_merge"] == "left_only"][
#             ["channel_label", "spike_time_x"]
#         ]

#         extra_rows.columns = ["channel_label", "spike_time"]
#         missing_rows.columns = ["channel_label", "spike_time"]

#         if not extra_rows.empty:
#             return f"Extra rows in actual_spikes_df:\n{extra_rows}"

#         if not missing_rows.empty:
#             return f"Missing rows in actual_spikes_df:\n{missing_rows}"

#     # Check for different channel labels
#     for exp_label, act_label in zip(
#         expected_spikes_df["channel_label"], actual_spikes_df["channel_label"]
#     ):
#         if exp_label != act_label:
#             return f"Channel label mismatch: expected {exp_label}, but got {act_label}"

#     # Check spike_time with fuzzy comparison
#     spike_time_difference = abs(
#         expected_spikes_df["spike_time"] - actual_spikes_df["spike_time"]
#     )

#     # If all differences are <= 4, then dataframes are considered the same
#     if all(spike_time_difference <= 4):
#         return True

#     # Else, return the rows where spike_time differs beyond the tolerance
#     diff_df = expected_spikes_df[spike_time_difference > 4].copy()
#     diff_df["actual_spike_time"] = actual_spikes_df[spike_time_difference > 4][
#         "spike_time"
#     ]
#     return diff_df

In [8]:
def compare_dfs(expected_spikes_df, actual_spikes_df):
    # Check for identical shapes first
    if expected_spikes_df.shape == actual_spikes_df.shape:
        return True

    # If shapes are different, identify the extra or missing rows
    merged_df = pd.merge(
        expected_spikes_df,
        actual_spikes_df,
        on="channel_label",
        how="outer",
        indicator=True,
    )
    extra_rows = merged_df[merged_df["_merge"] == "right_only"][
        ["channel_label", "spike_time_y"]
    ]
    missing_rows = merged_df[merged_df["_merge"] == "left_only"][
        ["channel_label", "spike_time_x"]
    ]

    extra_rows.columns = ["channel_label", "spike_time"]
    missing_rows.columns = ["channel_label", "spike_time"]

    # Generate the differences message
    differences = []
    if not extra_rows.empty:
        differences.append(f"Extra rows in actual_spikes_df:\n{extra_rows}")
    if not missing_rows.empty:
        differences.append(f"Missing rows in actual_spikes_df:\n{missing_rows}")

    return "\n".join(differences)

## Main loop

In [9]:
print("Using Carlos session")
with open("agu_ieeglogin.bin", "r") as f:
    session = Session("aguilac", f.read())

Using Carlos session


In [10]:
for index, row in nina_patients_df.iterrows():
    error_percetnages = []
    num_of_clips_overcounting = 0
    num_of_clips_undercounting = 0

    hup_id = row["hup_id"]
    dataset_name = row["IEEG_Portal_Number"]

    print("\n")
    print(f"------Processing HUP {hup_id} with dataset {dataset_name}------")

    ########################################
    # Get the data from IEEG
    ########################################

    dataset = session.open_dataset(dataset_name)

    all_channel_labels = np.array(dataset.get_channel_labels())
    channel_labels_to_download = all_channel_labels[
        electrode_selection(all_channel_labels)
    ]

    duration_usec = dataset.get_time_series_details(
        channel_labels_to_download[0]
    ).duration
    duration_hours = int(duration_usec / 1000000 / 60 / 60)
    enlarged_duration_hours = duration_hours + 24

    print(f"Opening {dataset_name} with duration {duration_hours} hours")

    ########################################
    # Process Erin's file
    ########################################

    erin_mat_file = sio.loadmat(os.path.join(ERIN_DIRECTORY, f"HUP{hup_id}_pc.mat"))

    mat_content = erin_mat_file["pc"]
    mat_content.dtype

    name = mat_content[0, 0]["name"]
    file = mat_content[0, 0]["file"]
    file.dtype

    name = file[0, 0]["name"]
    run = file[0, 0]["run"]
    run.dtype

    data = run[0]["data"]
    run_times = run[0]["run_times"]
    block_times = run[0]["block_times"]
    cohere_out = run[0]["cohere_out"]

    assert len(data) == len(run_times) == len(block_times) == len(cohere_out) > 100
    # Taking first 2 and last 2 elements
    first_two = run_times[:2].tolist()
    last_two = run_times[-2:].tolist()

    # Selecting 16 random elements from the middle and sorting them
    middle_indices = sorted(random.sample(range(2, len(run_times) - 2), 18))
    middle_elements = [run_times[i] for i in middle_indices]

    # Combining the lists
    all_indices = [0, 1] + middle_indices
    all_elements = first_two + middle_elements
    assert len(all_indices) == len(all_elements) == 20

    # Iterating through the selected elements and their indices
    for i, clip in zip(all_indices, all_elements):
        clip = clip[0]
        start_time_sec, end_time_sec = clip
        start_time_sec = int(start_time_sec)
        end_time_sec = int(end_time_sec)
        start_time_usec = start_time_sec * 1000000
        end_time_usec = end_time_sec * 1000000
        # Confirm montage is CAR
        montage = data[i][0]["montage"][0]["name"][0][1][0]
        assert montage == "car"
        # Get the expected spikes
        spikes = data[i][0]["montage"][0]["spikes"][0][1]
        expected_number_of_spikes = len(spikes)
        print(f"Clip {i}, {start_time_sec} to {end_time_sec}")
        if expected_number_of_spikes > 0:
            # Get Erin's channel labels
            erin_channel_labels = [item[0] for item in data[i][0]["clean_labels"][0]]
            erin_channel_labels = np.array(erin_channel_labels).flatten().astype(str)

            # Create a dataframe with the expected spikes
            # Adjust the indices in the first column of spikes to zero-based indexing
            zero_based_indices = spikes[:, 0] - 1

            # Map the adjusted indices to the corresponding erin_channel_labels
            channel_labels_mapped = format_channels(
                erin_channel_labels[zero_based_indices]
            )

            # Create the dataframe
            expected_spikes_df = pd.DataFrame(
                {"channel_label": channel_labels_mapped, "spike_time": spikes[:, 1]}
            )
            # Sort expected_spikes_df first by spike_time then by channel_label
            expected_spikes_df = expected_spikes_df.sort_values(
                by=["spike_time", "channel_label"], ascending=[True, True]
            )
            # Reset index
            expected_spikes_df = expected_spikes_df.reset_index(drop=True)
            print(expected_spikes_df)

        try:
            ieeg_data, fs = get_iEEG_data(
                "aguilac",
                "agu_ieeglogin.bin",
                dataset_name,
                start_time_usec,
                end_time_usec,
                channel_labels_to_download,
            )
            fs = int(fs)
        except:
            continue

        # Check if ieeg_data dataframe is all NaNs
        if ieeg_data.isnull().values.all():
            print(
                f"Empty dataframe after download, skip... There should be {expected_number_of_spikes} spikes"
            )
            continue

        good_channels_res = detect_bad_channels_optimized(ieeg_data.to_numpy(), fs)
        good_channel_indicies = good_channels_res[0]
        good_channel_labels = channel_labels_to_download[good_channel_indicies]
        ieeg_data = ieeg_data[good_channel_labels].to_numpy()

        # Check if ieeg_data is empty after dropping bad channels
        if ieeg_data.size == 0:
            print(
                f"Empty dataframe after artifact rejection, skip... There should be {expected_number_of_spikes} spikes"
            )
            continue

        ieeg_data = common_average_montage(ieeg_data)

        # Apply the filters directly on the DataFrame
        ieeg_data = notch_filter(ieeg_data, 59, 61, fs)
        ieeg_data = bandpass_filter(ieeg_data, 1, 70, fs)

        ##############################
        # Detect spikes
        ##############################

        spike_output = spike_detector(
            data=ieeg_data,
            fs=fs,
            electrode_labels=good_channel_labels,
        )
        spike_output = spike_output.astype(int)
        actual_number_of_spikes = len(spike_output)

        if actual_number_of_spikes == 0:
            print(
                f"No spikes detected, skip saving... There should be {expected_number_of_spikes} spikes"
            )
            continue
        else:
            # Map the channel indices to the corresponding good_channel_labels
            channel_labels_mapped = good_channel_labels[spike_output[:, 1]]

            # Create the dataframe
            actual_spikes_df = pd.DataFrame(
                {
                    "channel_label": channel_labels_mapped,
                    "spike_time": spike_output[:, 0],
                    # "spike_sequence": spike_output[:, 2],
                }
            )
            # Sort actual_spikes_df first by spike_time then by channel_label
            actual_spikes_df = actual_spikes_df.sort_values(
                by=["spike_time", "channel_label"], ascending=[True, True]
            )
            # Reset index
            actual_spikes_df = actual_spikes_df.reset_index(drop=True)
            print(actual_spikes_df)
            print("Comparing dataframes...")
            comparison = compare_dfs(expected_spikes_df, actual_spikes_df)
            if comparison == True:
                print("Dataframes are the same")
            else:
                print(comparison)
            # Convert spike_output to int
            print(
                f"Detected {actual_number_of_spikes} spikes, should be {expected_number_of_spikes} spikes"
            )
            if actual_number_of_spikes > expected_number_of_spikes:
                num_of_clips_overcounting += 1
            elif actual_number_of_spikes < expected_number_of_spikes:
                num_of_clips_undercounting += 1

            error_percentage = (
                abs(actual_number_of_spikes - expected_number_of_spikes)
                / expected_number_of_spikes
            )
            error_percetnages.append(error_percentage)

    print(
        f"Patient HUP {hup_id} average error percentages: {np.mean(error_percetnages)*100}%"
    )
    print(f"Patient HUP {hup_id} overcounting clips: {num_of_clips_overcounting}")
    print(f"Patient HUP {hup_id} undercounting clips: {num_of_clips_undercounting}")



------Processing HUP 138 with dataset HUP138_phaseII------
Opening HUP138_phaseII with duration 172 hours
Clip 0, 87 to 147
Empty dataframe after download, skip... There should be 0 spikes
Clip 1, 1096 to 1156
Empty dataframe after artifact rejection, skip... There should be 0 spikes
Clip 43, 25884 to 25944
Empty dataframe after artifact rejection, skip... There should be 0 spikes
Clip 55, 33221 to 33281
  channel_label  spike_time
0          RA03       11319
1          RA04       11319
2          RA02       11320
3          LA04       26797
4          LA05       26799
5          LA02       26835
  channel_label  spike_time
0          RA03       11316
1          RA04       11316
2          RA02       11317
3          LA04       26794
4          LA05       26796
5          LA02       26832
Comparing dataframes...
Dataframes are the same
Detected 6 spikes, should be 6 spikes
Clip 116, 70065 to 70125
No spikes detected, skip saving... There should be 0 spikes
Clip 123, 74011 to 74071
  

Opening HUP166_phaseII with duration 163 hours


In [None]:
expected_spikes_df

In [None]:
actual_spikes_df

In [None]:
compare_dfs(expected_spikes_df, actual_spikes_df)

Verify the following 20 time clips:
1. The first 2 clips
2. Random 18 clips in the middle (until the last two)

The reason is that the last two might be incomplete.

1. First, make sure the number of spikes are the same or smaller
2. Second, make sure the channels match up
3. Fuzzy time