### Configuration

In [None]:
import os
import numpy as np
import pandas as pd
import xarray as xr

from tqdm import tqdm
from pandas.arrays import IntervalArray

import mne
from mne.time_frequency import tfr_array_morlet
from scipy.stats import zscore

from utils__helpers_macro import hilbert_powerphase, hilbert_envelope

In [None]:
import utils__config
os.chdir(utils__config.working_directory)
os.getcwd()

### Parameters

In [None]:
# fif_path = 'Cache/Subject01/Feb02/S01_Feb02_256hz.fif'
# sw_path = 'Cache/Subject01/Feb02/S01_SW.csv'
# spike_path = 'Cache/Subject01/Feb02/S01_spikes.csv'
# legui_path = 'Cache/Subject01/Feb02/S01_electrodes.csv'
# bad_channel_path = 'Cache/Subject01/Feb02/S01_bad_channels.csv'
# best_channel_path = 'Cache/Subject01/Feb02/S01_best_channels.csv'
# output_path = 'Cache/Subject01/Feb02/S01_spike_sw_coupling.csv'

In [None]:
fif_path = 'Cache/Subject05/Jul12/S05_Jul12_256hz.fif'
sw_path = 'Cache/Subject05/Jul12/S05_SW.csv'
spike_path = 'Cache/Subject05/Jul12/S05_spikes.csv'
legui_path = 'Cache/Subject05/S05_electrodes.csv'
bad_channel_path = 'Cache/Subject05/Jul12/S05_bad_channels.csv'
best_channel_path = 'Cache/Subject05/Jul12/S05_best_channels.csv'
output_path = 'Cache/Subject05/Jul12/S05_Jul12_spike_sw_coupling.csv'

In [None]:
time_window = 2 # time window (in seconds) to pad before/after the SW

# selected_regions = ['superior_frontal_gyrus', 'middle_frontal_gyrus', 
#                     'inferior_frontal_gyrus', 'medial_frontal_gyrus',
#                     'orbitofrontal_gyrus', 'frontal_pole',
#                     'subcallosal_gyrus', 'subgenual_cingulate_gyrus', 
#                     'anterior_cingulate_gyrus']

### Load Data

In [None]:
raw = mne.io.read_raw_fif(fif_path, preload = True, verbose = None)

# Select only macroelectrodes
raw.pick_types(seeg = True, ecog = True)

# Remove rejected channels
bad_channels = pd.read_csv(bad_channel_path)
bad_channels = bad_channels[bad_channels['channel'].isin(raw.ch_names)]
raw.drop_channels(ch_names = bad_channels['channel'].astype('string'))

# Select channels with the most SW's in each ROI
#best_channels = pd.read_csv(best_channel_path)
#raw.pick_channels(ch_names = best_channels['Channel'].tolist())

# Select channels in specific ROI's
#legui = pd.read_csv(legui_path)
#legui = legui.loc[legui['roi_1'].isin(selected_regions)]
#raw.pick_channels(ch_names = legui['elec_label'].tolist())

### Extract Slow Wave Phase and Envelope (0.3 - 1.5 Hz)

In [None]:
# Extract Power and Phase
delta = raw.copy()
delta = hilbert_powerphase(data = delta, lower = 0.3, upper = 1.5, njobs = 6)
delta = delta[['time', 'channel', 'power', 'phase']]

# Extract Envelope
sw_env = raw.copy()
sw_env = hilbert_envelope(sw_env, lower = 0.3, upper = 1.5, njobs = 1)
sw_env = sw_env[['time', 'channel', 'envelope']]

# Merge Power/Phase and Envelope
delta = delta.merge(sw_env, on = ['time', 'channel'])

# Calculate z-score of power and envelope
delta['log_power'] = 10 * np.log10(delta['power'])
delta['zlog_power'] = delta.groupby(['channel'])['log_power'].transform(zscore)

delta['z_envelope'] = delta.groupby(['channel'])['envelope'].transform(zscore)

### Intersect Spike Times, Slow Wave Intervals, and Envelope/Phase Values

In [None]:
# Load Slow Wave data
sw_times = pd.read_csv(sw_path)

# Merge with LeGUI to get channel laterality
legui = pd.read_csv(legui_path)
legui = legui[['elec_label', 'hemisphere', 'roi_1']]
legui.columns = ['Channel', 'laterality', 'region']
sw_times = sw_times.merge(legui, on = 'Channel', how = 'inner')

# Select and rename SW columns
sw_times = sw_times[['ID', 'Channel', 'laterality', 'region', 'Start', 'End',
                     'NegPeak', 'MidCrossing', 'PosPeak', 'ValNegPeak', 'PTP']]
sw_times.columns = ['sw_id', 'channel_id', 'sw_laterality', 'sw_region', 'start', 'end',
                    'negative_peak', 'mid_crossing', 'positive_peak', 'npeak_amp', 'ptp_amp']

# Only keep SW's from channels contained in the final Raw selection
sw_times = sw_times[sw_times['channel_id'].isin(raw.ch_names)]

# Calculate the z-score of the negative peak value by channel
sw_times['zamp_npeak'] = sw_times.groupby('channel_id')['npeak_amp'].transform(zscore)
sw_times['zamp_ptp'] = sw_times.groupby('channel_id')['ptp_amp'].transform(zscore)

# Expand the SW stop/start times by the window length
# in order to include spikes not occurring during the SW
sw_times['start'] = sw_times['start'] - time_window
sw_times['end'] = sw_times['end'] + time_window

# Load Spike data
spikes = pd.read_csv(spike_path)
spikes = spikes[['unit_id', 'seconds', 'unit_laterality', 'unit_region']]

### BEGIN PHASE DIAGNOSTIC SECTION

In [None]:
# # Step 1: Subset sw_times dataframe
# sw_times_subset = sw_times[['channel_id', 'negative_peak', 'positive_peak']]

# # Step 2: Transform into long format
# sw_times_long = pd.melt(sw_times_subset, id_vars=['channel_id'], 
#                         value_vars=['negative_peak', 'positive_peak'], 
#                         var_name='peak_type', value_name='time')

# # Replace 'negative_peak' and 'positive_peak' with 'negative' and 'positive' for clarity
# sw_times_long['peak_type'] = sw_times_long['peak_type'].str.replace('_peak', '')

# # Step 3: Subset delta dataframe and rename the 'channel' column to 'channel_id'
# delta_subset = delta[['time', 'channel', 'phase']].rename(columns={'channel': 'channel_id'})

# # Step 4: Merge asof with merge_asof()
# # Ensure the 'time' columns in both dataframes are sorted for merge_asof to work correctly
# sw_times_long = sw_times_long.sort_values('time')
# delta_subset = delta_subset.sort_values('time')

# # Convert 'channel_id' from categorical to string (object)
# delta_subset['channel_id'] = delta_subset['channel_id'].astype('str')

# # Perform the merge_asof with the renamed column
# merged_df = pd.merge_asof(sw_times_long, delta_subset, on='time', by='channel_id', direction='nearest')

# # Step 5: Subset the merged dataframe to only include 'peak_type' and 'phase'
# final_df = merged_df[['peak_type', 'phase']]

# import matplotlib.pyplot as plt

# # Assuming final_df is your pandas DataFrame with 'peak_type' and 'phase' columns.
# # Filtering the DataFrame for positive and negative peak_types
# positive_phases = final_df[final_df['peak_type'] == 'positive']['phase']
# negative_phases = final_df[final_df['peak_type'] == 'negative']['phase']

# # Set the limits for the x-axis
# x_limits = (-np.pi, np.pi)

# # Create the histograms
# plt.figure(figsize=(12, 6))

# # Histogram for positive peak_types
# plt.subplot(1, 2, 1)  # 1 row, 2 columns, first subplot
# plt.hist(positive_phases, bins=30, range=x_limits)  # Adjust the number of bins as needed
# plt.title('Positive Peaks')
# plt.xlabel('Phase (radians)')
# plt.ylabel('Frequency')
# plt.xlim(x_limits)

# # Histogram for negative peak_types
# plt.subplot(1, 2, 2)  # 1 row, 2 columns, second subplot
# plt.hist(negative_phases, bins=30, range=x_limits)  # Adjust the number of bins as needed
# plt.title('Negative Peaks')
# plt.xlabel('Phase (radians)')
# plt.ylabel('Frequency')
# plt.xlim(x_limits)

# # Display the histograms
# plt.tight_layout()
# plt.show()

### END PHASE DIAGNOSTIC SECTION

Select claustrum spikes that are contained within a slow wave interval; attach slow-wave meta-data to their associated spikes; copy the spikes so that spikes which occur in multiple channels' slow waves are counted for each channel.

In [None]:
data = pd.DataFrame()

for chan in tqdm(sw_times.channel_id.unique()):

    # Subset the phase-envelope dataset
    sw_data = delta[delta.channel == chan].copy(deep = True)

    # Subset the slow-wave dataset
    sw_data_2 = sw_times[sw_times.channel_id == chan].copy(deep = True)

    # Create an array of intervals denoting the
    # timestamps when slow waves were present:
    sw_windows = IntervalArray.from_arrays(left = sw_times[sw_times.channel_id == chan]['start'],
                                           right = sw_times[sw_times.channel_id == chan]['end'],
                                           closed = 'both')
    
    # Initialize the channel-wise dataset
    sw_spikes = pd.DataFrame()

    # Select spikes that occur within a slow wave interval
    for idx, sw_window in enumerate(sw_windows):

        # Generate booleans for spike times within a window
        selected_spikes = spikes.seconds.between(left = sw_window.left, 
                                                 right = sw_window.right, 
                                                 inclusive = 'both')

        # Use booleans to select the actual spike times
        sw_spikes_temp = spikes[selected_spikes].copy(deep = True)

        # Add SW channel meta-data
        sw_spikes_temp['channel_id'] = chan
        sw_spikes_temp['channel_region'] = sw_times[sw_times.channel_id == chan]['sw_region'].iloc[0]
        sw_spikes_temp['channel_side'] = sw_times[sw_times.channel_id == chan]['sw_laterality'].iloc[0]
        
        # Take advantage of the fact that the sw_windows index
        # will be identical to the sw_data_2 index to get SW meta-data:
        sw_spikes_temp['negative_peak'] = sw_spikes_temp['seconds'] - sw_data_2.iloc[idx]['negative_peak']
        sw_spikes_temp['mid_crossing'] = sw_spikes_temp['seconds'] - sw_data_2.iloc[idx]['mid_crossing']
        sw_spikes_temp['positive_peak'] = sw_spikes_temp['seconds'] - sw_data_2.iloc[idx]['positive_peak']

        sw_spikes_temp['start'] = sw_spikes_temp['seconds'] - (sw_data_2.iloc[idx]['start'] + time_window)
        sw_spikes_temp['end'] = sw_spikes_temp['seconds'] - (sw_data_2.iloc[idx]['end'] - time_window)

        # Add SW meta-data
        sw_spikes_temp['sw_id'] = sw_data_2.iloc[idx]['sw_id']
        sw_spikes_temp['zamp_npeak'] = sw_data_2.iloc[idx]['zamp_npeak']
        sw_spikes_temp['zamp_ptp'] = sw_data_2.iloc[idx]['zamp_ptp']

        # Concatenate into channel dataset
        sw_spikes = pd.concat((sw_spikes, sw_spikes_temp))

    # For every spike, find the nearest 
    # sample in the phase-envelope dataset...
    data_temp = pd.merge_asof(sw_spikes.sort_values('seconds'), sw_data.sort_values('time'), 
                              left_on = 'seconds', right_on = 'time', direction = 'nearest')

    data_temp.drop(columns = ['time', 'channel', 'power'], inplace = True)

    # Concatenate into final dataset
    data = pd.concat((data, data_temp))

In [None]:
data.to_csv(output_path, index = False)