### Requirements:
* mne
* eeglabio
* yasa
* pandas
* numpy
* matplotlib
* ipywidgets
* statsmodel

### Worth Noting:

1. For development, use qt for interactive visualization.
2. Consider minimum sample frequency you need and resample before detection. Save quite some time. 
3. Conversion from EEGLAB to MNE requires `eeglabio`

In [93]:

# Housekeeping imports
import os
import mne
from mne.io import RawArray
import numpy as np
import pandas as pd
import yasa
import matplotlib.pyplot as plt  
import statsmodels.api as sm
import ipywidgets as widgets 

# Activate interactive figures with %matplotlib qt (useful in Jupyter environments)
%matplotlib qt

In [94]:
fname = '/Volumes/CSC-Ido/Analyze/119/N1/Strength_119_N1_forSW.set'
raw = mne.io.read_raw_eeglab(fname, preload=True);           # Preload data into memory for data manipulation and faster indexing.
# prepare the data for processing
raw.resample(100)                      # adjust as needed. Consider performing this steps before filtering/detection. 
data = raw.get_data(units="uV")        # standard practice
sf = raw.info['sfreq']
print(data.shape , sf)

Reading /Volumes/CSC-Ido/Analyze/107/N1/Strength_107_N1_forSW.fdt
Reading 0 ... 6632735  =      0.000 ... 13265.470 secs...


  raw = mne.io.read_raw_eeglab(fname, preload=True);           # Preload data into memory for data manipulation and faster indexing.


(186, 1326547) 100.0


In [95]:
# plot remaining channels after preprocessing 
fig_2d = raw.plot_sensors(kind='topomap', show_names=True, show=False) 

In [96]:
events = mne.events_from_annotations(raw)  # raw events
events_id = events[-1]                     # grab event dict
actual_events = events[:-1][0]             # grab actual events
print(events_id, '\n') 
print(actual_events)

Used Annotations descriptions: ['stim end', 'stim start']
{'stim end': 1, 'stim start': 2} 

[[  95066       0       2]
 [ 116674       0       1]
 [ 172472       0       2]
 [ 194080       0       1]
 [ 376852       0       2]
 [ 398460       0       1]
 [ 451471       0       2]
 [ 472136       0       1]
 [ 586338       0       2]
 [ 607945       0       1]
 [ 662243       0       2]
 [ 683642       0       1]
 [ 735949       0       2]
 [ 757102       0       1]
 [1094828       0       2]
 [1116437       0       1]
 [1151207       0       2]
 [1172815       0       1]]


In [97]:
# Extract events and event IDs from annotations in the raw data
events, event_id = mne.events_from_annotations(raw)

print(event_id, '\n')
print(events)

# Dictionary mapping event descriptions to numerical codes
column_dict = {'stim end': 1, 'stim start': 2}

# Indices for 'stim end' and 'stim start'
stim_end_index = column_dict['stim end']
stim_start_index = column_dict['stim start']

# Filter events to get only stim start and stim end
filtered_data = [item for item in events if item[2] == stim_end_index or item[2] == stim_start_index]

# Minimum stim duration threshold in seconds (example: 100 seconds)
min_stim_duration_sec = 100
min_stim_duration_samples = int(min_stim_duration_sec * sf)

# Initialize lists for pre-stim, early stim, late stim, and post-stim epochs with protocol numbers
pre_stim_epochs = []
early_stim_epochs = []
late_stim_epochs = []
post_stim_epochs = []

# Previous epoch end to check for overlap
previous_end = 0

# Protocol counter
protocol_number = 1

# Loop through the epochs and define pre-stim, early stim, late stim, and post-stim epochs
for i in range(0, len(filtered_data), 2):
    if i + 1 < len(filtered_data):  # Ensure i+1 is within bounds
        stim_start = filtered_data[i][0]
        stim_end = filtered_data[i+1][0]
        stim_duration = stim_end - stim_start  # Calculate stim duration
        
        if stim_duration < min_stim_duration_samples:
            continue  # Skip this stim epoch if it is shorter than the minimum duration
        
        stim_midpoint = (stim_start + stim_end) // 2  # Calculate the midpoint of the stim epoch
        
        pre_stim_epoch = (stim_start - stim_duration, stim_start, protocol_number)
        early_stim_epoch = (stim_start, stim_midpoint, protocol_number)
        late_stim_epoch = (stim_midpoint, stim_end, protocol_number)
        post_stim_epoch = (stim_end, stim_end + stim_duration, protocol_number)
        
        # Check for overlap with the previous epoch
        if pre_stim_epoch[0] < previous_end:
            continue  # Skip this entire protocol if there's an overlap with the previous one
        
        # Update previous_end to the end of the current post-stim epoch
        previous_end = post_stim_epoch[1]
        
        pre_stim_epochs.append(pre_stim_epoch)
        early_stim_epochs.append(early_stim_epoch)
        late_stim_epochs.append(late_stim_epoch)
        post_stim_epochs.append(post_stim_epoch)
        
        # Increment protocol number
        protocol_number += 1

# Convert epochs to time for plotting
def convert_sample_to_time(epochs, sf):
    return [(start / sf, end / sf) for start, end, protocol in epochs]

pre_stim_epochs_time = convert_sample_to_time(pre_stim_epochs, sf)
early_stim_epochs_time = convert_sample_to_time(early_stim_epochs, sf)
late_stim_epochs_time = convert_sample_to_time(late_stim_epochs, sf)
post_stim_epochs_time = convert_sample_to_time(post_stim_epochs, sf)

# Optionally store or print protocol numbers for reference
protocol_numbers = [epoch[2] for epoch in pre_stim_epochs]  # Collect protocol numbers (only need to do it once)
print(f"Protocol Numbers: {protocol_numbers}")

# Plotting
plt.figure(figsize=(10, 6))

# Plot pre-stim epochs as shaded regions
for (start, end) in pre_stim_epochs_time:
    plt.axvspan(start, end, color='blue', alpha=0.3, label='Pre-Stim')

# Plot early stim epochs as shaded regions and add protocol numbers
for i, (start, end) in enumerate(early_stim_epochs_time):
    plt.axvspan(start, end, color='orange', alpha=0.3, label='Early Stim')
    # Add the protocol number only for early stim, to represent the entire protocol
    plt.text((start + end) / 2, 0.2, f'P{protocol_numbers[i]}', color='black', fontsize=10, ha='center')

# Plot late stim epochs as shaded regions
for (start, end) in late_stim_epochs_time:
    plt.axvspan(start, end, color='red', alpha=0.3, label='Late Stim')

# Plot post-stim epochs as shaded regions
for (start, end) in post_stim_epochs_time:
    plt.axvspan(start, end, color='green', alpha=0.3, label='Post-Stim')

# Set labels and title
plt.xlabel('Time (s)')
plt.ylabel('Epoch Type')
plt.title('Stimulation Protocols Visualization')

# To prevent duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())

plt.show()

Used Annotations descriptions: ['stim end', 'stim start']
{'stim end': 1, 'stim start': 2} 

[[  95066       0       2]
 [ 116674       0       1]
 [ 172472       0       2]
 [ 194080       0       1]
 [ 376852       0       2]
 [ 398460       0       1]
 [ 451471       0       2]
 [ 472136       0       1]
 [ 586338       0       2]
 [ 607945       0       1]
 [ 662243       0       2]
 [ 683642       0       1]
 [ 735949       0       2]
 [ 757102       0       1]
 [1094828       0       2]
 [1116437       0       1]
 [1151207       0       2]
 [1172815       0       1]]
Protocol Numbers: [1, 2, 3, 4, 5, 6, 7, 8]


In [98]:
# # %% [code]
# import numpy as np
# import mne
# import antropy as ant  # Ensure you have antropy installed: pip install antropy
# import matplotlib.pyplot as plt

# # For time–frequency analysis using Morlet wavelets:
# from mne.time_frequency import tfr_array_morlet

# # --------------------------
# # PARAMETERS
# # --------------------------
# # Frequency range for PSD estimation
# fmin, fmax = 0.5, 2.0  # Adjust if needed

# # Narrow frequency band around 1 Hz (since your stimulation is exactly at 1 Hz)
# entrainment_band = [0.95, 1.05]

# # Permutation entropy parameters
# perm_order = 3
# perm_delay = 1

# # Select the channel index to analyze (here, channel 0; adjust as needed)
# chan_idx = 0

# # --------------------------
# # CONTAINER LISTS FOR RESULTS
# # --------------------------
# band_power_diff = []  # (late - early) power in the entrainment band for each protocol
# entropy_diff = []     # (late - early) permutation entropy for each protocol
# protocol_stats = []   # to store per-protocol values and PSDs (for plotting)

# print("General:")
# # --------------------------
# # LOOP OVER PROTOCOLS (using early_stim_epochs and late_stim_epochs)
# # --------------------------
# for prot_idx, (early_epoch, late_epoch) in enumerate(zip(early_stim_epochs, late_stim_epochs)):
#     # Unpack sample indices (the third element is the protocol number)
#     early_start, early_end, _ = early_epoch
#     late_start, late_end, _ = late_epoch

#     # Extract the data segments from the chosen channel.
#     early_data = data[chan_idx, early_start:early_end]
#     late_data  = data[chan_idx, late_start:late_end]

#     # --------------------------
#     # (1) BAND POWER ANALYSIS: Compute PSD using MNE's psd_array_welch
#     # --------------------------
#     psd_early, freqs = mne.time_frequency.psd_array_welch(
#         early_data[np.newaxis, :],
#         sfreq=sf,
#         fmin=fmin,
#         fmax=fmax,
#         n_fft=len(early_data),
#         verbose=False
#     )
#     psd_late, _ = mne.time_frequency.psd_array_welch(
#         late_data[np.newaxis, :],
#         sfreq=sf,
#         fmin=fmin,
#         fmax=fmax,
#         n_fft=len(late_data),
#         verbose=False
#     )
#     # Remove extra channel dimension
#     psd_early = psd_early[0]
#     psd_late  = psd_late[0]

#     # For the first protocol, print frequency resolution for the early epoch
#     if prot_idx == 0 and len(freqs) > 1:
#         freq_res = freqs[1] - freqs[0]
#         print(f"Frequency resolution for early epoch (Protocol 1): {freq_res:.4f} Hz")
#         print("-" * 50)

#     # Identify frequency indices within the narrow entrainment band (around 1 Hz)
#     band_idx = np.where((freqs >= entrainment_band[0]) & (freqs <= entrainment_band[1]))[0]
#     power_early = np.mean(psd_early[band_idx])
#     power_late  = np.mean(psd_late[band_idx])
#     diff_power  = power_late - power_early
#     band_power_diff.append(diff_power)

#     # --------------------------
#     # (2) ENTROPY ANALYSIS: Permutation entropy (lower entropy may indicate higher entrainment)
#     # --------------------------
#     ent_early = ant.perm_entropy(early_data, order=perm_order, delay=perm_delay, normalize=True)
#     ent_late  = ant.perm_entropy(late_data, order=perm_order, delay=perm_delay, normalize=True)
#     diff_ent  = ent_late - ent_early
#     entropy_diff.append(diff_ent)

#     # Save protocol stats for later use (including PSDs for plotting)
#     protocol_stats.append({
#         "protocol": prot_idx + 1,
#         "power_early": power_early,
#         "power_late": power_late,
#         "diff_power": diff_power,
#         "ent_early": ent_early,
#         "ent_late": ent_late,
#         "diff_ent": diff_ent,
#         "freqs": freqs,
#         "psd_early": psd_early,
#         "psd_late": psd_late
#     })

#     # --------------------------
#     # PRINT PROTOCOL-SPECIFIC RESULTS
#     # --------------------------
#     print(f"Protocol {prot_idx+1}:")
#     print(f"  Entrainment Band Power (1Hz narrow band): Early = {power_early:.4f} uV²/Hz, Late = {power_late:.4f} uV²/Hz, Diff (Late-Early) = {diff_power:.4f}")
#     print(f"  Permutation Entropy: Early = {ent_early:.4f}, Late = {ent_late:.4f}, Diff (Late-Early) = {diff_ent:.4f}")
#     print("-" * 50)

#     # --------------------------
#     # PLOT PSD COMPARISON FOR THIS PROTOCOL
#     # --------------------------
#     plt.figure(figsize=(10, 6))
#     plt.plot(freqs, psd_early, label="Early Stimulation PSD", color="orange")
#     plt.plot(freqs, psd_late, label="Late Stimulation PSD", color="red")
#     plt.axvline(x=1.0, color="green", linestyle="--", label="1 Hz")
#     plt.axvspan(entrainment_band[0], entrainment_band[1], color='gray', alpha=0.3,
#                 label=f"Entrainment band ({entrainment_band[0]}-{entrainment_band[1]} Hz)")
#     plt.xlabel("Frequency (Hz)")
#     plt.ylabel("Power (uV²/Hz)")
#     plt.title(f"PSD Comparison for Protocol {prot_idx+1}")
#     plt.legend()
#     plt.show()

# # --------------------------
# # PRINT AND PLOT AVERAGE DIFFERENCES ACROSS PROTOCOLS
# # --------------------------
# avg_band_power_diff = np.mean(band_power_diff)
# avg_entropy_diff    = np.mean(entropy_diff)

# print("\nAVERAGE ACROSS PROTOCOLS:")
# print(f"  Average Band Power Difference (Late - Early): {avg_band_power_diff:.4f} uV²/Hz")
# print(f"  Average Permutation Entropy Difference (Late - Early): {avg_entropy_diff:.4f}")

# plt.figure(figsize=(10, 6))
# protocol_numbers = np.arange(1, len(band_power_diff) + 1)
# plt.bar(protocol_numbers, band_power_diff, color='blue', alpha=0.7, label='Band Power Difference (Late - Early)')
# plt.axhline(avg_band_power_diff, color='red', linestyle='--', label=f'Average Difference ({avg_band_power_diff:.4f})')
# plt.xlabel("Protocol")
# plt.ylabel("Band Power Difference (uV²/Hz)")
# plt.title("Band Power Differences Across Protocols")
# plt.legend()
# plt.show()

# # --------------------------
# # CREATE A SPECTROGRAM: Average Across Protocols
# # --------------------------
# # We will use the entire stimulation window for each protocol (from early start to late end).
# # First, extract and store these stimulation windows.
# all_stim_windows = []
# for early_epoch, late_epoch in zip(early_stim_epochs, late_stim_epochs):
#     stim_start = early_epoch[0]
#     stim_end   = late_epoch[1]
#     stim_data = data[chan_idx, stim_start:stim_end]
#     all_stim_windows.append(stim_data)

# # To average spectrograms, we first need to make them the same length.
# # Find the minimum length among all stimulation windows.
# min_length = np.min([len(win) for win in all_stim_windows])

# # Crop each stimulation window to this minimum length
# cropped_windows = [win[:min_length] for win in all_stim_windows]

# # Stack into a 3D array with shape (n_protocols, n_channels, n_times)
# data_array = np.array(cropped_windows)  # shape (n_protocols, min_length)
# data_array = data_array[:, np.newaxis, :]  # add channel dimension

# # Define frequencies for the TFR (e.g., from 0.5 to 2.0 Hz)
# freqs_tfr = np.linspace(0.5, 2.0, 50)
# n_cycles = 3  # number of cycles in the Morlet wavelet

# # Compute time–frequency representation for each protocol using Morlet wavelets.
# # Note: The 'return_itc' parameter has been removed.
# tfr = tfr_array_morlet(data_array, sfreq=sf, freqs=freqs_tfr, n_cycles=n_cycles,
#                        decim=1, n_jobs=1)
# # tfr has shape (n_protocols, n_channels, n_freqs, n_times)

# # Convert the complex TFR output to power (squared absolute value)
# tfr_power = np.abs(tfr) ** 2

# # Average across protocols and then squeeze the channel dimension.
# tfr_avg = tfr_power.mean(axis=0)[0]  # shape (n_freqs, n_times)

# # Create a time vector for the stimulation window (in seconds)
# times = np.arange(min_length) / sf

# # Convert power to decibels for visualization.
# tfr_db = 10 * np.log10(tfr_avg)

# plt.figure(figsize=(12, 6))
# plt.imshow(tfr_db, aspect='auto', origin='lower',
#            extent=[times[0], times[-1], freqs_tfr[0], freqs_tfr[-1]],
#            cmap='viridis')
# plt.colorbar(label='Power (dB)')
# plt.xlabel("Time (s)")
# plt.ylabel("Frequency (Hz)")
# plt.title("Average Spectrogram Across Protocols\n(Stimulation Window: Early Start to Late End)")
# plt.show()


In [99]:
# # %% [code]
# import numpy as np
# import mne
# import matplotlib.pyplot as plt

# # --------------------------
# # PARAMETERS (should match previous settings)
# # --------------------------
# fmin, fmax = 0.5, 2.0
# entrainment_band = [0.95, 1.05]
# chan_idx = 0  # channel to analyze

# # --------------------------
# # EXTRACT EARLY AND LATE STIMULATION WINDOWS FOR EACH PROTOCOL
# # --------------------------
# all_early_windows = []
# all_late_windows = []

# for early_epoch, late_epoch in zip(early_stim_epochs, late_stim_epochs):
#     early_data = data[chan_idx, early_epoch[0]:early_epoch[1]]
#     late_data  = data[chan_idx, late_epoch[0]:late_epoch[1]]
#     all_early_windows.append(early_data)
#     all_late_windows.append(late_data)

# # --------------------------
# # FIND A COMMON MINIMUM LENGTH ACROSS ALL EARLY AND LATE WINDOWS
# # --------------------------
# min_len_early = np.min([len(win) for win in all_early_windows])
# min_len_late  = np.min([len(win) for win in all_late_windows])
# common_min_length = min(min_len_early, min_len_late)

# # Crop each window to the common minimum length
# cropped_early = [win[:common_min_length] for win in all_early_windows]
# cropped_late  = [win[:common_min_length] for win in all_late_windows]

# # --------------------------
# # COMPUTE PSD FOR EACH CROP
# # --------------------------
# # For early stimulation windows
# all_psd_early = []
# for win in cropped_early:
#     psd, freqs = mne.time_frequency.psd_array_welch(win[np.newaxis, :],
#                                                      sfreq=sf,
#                                                      fmin=fmin,
#                                                      fmax=fmax,
#                                                      n_fft=common_min_length,
#                                                      verbose=False)
#     all_psd_early.append(psd[0])
# all_psd_early = np.array(all_psd_early)

# # For late stimulation windows
# all_psd_late = []
# for win in cropped_late:
#     psd, freqs = mne.time_frequency.psd_array_welch(win[np.newaxis, :],
#                                                      sfreq=sf,
#                                                      fmin=fmin,
#                                                      fmax=fmax,
#                                                      n_fft=common_min_length,
#                                                      verbose=False)
#     all_psd_late.append(psd[0])
# all_psd_late = np.array(all_psd_late)

# # --------------------------
# # AVERAGE THE PSDS ACROSS PROTOCOLS
# # --------------------------
# avg_psd_early = np.mean(all_psd_early, axis=0)
# avg_psd_late  = np.mean(all_psd_late, axis=0)
# diff_psd      = avg_psd_late - avg_psd_early

# # Print frequency resolution (should be the same for early and late since n_fft is common)
# if len(freqs) > 1:
#     freq_res = freqs[1] - freqs[0]
#     print(f"Frequency resolution for averaged PSD: {freq_res:.4f} Hz")

# # --------------------------
# # PLOT THE AVERAGE PSD CURVES AND THEIR DIFFERENCE
# # --------------------------
# plt.figure(figsize=(10, 6))
# plt.plot(freqs, avg_psd_early, label="Average Early-Stim PSD", color="orange")
# plt.plot(freqs, avg_psd_late, label="Average Late-Stim PSD", color="red")
# plt.plot(freqs, diff_psd, label="Difference (Late - Early)", color="blue")
# plt.axvline(x=1.0, color="green", linestyle="--", label="1 Hz")
# plt.axvspan(entrainment_band[0], entrainment_band[1], color='gray', alpha=0.3,
#             label=f"Entrainment band ({entrainment_band[0]}–{entrainment_band[1]} Hz)")
# plt.xlabel("Frequency (Hz)")
# plt.ylabel("Power (uV²/Hz)")
# plt.title("Average PSD Across Protocols: Early vs. Late Stimulation and Difference")
# plt.legend()
# plt.show()


In [110]:
# %% [code]
import numpy as np
import mne
import matplotlib.pyplot as plt

# --------------------------
# PARAMETERS (should match previous settings)
# --------------------------
fmin, fmax = 0.5, 2.0
entrainment_band = [0.8, 1.2]
# We'll compute the narrow-band power for every channel.
n_channels = data.shape[0]  # assuming data shape is (n_channels, n_samples)

# --------------------------
# COMPUTE AVERAGE NARROW-BAND POWER ACROSS PROTOCOLS FOR EACH CHANNEL
# --------------------------
early_topo = np.zeros(n_channels)
late_topo  = np.zeros(n_channels)

# Loop over channels
for ch in range(n_channels):
    early_vals = []  # will hold narrow-band power for each protocol during early-stim
    late_vals  = []  # will hold narrow-band power for each protocol during late-stim
    
    # Loop over protocols using early_stim_epochs and late_stim_epochs
    for early_epoch, late_epoch in zip(early_stim_epochs, late_stim_epochs):
        # Extract data for current channel for the early and late epochs
        early_data = data[ch, early_epoch[0]:early_epoch[1]]
        late_data  = data[ch, late_epoch[0]:late_epoch[1]]
        
        # Compute PSD for the early epoch for the current channel.
        psd_early, freqs = mne.time_frequency.psd_array_welch(
            early_data[np.newaxis, :],
            sfreq=sf,
            fmin=fmin,
            fmax=fmax,
            n_fft=len(early_data),
            verbose=False
        )
        # Compute PSD for the late epoch for the current channel.
        psd_late, _ = mne.time_frequency.psd_array_welch(
            late_data[np.newaxis, :],
            sfreq=sf,
            fmin=fmin,
            fmax=fmax,
            n_fft=len(late_data),
            verbose=False
        )
        # Remove extra channel dimension
        psd_early = psd_early[0]
        psd_late  = psd_late[0]
        
        # Identify frequency indices within the narrow entrainment band (0.8–1.2 Hz)
        band_idx = np.where((freqs >= entrainment_band[0]) & (freqs <= entrainment_band[1]))[0]
        
        # Compute average power in the band for early and late epochs.
        early_power = np.mean(psd_early[band_idx])
        late_power  = np.mean(psd_late[band_idx])
        
        # Append to our list for the current channel.
        early_vals.append(early_power)
        late_vals.append(late_power)
    
    # Average across protocols for the current channel.
    early_topo[ch] = np.mean(early_vals)
    late_topo[ch]  = np.mean(late_vals)

# Compute the difference topography (Late minus Early)
diff_topo = late_topo - early_topo

# --------------------------
# PLOT TOPOMAPS FOR EARLY, LATE, AND DIFFERENCE WITH A SCALE (COLORBAR)
# --------------------------
# Ensure that raw.info contains channel locations (montage). If not, set a standard montage:
# raw.set_montage('standard_1020')

# Create a figure with 1 row and 3 columns for the subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Topomap for average early-stim narrow-band power.
im0, _ = mne.viz.plot_topomap(early_topo, raw.info, axes=axes[0],
                              show=False, sensors=True, contours=0)
axes[0].set_title("Average Early-Stim PSD (0.8-1.2 Hz)")
cbar0 = plt.colorbar(im0, ax=axes[0], orientation='vertical', fraction=0.046, pad=0.04)
cbar0.set_label('Power (uV²/Hz)')

# Topomap for average late-stim narrow-band power.
im1, _ = mne.viz.plot_topomap(late_topo, raw.info, axes=axes[1],
                              show=False, sensors=True, contours=0)
axes[1].set_title("Average Late-Stim PSD (0.8-1.2 Hz)")
cbar1 = plt.colorbar(im1, ax=axes[1], orientation='vertical', fraction=0.046, pad=0.04)
cbar1.set_label('Power (uV²/Hz)')

# Topomap for the difference (Late - Early)
im2, _ = mne.viz.plot_topomap(diff_topo, raw.info, axes=axes[2],
                              show=False, sensors=True, contours=0)
axes[2].set_title("Difference (Late - Early) PSD (0.8-1.2 Hz)")
cbar2 = plt.colorbar(im2, ax=axes[2], orientation='vertical', fraction=0.046, pad=0.04)
cbar2.set_label('Power Difference (uV²/Hz)')

plt.tight_layout()
plt.show()


In [101]:
#sw = yasa.sw_detect(raw, hypno=hypno_up, include=(2, 3))      #if you have a hyponogram, you can add these arguments. 
sw = yasa.sw_detect(raw, verbose=False, coupling=False);
df = sw.summary(); # general summary for each sw
df # Inspect the dataframe

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    1.0s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    2.4s


Unnamed: 0,Start,NegPeak,MidCrossing,PosPeak,End,Duration,ValNegPeak,ValPosPeak,PTP,Slope,Frequency,Channel,IdxChannel
0,1223.08,1223.65,1223.93,1224.20,1224.67,1.59,-43.823898,42.814476,86.638374,309.422766,0.628931,E2,0
1,1914.49,1914.80,1915.06,1915.27,1915.53,1.04,-51.971251,31.699985,83.671236,321.812446,0.961538,E2,0
2,2274.47,2274.71,2274.97,2275.14,2275.33,0.86,-58.786700,23.253850,82.040550,315.540576,1.162791,E2,0
3,2347.05,2347.45,2347.72,2348.00,2348.61,1.56,-42.371320,38.033741,80.405062,297.796525,0.641026,E2,0
4,2364.50,2364.79,2365.05,2365.25,2365.55,1.05,-48.719274,27.986922,76.706195,295.023827,0.952381,E2,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
6065,12258.50,12259.34,12259.59,12259.85,12260.18,1.68,-55.612660,57.097145,112.709805,450.839220,0.595238,E256,185
6066,12282.14,12282.50,12282.77,12282.95,12283.75,1.61,-51.039644,26.948499,77.988144,288.844977,0.621118,E256,185
6067,12288.66,12289.49,12289.71,12289.98,12290.32,1.66,-44.847678,61.786618,106.634296,484.701345,0.602410,E256,185
6068,12371.79,12372.24,12372.64,12373.24,12373.53,1.74,-53.142763,33.046498,86.189261,215.473152,0.574713,E256,185


In [102]:
sw.plot_detection() # lets you scroll through the detection very conveniently with marked detections.

interactive(children=(IntSlider(value=0, description='Epoch:', layout=Layout(align_items='center', justify_con…

<function yasa.detection._DetectionResults.plot_detection.<locals>.update(epoch, amplitude, channel, win_size, filt)>

In [103]:
# Define the classification function
def classify_wave(start_time, pre_stim_epochs_time, early_stim_epochs_time, late_stim_epochs_time, post_stim_epochs_time):
    """Classify each wave based on the start time into 'Pre-Stim', 'Early-Stim', 'Late-Stim', or 'Post-Stim' and assign protocol number."""
    for idx, (start, end) in enumerate(pre_stim_epochs_time):
        if start <= start_time <= end:
            return 'Pre-Stim', idx + 1
    for idx, (start, end) in enumerate(early_stim_epochs_time):
        if start <= start_time <= end:
            return 'Early-Stim', idx + 1
    for idx, (start, end) in enumerate(late_stim_epochs_time):
        if start <= start_time <= end:
            return 'Late-Stim', idx + 1
        
    for idx, (start, end) in enumerate(post_stim_epochs_time):
        if start <= start_time <= end:
            return 'Post-Stim', idx + 1
    return 'Unknown', None  # If the wave does not fall within any of the epochs

# Apply classification to DataFrame
df[['Classification', 'Protocol Number']] = df['Start'].apply(lambda start_time: classify_wave(start_time, pre_stim_epochs_time, early_stim_epochs_time, late_stim_epochs_time, post_stim_epochs_time)).apply(pd.Series)

# Filter out rows classified as 'Unknown'
df_filtered = df[df['Classification'] != 'Unknown']

# Now df_filtered contains both the classification and the protocol number for each wave.

In [104]:
df_filtered

Unnamed: 0,Start,NegPeak,MidCrossing,PosPeak,End,Duration,ValNegPeak,ValPosPeak,PTP,Slope,Frequency,Channel,IdxChannel,Classification,Protocol Number
0,1223.08,1223.65,1223.93,1224.20,1224.67,1.59,-43.823898,42.814476,86.638374,309.422766,0.628931,E2,0,Post-Stim,1.0
1,1914.49,1914.80,1915.06,1915.27,1915.53,1.04,-51.971251,31.699985,83.671236,321.812446,0.961538,E2,0,Late-Stim,2.0
8,6493.66,6494.16,6494.51,6494.81,6495.24,1.58,-50.434896,40.389758,90.824654,259.499012,0.632911,E2,0,Pre-Stim,6.0
9,6698.19,6698.51,6698.75,6698.98,6699.29,1.10,-43.906299,35.850944,79.757243,332.321845,0.909091,E2,0,Early-Stim,6.0
10,6931.89,6932.67,6932.98,6933.23,6933.50,1.61,-50.243733,34.186700,84.430433,272.356234,0.621118,E2,0,Post-Stim,6.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6042,7509.45,7509.77,7510.00,7510.21,7510.43,0.98,-56.081128,50.173632,106.254759,461.977214,1.020408,E256,185,Late-Stim,7.0
6043,7605.36,7605.65,7605.90,7606.17,7606.43,1.07,-58.501179,58.298106,116.799285,467.197140,0.934579,E256,185,Post-Stim,7.0
6044,7609.82,7610.09,7610.34,7610.60,7611.00,1.18,-46.340887,37.250680,83.591567,334.366268,0.847458,E256,185,Post-Stim,7.0
6045,7647.79,7648.11,7648.75,7649.17,7649.53,1.74,-57.141139,74.305210,131.446349,205.384920,0.574713,E256,185,Post-Stim,7.0


In [105]:
# Group by classification and calculate mean and count for each group
comparison_means = df_filtered.groupby('Classification')[['Duration', 'ValNegPeak', 'ValPosPeak', 'PTP', 'Frequency']].mean()
comparison_counts = df_filtered.groupby('Classification')['Start'].count()  # Counting instances using the 'Start' column

# Print results
print("Mean Values by Group:")
print(comparison_means)
print("\nCount of Instances by Group:")
print(comparison_counts)

Mean Values by Group:
                Duration  ValNegPeak  ValPosPeak         PTP  Frequency
Classification                                                         
Early-Stim      1.387269  -48.569586   47.274337   95.843923   0.760434
Late-Stim       1.382352  -49.717699   41.892090   91.609789   0.759782
Post-Stim       1.346312  -52.675482   48.010000  100.685483   0.778666
Pre-Stim        1.396502  -51.499540   51.352164  102.851704   0.748867

Count of Instances by Group:
Classification
Early-Stim     692
Late-Stim      438
Post-Stim     1467
Pre-Stim      1155
Name: Start, dtype: int64


In [106]:
# Assuming df is your original DataFrame with relevant columns

# Define the specific columns you want to plot
columns_to_plot = ['Duration', 'ValNegPeak', 'ValPosPeak', 'PTP', 'Frequency']

# Define all possible classifications except 'Unknown'
all_classifications = ['Early-Stim', 'Post-Stim', 'Pre-Stim', 'Late-Stim']

# Filter out 'Unknown' classification from the DataFrame
df_filtered = df[df['Classification'].isin(all_classifications)]

# Calculate overall means and counts, including all classifications except 'Unknown'
overall_means = df_filtered.groupby('Classification')[columns_to_plot].mean().reindex(all_classifications, fill_value=0)
overall_counts = df_filtered['Classification'].value_counts().reindex(all_classifications, fill_value=0)

# Function to add value labels on bars
def add_value_labels(ax, spacing=5):
    """Add labels to the end of each bar in a bar chart."""
    for rect in ax.patches:
        y_value = rect.get_height()
        x_value = rect.get_x() + rect.get_width() / 2
        label = f"{y_value:.2f}" if y_value != 0 else "0"  # Use a single zero for labels with no decimal part
        ax.annotate(
            label, 
            (x_value, y_value), 
            xytext=(0, spacing), 
            textcoords="offset points", 
            ha='center', 
            va='bottom'
        )

# Figure 1: Overall comparison for all classifications combined, disregarding protocol number
plt.figure(figsize=(15, 6))

# Overall Mean Values Plot
plt.subplot(1, 2, 1)
ax = overall_means.plot(kind='bar', ax=plt.gca(), color=['#6baed6', '#bdd7e7', '#eff3ff', '#fdbe85', '#fd8d3c'])
plt.title('Overall Mean Values of Wave Properties (All Protocols Combined)')
plt.ylabel('Mean Values')
plt.xlabel('Classification', labelpad=10)
plt.xticks(rotation=0)
plt.legend(title='Properties', loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
add_value_labels(ax)  # Add value labels

# Overall Count Plot
plt.subplot(1, 2, 2)
ax2 = overall_counts.plot(kind='bar', color='#6baed6', ax=plt.gca())
plt.title('Overall Count of Instances by Classification (All Protocols Combined)')
plt.ylabel('Count')
plt.xlabel('Classification', labelpad=10)
plt.xticks(rotation=0)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
add_value_labels(ax2)  # Add value labels

# Show the overall plot
plt.tight_layout()  # Adjusts subplots to give some padding and prevent overlap
plt.show()

# Now, let's create separate plots for each protocol
protocol_numbers = df_filtered['Protocol Number'].dropna().unique()  # Get unique protocol numbers

for protocol in protocol_numbers:
    protocol_data = df_filtered[df_filtered['Protocol Number'] == protocol]
    protocol_means = protocol_data.groupby('Classification')[columns_to_plot].mean().reindex(all_classifications, fill_value=0)
    protocol_counts = protocol_data['Classification'].value_counts().reindex(all_classifications, fill_value=0)

    plt.figure(figsize=(15, 6))

    # Mean values for each protocol
    plt.subplot(1, 2, 1)
    ax = protocol_means.plot(kind='bar', ax=plt.gca(), color=['#6baed6', '#bdd7e7', '#eff3ff', '#fdbe85', '#fd8d3c'])
    plt.title(f'Mean Values of Wave Properties (Protocol {protocol})')
    plt.ylabel('Mean Values')
    plt.xlabel('Classification', labelpad=10)
    plt.xticks(rotation=0)
    plt.legend(title='Properties', loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
    add_value_labels(ax)  # Add value labels

    # Counts for each protocol
    plt.subplot(1, 2, 2)
    ax2 = protocol_counts.plot(kind='bar', color='#6baed6', ax=plt.gca())
    plt.title(f'Count of Instances by Classification (Protocol {protocol})')
    plt.ylabel('Count')
    plt.xlabel('Classification', labelpad=10)
    plt.xticks(rotation=0)
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
    add_value_labels(ax2)  # Add value labels

    # Show the plots for this protocol
    plt.tight_layout()
    plt.show()

In [107]:
'''
creates a CSV will all instances of SW and sorts them in order of importance:
Proto#
pre -> early -> late -> post
Wave#
'''

# Assuming 'df' is your DataFrame
# Filter out rows where 'Protocol Number' is NaN
df_filtered = df.dropna(subset=['Protocol Number'])

# Generate the slow wave names
df_filtered['Slow_Wave_Name'] = 'proto' + df_filtered['Protocol Number'].astype(int).astype(str) + '_' + df_filtered['Classification'].str.lower().replace(' ', '-') + '_sw' + (df_filtered.groupby(['Protocol Number', 'Classification']).cumcount() + 1).astype(str)

# Define the order for the 'Classification'
classification_order = ['pre-stim', 'early-stim', 'late-stim', 'post-stim']

# Create a categorical type for sorting Classification
df_filtered['Classification'] = df_filtered['Classification'].str.lower().replace(' ', '-')
df_filtered['Classification'] = pd.Categorical(df_filtered['Classification'], categories=classification_order, ordered=True)

# Sort the DataFrame based on Protocol Number, Classification order, and slow wave number (extracted from Slow_Wave_Name)
df_sorted = df_filtered.sort_values(by=['Protocol Number', 'Classification', 'Slow_Wave_Name'])

# Keep only the relevant columns
df_epochs = df_sorted[['Slow_Wave_Name', 'Start', 'End', 'Protocol Number', 'Classification', 'NegPeak', 'PosPeak', 'Channel', 'ValNegPeak']]

# Display the DataFrame with the new columns
print(df_epochs.head())

# Save the resulting DataFrame to a CSV file
df_epochs.to_csv('sorted_slow_waves.csv', index=False)

             Slow_Wave_Name   Start     End  Protocol Number Classification  \
211     proto1_pre-stim_sw1  907.57  908.66              1.0       pre-stim   
1157   proto1_pre-stim_sw10  871.39  872.73              1.0       pre-stim   
4633  proto1_pre-stim_sw100  832.92  833.94              1.0       pre-stim   
4634  proto1_pre-stim_sw101  890.66  891.53              1.0       pre-stim   
4705  proto1_pre-stim_sw102  755.91  756.91              1.0       pre-stim   

      NegPeak  PosPeak Channel  ValNegPeak  
211    907.82   908.36     E25  -41.158540  
1157   871.69   872.24    E103  -40.303124  
4633   833.18   833.64    E233  -40.850111  
4634   890.89   891.31    E233  -41.578816  
4705   756.19   756.66    E234  -54.092061  


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered['Slow_Wave_Name'] = 'proto' + df_filtered['Protocol Number'].astype(int).astype(str) + '_' + df_filtered['Classification'].str.lower().replace(' ', '-') + '_sw' + (df_filtered.groupby(['Protocol Number', 'Classification']).cumcount() + 1).astype(str)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered['Classification'] = df_filtered['Classification'].str.lower().replace(' ', '-')
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value 

In [108]:
df_epochs

Unnamed: 0,Slow_Wave_Name,Start,End,Protocol Number,Classification,NegPeak,PosPeak,Channel,ValNegPeak
211,proto1_pre-stim_sw1,907.57,908.66,1.0,pre-stim,907.82,908.36,E25,-41.158540
1157,proto1_pre-stim_sw10,871.39,872.73,1.0,pre-stim,871.69,872.24,E103,-40.303124
4633,proto1_pre-stim_sw100,832.92,833.94,1.0,pre-stim,833.18,833.64,E233,-40.850111
4634,proto1_pre-stim_sw101,890.66,891.53,1.0,pre-stim,890.89,891.31,E233,-41.578816
4705,proto1_pre-stim_sw102,755.91,756.91,1.0,pre-stim,756.19,756.66,E234,-54.092061
...,...,...,...,...,...,...,...,...,...
251,proto8_early-stim_sw6,11025.11,11026.84,8.0,early-stim,11025.58,11026.32,E25,-66.241789
296,proto8_early-stim_sw7,11025.09,11026.92,8.0,early-stim,11025.50,11026.33,E26,-67.406181
401,proto8_early-stim_sw8,11025.14,11026.83,8.0,early-stim,11025.61,11026.30,E31,-70.337213
477,proto8_early-stim_sw9,11025.08,11026.85,8.0,early-stim,11025.56,11026.29,E32,-73.472124


In [109]:
# Assuming df_epochs is your sorted DataFrame from the previous steps

# Function to filter epochs based on window size and picking criteria
def filter_epochs(df, window_size, pick_most_negative=False):
    filtered_epochs_list = []
    last_end_time = -float('inf')
    current_window = []

    for index, row in df.iterrows():
        if row['Start'] > last_end_time + window_size:
            if current_window:
                if pick_most_negative:
                    # Find the wave with the most negative ValNegPeak within the current window
                    max_wave = min(current_window, key=lambda x: x['ValNegPeak'])
                    filtered_epochs_list.append(max_wave)
                else:
                    # Pick the first wave in the window
                    filtered_epochs_list.append(current_window[0])
            # Reset the current window and add the current row to the new window
            current_window = [row]
            # Update the last_end_time to the current row's end time
            last_end_time = row['End']
        else:
            # Add the current row to the current window
            current_window.append(row)

    # After the loop, check if there are any remaining waves in the current window
    if current_window:
        if pick_most_negative:
            max_wave = min(current_window, key=lambda x: x['ValNegPeak'])
            filtered_epochs_list.append(max_wave)
        else:
            filtered_epochs_list.append(current_window[0])

    return pd.DataFrame(filtered_epochs_list, columns=df.columns)

# 1. 0.5s window, pick the first wave
filtered_epochs_05s_first = filter_epochs(df_epochs, window_size=0.5, pick_most_negative=False)
filtered_epochs_05s_first.to_csv('filtered_epochs_05s_first.csv', index=False)

# 2. 0.5s window, pick the wave with the most negative ValNegPeak
filtered_epochs_05s_negative = filter_epochs(df_epochs, window_size=0.5, pick_most_negative=True)
filtered_epochs_05s_negative.to_csv('filtered_epochs_05s_negative.csv', index=False)

# 3. 1s window, pick the first wave
filtered_epochs_1s_first = filter_epochs(df_epochs, window_size=1.0, pick_most_negative=False)
filtered_epochs_1s_first.to_csv('filtered_epochs_1s_first.csv', index=False)

# 4. 1s window, pick the wave with the most negative ValNegPeak
filtered_epochs_1s_negative = filter_epochs(df_epochs, window_size=1.0, pick_most_negative=True)
filtered_epochs_1s_negative.to_csv('filtered_epochs_1s_negative.csv', index=False)

print("Filtered CSV files created successfully.")


Filtered CSV files created successfully.
