In [1]:
#Make sure you have eeglabio for saving the Epochs
#pip install eeglabio


# Housekeeping imports
import os
import mne
from mne.io import RawArray
import numpy as np
import pandas as pd
import yasa
import matplotlib.pyplot as plt  # Make sure to import pyplot as plt for plotting
import statsmodels.api as sm
import ipywidgets as widgets  # Renamed for clarity
import scipy
import seaborn as sns

# Activate interactive figures with %matplotlib qt (useful in Jupyter environments)
%matplotlib qt


In [2]:
fname = '/Users/idohaber/Desktop/101_test/Strength_101_forICA_wcomps_rmcomps.set'
raw = mne.io.read_raw_eeglab(fname, preload=True);

# HOUSEKEEPING: 
# 1. Make sure number and type of events make sense
# 2. Look at sensor coverage 

print(raw.annotations)# Get all unique annotation types
unique_annotations = set(raw.annotations.description)
print("Unique annotation types:", unique_annotations)
# Count the number of 'stim start' and 'stim end' annotations
stim_start_count = sum(raw.annotations.description == 'stim start')
stim_end_count = sum(raw.annotations.description == 'stim end')

print(f"Number of 'stim start' events: {stim_start_count}")
print(f"Number of 'stim end' events: {stim_end_count}")
raw.plot_sensors();

Reading /Users/idohaber/Desktop/101_test/Strength_101_forICA_wcomps_rmcomps.fdt
Reading 0 ... 6496921  =      0.000 ... 12993.842 secs...


  raw = mne.io.read_raw_eeglab(fname, preload=True);
  raw = mne.io.read_raw_eeglab(fname, preload=True);


<Annotations | 1368 segments: SE__ (18), SS__ (19), Sleep Stage (1216), ...>
Unique annotation types: {'boundary', 'stim end', 'SE__', 'stim start', 'Sleep Stage', 'SS__'}
Number of 'stim start' events: 22
Number of 'stim end' events: 24


In [3]:
# prepare the data for processing
data = raw.get_data(units="uV") 
raw.filter(0.5, 4, fir_design='firwin', h_trans_bandwidth=0.2, l_trans_bandwidth=0.2)
raw.resample(100)
sf = raw.info['sfreq']
print(data.shape , sf)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 4 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.20 Hz (-6 dB cutoff frequency: 0.40 Hz)
- Upper passband edge: 4.00 Hz
- Upper transition bandwidth: 0.20 Hz (-6 dB cutoff frequency: 4.10 Hz)
- Filter length: 8251 samples (16.502 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    3.1s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    9.6s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:   20.2s


(198, 6496922) 100.0


In [4]:
# Create annotations DataFrame
annotations_df = pd.DataFrame({
    "Onset": raw.annotations.onset,
    "Duration": raw.annotations.duration,
    "Description": raw.annotations.description,
})

# Filter for only 'stim start' and 'stim end' events
stim_events = annotations_df[annotations_df["Description"].isin(["stim start", "stim end"])].reset_index(drop=True)

# Extract onset times for 'stim start' and 'stim end'
stim_start_times = stim_events[stim_events["Description"] == 'stim start']["Onset"].values
stim_end_times = stim_events[stim_events["Description"] == 'stim end']["Onset"].values


In [5]:
# Initialize variables
cleaned_events = []
expected_event = 'stim start'  # The sequence should start with 'stim start'
min_duration = 170  # Minimum acceptable duration in seconds
max_duration = 220  # Maximum acceptable duration in seconds
i = 0  # Index counter

# List to store information about omitted events
omitted_events = []

# Loop through the stim_events DataFrame
while i < len(stim_events):
    current_event = stim_events.iloc[i]
    if current_event["Description"] == expected_event:
        if expected_event == 'stim start':
            # Tentatively add 'stim start'
            stim_start_event = current_event
            expected_event = 'stim end'
        elif expected_event == 'stim end':
            # Check the time difference between 'stim end' and 'stim start'
            stim_end_event = current_event
            time_diff = stim_end_event["Onset"] - stim_start_event["Onset"]
            if min_duration <= time_diff <= max_duration:
                # Duration is acceptable; add both events to cleaned_events
                cleaned_events.append(stim_start_event)
                cleaned_events.append(stim_end_event)
                expected_event = 'stim start'  # Reset expected event
            else:
                # Duration is not acceptable; discard both events
                omitted_events.append({
                    'stim_start_index': stim_start_event.name,
                    'stim_start_onset': stim_start_event["Onset"],
                    'stim_end_index': stim_end_event.name,
                    'stim_end_onset': stim_end_event["Onset"],
                    'reason': f'Invalid duration ({time_diff:.2f}s)'
                })
                expected_event = 'stim start'  # Reset expected event
        i += 1  # Move to the next event
    else:
        # Event does not match expected; discard and move to the next
        omitted_events.append({
            'event_index': current_event.name,
            'event_onset': current_event["Onset"],
            'event_description': current_event["Description"],
            'reason': f"Unexpected event '{current_event['Description']}'"
        })
        i += 1  # Move to the next event
        expected_event = 'stim start'  # Reset expected event in case of mismatch

# Convert cleaned_events to DataFrame
cleaned_events_df = pd.DataFrame(cleaned_events).reset_index(drop=True)

# Convert omitted events to DataFrame
omitted_events_df = pd.DataFrame(omitted_events)

# Verify the counts
stim_start_count = (cleaned_events_df["Description"] == 'stim start').sum()
stim_end_count = (cleaned_events_df["Description"] == 'stim end').sum()

print(f"\nTotal 'stim start' events: {stim_start_count}")
print(f"Total 'stim end' events: {stim_end_count}")

# Check if we have 20 starts and ends
if stim_start_count == 20 and stim_end_count == 20:
    print("Successfully cleaned data to have 20 'stim start' and 20 'stim end' events.")
else:
    print("Warning: The number of 'stim start' and 'stim end' events does not equal 20.")

# Print omitted events with timepoints
print("\nOmitted Events with Timepoints:")
print(omitted_events_df)

# If you prefer a more readable format
print("\nDetailed Omitted Events:")
for idx, row in omitted_events_df.iterrows():
    if 'stim_start_index' in row:
        print(f"Discarded 'stim start' at index {row['stim_start_index']} (Onset: {row['stim_start_onset']:.2f}s) and "
              f"'stim end' at index {row['stim_end_index']} (Onset: {row['stim_end_onset']:.2f}s) - {row['reason']}")
    else:
        print(f"Skipped event at index {row['event_index']} (Onset: {row['event_onset']:.2f}s, "
              f"Description: '{row['event_description']}') - {row['reason']}")



Total 'stim start' events: 18
Total 'stim end' events: 18

Omitted Events with Timepoints:
   event_index  event_onset event_description                         reason
0           20     8090.000          stim end    Unexpected event 'stim end'
1           21     8090.000          stim end    Unexpected event 'stim end'
2           28    10130.596          stim end    Unexpected event 'stim end'
3           30    10297.116        stim start  Unexpected event 'stim start'
4           31    10413.884          stim end    Unexpected event 'stim end'
5           36    11401.568          stim end    Unexpected event 'stim end'
6           38    11581.140        stim start  Unexpected event 'stim start'
7           39    11797.248          stim end    Unexpected event 'stim end'

Detailed Omitted Events:
Skipped event at index 20 (Onset: 8090.00s, Description: 'stim end') - Unexpected event 'stim end'
Skipped event at index 21 (Onset: 8090.00s, Description: 'stim end') - Unexpected event 's

In [6]:
# Calculate durations between 'stim start' and 'stim end' in the original data
durations = []
i = 0
while i < len(stim_events) - 1:
    if stim_events.iloc[i]["Description"] == 'stim start' and stim_events.iloc[i+1]["Description"] == 'stim end':
        duration = stim_events.iloc[i+1]["Onset"] - stim_events.iloc[i]["Onset"]
        durations.append(duration)
        i += 2  # Skip to the next pair
    else:
        i += 1  # Move to the next event

# Plot the durations
plt.figure(figsize=(10, 5))
plt.plot(durations, marker='o')
plt.axhline(y=170, color='r', linestyle='--', label='Min Duration (170s)')
plt.axhline(y=220, color='r', linestyle='--', label='Max Duration (220s)')
plt.xlabel('Stim Pair Index')
plt.ylabel('Duration (s)')
plt.title('Durations Between "stim start" and "stim end" Events')
plt.legend()
plt.grid(True)
plt.show()


In [7]:
# Ensure that cleaned_events_df is sorted by onset time and reset index
cleaned_events_df = cleaned_events_df.sort_values(by='Onset').reset_index(drop=True)

# --- First Pass: Create Epochs and Identify Overlaps ---

print("Creating epochs and identifying overlaps...\n")

# Initialize lists for epochs
pre_stim_epochs = []
stim_epochs = []
post_stim_epochs = []

# Protocol counter
protocol_number = 1

# Previous protocol's post-stim epoch
prev_post_stim_end = 0

# List to keep track of overlaps
overlaps = []

# Loop through the cleaned events to define epochs
for i in range(0, len(cleaned_events_df), 2):
    if i + 1 < len(cleaned_events_df):  # Ensure i+1 is within bounds
        stim_start_event = cleaned_events_df.iloc[i]
        stim_end_event = cleaned_events_df.iloc[i + 1]

        # Extract onset times in seconds
        stim_start = stim_start_event["Onset"]
        stim_end = stim_end_event["Onset"]
        stim_duration = stim_end - stim_start

        # Define initial pre-stim, stim, and post-stim epochs
        pre_stim_start = stim_start - stim_duration
        pre_stim_end = stim_start

        stim_epoch_start = stim_start
        stim_epoch_end = stim_end

        post_stim_start = stim_end
        post_stim_end = stim_end + stim_duration

        # Check for overlap with the previous protocol's post-stim epoch
        overlap_amount = prev_post_stim_end - pre_stim_start
        if overlap_amount > 0:
            # Calculate half of the overlap amount
            half_overlap = overlap_amount / 2

            # Adjust the previous protocol's post-stim epoch
            if len(post_stim_epochs) > 0:
                prev_protocol = protocol_number - 1
                prev_post_stim = post_stim_epochs[-1]
                adjusted_prev_post_stim_end = prev_post_stim[1] - half_overlap
                if adjusted_prev_post_stim_end < prev_post_stim[0]:
                    print(f"Warning: Negative duration for post-stim epoch of Protocol {prev_protocol}. Setting duration to zero.")
                    adjusted_prev_post_stim_end = prev_post_stim[0]
                post_stim_epochs[-1] = (prev_post_stim[0], adjusted_prev_post_stim_end, prev_post_stim[2])

                print(f"Overlap detected between Protocol {prev_protocol} and Protocol {protocol_number}:")
                print(f"  Adjusting Post-Stim End Time of Protocol {prev_protocol} by {-half_overlap:.2f}s")
                print(f"  New Post-Stim End Time of Protocol {prev_protocol}: {adjusted_prev_post_stim_end:.2f}s")

            # Adjust the current protocol's pre-stim epoch
            adjusted_pre_stim_start = pre_stim_start + half_overlap
            if adjusted_pre_stim_start > pre_stim_end:
                print(f"Warning: Negative duration for pre-stim epoch of Protocol {protocol_number}. Setting duration to zero.")
                adjusted_pre_stim_start = pre_stim_end
            pre_stim_start = adjusted_pre_stim_start

            print(f"  Adjusting Pre-Stim Start Time of Protocol {protocol_number} by {half_overlap:.2f}s")
            print(f"  New Pre-Stim Start Time of Protocol {protocol_number}: {pre_stim_start:.2f}s\n")

            # Record the overlap details
            overlaps.append({
                'protocols': (protocol_number - 1, protocol_number),
                'overlap_amount': overlap_amount,
                'adjusted_prev_post_stim_end': adjusted_prev_post_stim_end,
                'adjusted_pre_stim_start': pre_stim_start
            })
        else:
            print(f"No overlap detected before Protocol {protocol_number}.\n")

        # Update prev_post_stim_end to the end of the current post-stim epoch (after adjustment)
        prev_post_stim_end = post_stim_end

        # Append epochs with protocol number
        pre_stim_epochs.append((pre_stim_start, pre_stim_end, protocol_number))
        stim_epochs.append((stim_epoch_start, stim_epoch_end, protocol_number))
        post_stim_epochs.append((post_stim_start, post_stim_end, protocol_number))

        protocol_number += 1

print("\nEpoch creation and overlap adjustment complete.\n")

# --- Plotting the epochs with overlaps highlighted ---

import matplotlib.pyplot as plt

print("Plotting epochs with overlaps highlighted...\n")

# First Plot: Original epochs (before adjustment)
plt.figure(figsize=(15, 4))

# Plot pre-stim epochs as shaded regions
for (start, end, protocol) in pre_stim_epochs:
    plt.axvspan(start, end, color='blue', alpha=0.3, label='Pre-Stim')

# Plot stim epochs as shaded regions and add protocol numbers
for i, (start, end, protocol) in enumerate(stim_epochs):
    plt.axvspan(start, end, color='orange', alpha=0.3, label='Stim')
    # Add protocol number
    plt.text((start + end) / 2, 0.5, f'P{protocol}', color='black', fontsize=9, ha='center', va='bottom')

# Plot post-stim epochs as shaded regions
for (start, end, protocol) in post_stim_epochs:
    plt.axvspan(start, end, color='green', alpha=0.3, label='Post-Stim')

# Plot overlaps
for overlap in overlaps:
    overlap_start = pre_stim_epochs[overlap['protocols'][1] - 1][0] - (overlap['overlap_amount'] / 2)
    overlap_end = pre_stim_epochs[overlap['protocols'][1] - 1][0] + (overlap['overlap_amount'] / 2)
    plt.axvspan(overlap_start, overlap_end, color='red', alpha=0.5, label='Overlap')

# Set labels and title
plt.xlabel('Time (s)')
plt.yticks([])  # Hide y-axis ticks since it's unitless
plt.title('Stimulation Protocols with Overlaps')

# To prevent duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), loc='upper right')

plt.grid(True)
plt.tight_layout()
plt.show()

# --- Second Plot: Epochs after trimming both pre-stim and post-stim epochs ---

print("Plotting epochs after adjusting to remove overlaps...\n")

# Print statements for adjusted epochs
print("Adjusted Epochs Durations:")
for idx, ((pre_start, pre_end, protocol), (post_start, post_end, _)) in enumerate(zip(pre_stim_epochs, post_stim_epochs)):
    pre_duration = pre_end - pre_start
    post_duration = post_end - post_start
    print(f"  Protocol {protocol}: Pre-Stim Duration = {pre_duration:.2f}s, Post-Stim Duration = {post_duration:.2f}s")

# Plotting adjusted epochs
plt.figure(figsize=(15, 4))

# Plot adjusted pre-stim epochs
for (start, end, protocol) in pre_stim_epochs:
    plt.axvspan(start, end, color='cyan', alpha=0.3, label='Pre-Stim Adjusted')

# Plot stim epochs and add protocol numbers
for i, (start, end, protocol) in enumerate(stim_epochs):
    plt.axvspan(start, end, color='orange', alpha=0.3, label='Stim')
    # Add protocol number
    plt.text((start + end) / 2, 0.5, f'P{protocol}', color='black', fontsize=9, ha='center', va='bottom')

# Plot adjusted post-stim epochs
for (start, end, protocol) in post_stim_epochs:
    plt.axvspan(start, end, color='lime', alpha=0.3, label='Post-Stim Adjusted')

# Set labels and title
plt.xlabel('Time (s)')
plt.yticks([])  # Hide y-axis ticks since it's unitless
plt.title('Stimulation Protocols After Adjusting Epochs')

# To prevent duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), loc='upper right')

plt.grid(True)
plt.tight_layout()
plt.show()


Creating epochs and identifying overlaps...

No overlap detected before Protocol 1.

No overlap detected before Protocol 2.

No overlap detected before Protocol 3.

No overlap detected before Protocol 4.

No overlap detected before Protocol 5.

Overlap detected between Protocol 5 and Protocol 6:
  Adjusting Post-Stim End Time of Protocol 5 by -11.10s
  New Post-Stim End Time of Protocol 5: 3847.28s
  Adjusting Pre-Stim Start Time of Protocol 6 by 11.10s
  New Pre-Stim Start Time of Protocol 6: 3847.28s

No overlap detected before Protocol 7.

No overlap detected before Protocol 8.

No overlap detected before Protocol 9.

No overlap detected before Protocol 10.

No overlap detected before Protocol 11.

Overlap detected between Protocol 11 and Protocol 12:
  Adjusting Post-Stim End Time of Protocol 11 by -92.54s
  New Post-Stim End Time of Protocol 11: 8769.47s
  Adjusting Pre-Stim Start Time of Protocol 12 by 92.54s
  New Pre-Stim Start Time of Protocol 12: 8769.47s

No overlap detected

In [8]:
#sw = yasa.sw_detect(raw, hypno=hypno_up, include=(2, 3))
sw = yasa.sw_detect(raw, freq_sw=(0.5, 4), verbose=False, coupling=False)
df = sw.summary(); # general summayb for each sw
df # Inspect the dataframe

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.4s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    1.6s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    3.3s


Unnamed: 0,Start,NegPeak,MidCrossing,PosPeak,End,Duration,ValNegPeak,ValPosPeak,PTP,Slope,Frequency,Channel,IdxChannel
0,62.96,63.75,63.82,63.95,64.10,1.14,-40.843038,108.218040,149.061078,2129.443975,0.877193,E2,0
1,83.02,83.14,83.41,83.49,83.54,0.52,-66.338851,15.897654,82.236505,304.579648,1.923077,E2,0
2,144.10,144.25,144.63,144.84,145.08,0.98,-42.309507,53.274203,95.583709,251.536077,1.020408,E2,0
3,157.41,157.53,157.82,157.91,157.99,0.58,-69.818186,11.151660,80.969846,279.206367,1.724138,E2,0
4,198.98,199.15,199.41,199.81,199.94,0.96,-53.234828,29.418601,82.653429,317.897804,1.041667,E2,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
232395,8013.42,8013.64,8014.00,8014.12,8014.60,1.18,-78.882369,29.605271,108.487640,301.354555,0.847458,E256,197
232396,8044.63,8044.79,8045.05,8045.15,8045.76,1.13,-62.821096,22.815097,85.636193,329.369973,0.884956,E256,197
232397,8094.95,8095.48,8095.57,8095.66,8095.76,0.81,-53.311399,54.516849,107.828248,1198.091649,1.234568,E256,197
232398,8355.53,8356.12,8356.22,8356.33,8357.06,1.53,-40.587707,34.672218,75.259925,752.599254,0.653595,E256,197


In [None]:
sw.plot_detection() # lets you scroll through the detection very conveniently

In [9]:
# Define the classification function
def classify_wave(start_time, pre_stim_epochs, stim_epochs, post_stim_epochs):
    """
    Classify each wave based on the start time into 'Pre-Stim', 'Stim', or 'Post-Stim' 
    and assign the protocol number.
    """
    for epochs, label in [(pre_stim_epochs, 'Pre-Stim'), 
                          (stim_epochs, 'Stim'), 
                          (post_stim_epochs, 'Post-Stim')]:
        for start, end, protocol in epochs:
            if start <= start_time <= end:
                return label, protocol
    return 'Unknown', None

# Apply classification to DataFrame 'df'
df[['Classification', 'Protocol Number']] = df['Start'].apply(
    lambda start_time: classify_wave(
        start_time,
        pre_stim_epochs, stim_epochs, post_stim_epochs
    )
).apply(pd.Series)

# Filter rows classified as 'Unknown'
df_filtered = df[df['Classification'] != 'Unknown'].reset_index(drop=True)


In [10]:
df

Unnamed: 0,Start,NegPeak,MidCrossing,PosPeak,End,Duration,ValNegPeak,ValPosPeak,PTP,Slope,Frequency,Channel,IdxChannel,Classification,Protocol Number
0,62.96,63.75,63.82,63.95,64.10,1.14,-40.843038,108.218040,149.061078,2129.443975,0.877193,E2,0,Unknown,
1,83.02,83.14,83.41,83.49,83.54,0.52,-66.338851,15.897654,82.236505,304.579648,1.923077,E2,0,Unknown,
2,144.10,144.25,144.63,144.84,145.08,0.98,-42.309507,53.274203,95.583709,251.536077,1.020408,E2,0,Unknown,
3,157.41,157.53,157.82,157.91,157.99,0.58,-69.818186,11.151660,80.969846,279.206367,1.724138,E2,0,Unknown,
4,198.98,199.15,199.41,199.81,199.94,0.96,-53.234828,29.418601,82.653429,317.897804,1.041667,E2,0,Unknown,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
232395,8013.42,8013.64,8014.00,8014.12,8014.60,1.18,-78.882369,29.605271,108.487640,301.354555,0.847458,E256,197,Unknown,
232396,8044.63,8044.79,8045.05,8045.15,8045.76,1.13,-62.821096,22.815097,85.636193,329.369973,0.884956,E256,197,Unknown,
232397,8094.95,8095.48,8095.57,8095.66,8095.76,0.81,-53.311399,54.516849,107.828248,1198.091649,1.234568,E256,197,Unknown,
232398,8355.53,8356.12,8356.22,8356.33,8357.06,1.53,-40.587707,34.672218,75.259925,752.599254,0.653595,E256,197,Pre-Stim,11.0


In [14]:
# Filter out rows where 'Protocol Number' is NaN
df_filtered = df.dropna(subset=['Protocol Number'])

# Generate the slow wave names
df_filtered['Slow_Wave_Name' ] = (
    'proto' + df_filtered['Protocol Number'].astype(int).astype(str) + '_' +
    df_filtered['Classification'].str.lower().str.replace(' ', '-') + '_sw' +
    (df_filtered.groupby(['Protocol Number', 'Classification']).cumcount() + 1).astype(str)
)

# Define the order for the 'Classification'
classification_order = ['pre-stim', 'stim', 'post-stim']

# Create a categorical type for sorting Classification
df_filtered['Classification'] = df_filtered['Classification'].str.lower().str.replace(' ', '-')
df_filtered['Classification'] = pd.Categorical(df_filtered['Classification'], categories=classification_order, ordered=True)

# Sort the DataFrame based on Protocol Number, Classification order, and Slow Wave Name
df_sorted = df_filtered.sort_values(by=['Protocol Number', 'Classification', 'Slow_Wave_Name'])

# Keep only the relevant columns
df_epochs = df_sorted[['Slow_Wave_Name', 'Start', 'End', 'Duration', 
                       'NegPeak', 'PosPeak', 'ValNegPeak', 'ValPosPeak', 
                       'Protocol Number', 'Classification', 'Channel']]

# Display the DataFrame with the new columns
print(df_epochs.head())

# Save the resulting DataFrame to a CSV file
df_epochs.to_csv('sorted_slow_waves.csv', index=False)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered['Slow_Wave_Name' ] = (
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered['Classification'] = df_filtered['Classification'].str.lower().str.replace(' ', '-')
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered['Classification'] = pd.Categorical(df_filtered['Classification

               Slow_Wave_Name   Start     End  Duration  NegPeak  PosPeak  \
6         proto1_pre-stim_sw1  389.76  390.50      0.74   389.86   390.24   
15       proto1_pre-stim_sw10  454.82  455.51      0.69   455.02   455.42   
9359    proto1_pre-stim_sw100  455.49  456.34      0.85   455.62   456.01   
84775  proto1_pre-stim_sw1000  356.65  357.44      0.79   356.86   357.13   
84776  proto1_pre-stim_sw1001  380.86  381.74      0.88   381.02   381.41   

       ValNegPeak  ValPosPeak  Protocol Number Classification Channel  
6      -45.324543   40.681746              1.0       pre-stim      E2  
15     -99.037309   60.448043              1.0       pre-stim      E2  
9359   -51.307500   76.539914              1.0       pre-stim     E14  
84775  -61.445404   57.653407              1.0       pre-stim    E123  
84776  -60.053981   23.689135              1.0       pre-stim    E123  


In [15]:
# Function to filter epochs based on window size and picking criteria
def filter_epochs(df, window_size, pick_most_negative=False):
    """
    Filters epochs based on a specified window size. Optionally picks the wave 
    with the most negative ValNegPeak within the window.

    Parameters:
        df (pd.DataFrame): Input DataFrame with epoch information.
        window_size (float): Minimum time gap (in seconds) between waves.
        pick_most_negative (bool): Whether to pick the wave with the most negative ValNegPeak.

    Returns:
        pd.DataFrame: Filtered DataFrame.
    """
    filtered_epochs_list = []
    last_end_time = -float('inf')
    current_window = []

    for _, row in df_sorted.iterrows():
        if row['Start'] > last_end_time + window_size:
            if current_window:
                selected_wave = (
                    min(current_window, key=lambda x: x['ValNegPeak']) 
                    if pick_most_negative else current_window[0]
                )
                filtered_epochs_list.append(selected_wave)
            current_window = [row]
            last_end_time = row['End']
        else:
            current_window.append(row)

    # Process the last window
    if current_window:
        selected_wave = (
            min(current_window, key=lambda x: x['ValNegPeak']) 
            if pick_most_negative else current_window[0]
        )
        filtered_epochs_list.append(selected_wave)

    # Convert list of dictionaries to DataFrame to retain all original columns
    return pd.DataFrame(filtered_epochs_list)

# Apply the filtering and save to CSV files, ensuring all relevant columns are included
windows = [0.5, 1.0]  # Window sizes in seconds
for window in windows:
    for pick_negative in [False, True]:
        suffix = "most_negative" if pick_negative else "first"
        filename = f'filtered_epochs_{int(window*1000)}ms_{suffix}.csv'
        filtered_epochs = filter_epochs(df, window_size=window, pick_most_negative=pick_negative)
        filtered_epochs.to_csv(filename, index=False)
        print(f"Filtered epochs saved to {filename}")


Filtered epochs saved to filtered_epochs_500ms_first.csv
Filtered epochs saved to filtered_epochs_500ms_most_negative.csv
Filtered epochs saved to filtered_epochs_1000ms_first.csv
Filtered epochs saved to filtered_epochs_1000ms_most_negative.csv


In [16]:

# Load the filtered epochs data
df_actual = pd.read_csv('filtered_epochs_500ms_most_negative.csv')

# Define the specific columns to analyze
columns_to_plot = ['Duration', 'ValNegPeak', 'ValPosPeak', 'PTP', 'Frequency']

# Define all possible classifications
all_classifications = ['pre-stim', 'stim', 'post-stim']

# Filter and group by Classification for overall analysis
comparison_means = df_actual.groupby('Classification')[columns_to_plot].mean().reindex(all_classifications, fill_value=0)
comparison_counts = df_actual['Classification'].value_counts().reindex(all_classifications, fill_value=0)

# Function to add value labels on bars
def add_value_labels(ax, spacing=5):
    """Add labels to the end of each bar in a bar chart."""
    for rect in ax.patches:
        y_value = rect.get_height()
        x_value = rect.get_x() + rect.get_width() / 2
        label = f"{y_value:.2f}" if y_value != 0 else "0"  # Use a single zero for labels with no decimal part
        ax.annotate(
            label, 
            (x_value, y_value), 
            xytext=(0, spacing), 
            textcoords="offset points", 
            ha='center', 
            va='bottom'
        )

# Figure 1: Overall comparison for all classifications combined, disregarding protocol number
plt.figure(figsize=(15, 6))

# Overall Mean Values Plot
plt.subplot(1, 2, 1)
ax = comparison_means.plot(kind='bar', ax=plt.gca(), color=['#6baed6', '#bdd7e7', '#eff3ff', '#fdbe85', '#fd8d3c'])
plt.title('Overall Mean Values of Wave Properties (All Protocols Combined)')
plt.ylabel('Mean Values')
plt.xlabel('Classification', labelpad=10)
plt.xticks(rotation=0)
plt.legend(title='Properties', loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
add_value_labels(ax)  # Add value labels

# Overall Count Plot
plt.subplot(1, 2, 2)
ax2 = comparison_counts.plot(kind='bar', color='#6baed6', ax=plt.gca())
plt.title('Overall Count of Instances by Classification (All Protocols Combined)')
plt.ylabel('Count')
plt.xlabel('Classification', labelpad=10)
plt.xticks(rotation=0)
add_value_labels(ax2)  # Add value labels

# Show the overall plot
plt.tight_layout()  # Adjusts subplots to give some padding and prevent overlap
plt.show()

# Now, let's create separate plots for each protocol
protocol_numbers = df_actual['Protocol Number'].dropna().unique()  # Get unique protocol numbers

for protocol in protocol_numbers:
    protocol_data = df_actual[df_actual['Protocol Number'] == protocol]
    protocol_means = protocol_data.groupby('Classification')[columns_to_plot].mean().reindex(all_classifications, fill_value=0)
    protocol_counts = protocol_data['Classification'].value_counts().reindex(all_classifications, fill_value=0)

    plt.figure(figsize=(15, 6))

    # Mean values for each protocol
    plt.subplot(1, 2, 1)
    ax = protocol_means.plot(kind='bar', ax=plt.gca(), color=['#6baed6', '#bdd7e7', '#eff3ff', '#fdbe85', '#fd8d3c'])
    plt.title(f'Mean Values of Wave Properties (Protocol {int(protocol)})')
    plt.ylabel('Mean Values')
    plt.xlabel('Classification', labelpad=10)
    plt.xticks(rotation=0)
    plt.legend(title='Properties', loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
    add_value_labels(ax)  # Add value labels

    # Counts for each protocol
    plt.subplot(1, 2, 2)
    ax2 = protocol_counts.plot(kind='bar', color='#6baed6', ax=plt.gca())
    plt.title(f'Count of Instances by Classification (Protocol {int(protocol)})')
    plt.ylabel('Count')
    plt.xlabel('Classification', labelpad=10)
    plt.xticks(rotation=0)
    add_value_labels(ax2)  # Add value labels

    # Show the plots for this protocol
    plt.tight_layout()
    plt.show()


In [17]:
# Load the CSV file into a DataFrame
events_df = pd.read_csv('filtered_epochs_500ms_most_negative.csv')

# Initialize lists to hold annotation data
onsets, durations, descriptions = [], [], []

# Populate the annotation lists for the main events
for _, row in events_df.iterrows():
    # Add the "Start" event
    onsets.append(row['Start'])  # Start time in seconds
    durations.append(0)  # Instantaneous event
    descriptions.append("Start")
    
    # Add the "NegPeak" event
    onsets.append(row['NegPeak'])  # Time of negative peak
    durations.append(0)  # Instantaneous event
    descriptions.append("NegPeak")
    
    # Add the "PosPeak" event
    onsets.append(row['PosPeak'])  # Time of positive peak
    durations.append(0)  # Instantaneous event
    descriptions.append("PosPeak")
    
    # Add the "End" event
    onsets.append(row['End'])  # End time in seconds
    durations.append(0)  # Instantaneous event
    descriptions.append("End")

# Create an MNE Annotations object
annotations = mne.Annotations(onset=onsets, duration=durations, description=descriptions)

# Add the annotations to the raw data
raw.set_annotations(annotations)

# Export the annotated raw data to an EEGLAB .set file
output_fname = 'annotated_raw.set'
mne.export.export_raw(output_fname, raw, fmt='eeglab')

# Verify the annotations
print("Available annotation descriptions:")
print(set(annotations.description))


Available annotation descriptions:
{'Start', 'PosPeak', 'End', 'NegPeak'}


In [18]:
fname = 'annotated_raw.set'
annotated= mne.io.read_raw_eeglab(fname, preload=True)
annotated
annotated.plot(clipping=None, duration=5)
annotated.info['sfreq']

Using matplotlib as 2D backend.


100.0

Channels marked as bad:
none


In [21]:

# Define input CSV and output directory
extraction_file = 'filtered_epochs_500ms_first.csv'
output_dir = 'epoch-imgs'
os.makedirs(output_dir, exist_ok=True)

# Load the extraction data into a DataFrame
extraction_df = pd.read_csv(extraction_file)

# Loop through each row in the DataFrame
for i, row in extraction_df.iterrows():
    # Convert 'Start', 'NegPeak', 'PosPeak', and 'End' times to sample indices
    start_sample = int(row['Start'] * sf)
    neg_peak_sample = int(row['NegPeak'] * sf)
    pos_peak_sample = int(row['PosPeak'] * sf)
    end_sample = int(row['End'] * sf)

    # Calculate tmin and tmax based on the 'Start' and 'End' times
    tmin = 0  # Start aligned at 0
    tmax = (end_sample - start_sample) / sf  # Duration from Start to End

    # Create an event array for the 'Start' event only
    events = np.array([[start_sample, 0, 1]])  # Event code 1 for 'Start'

    # Create the epoch using the 'Start' event
    try:
        epoch = mne.Epochs(
            raw, events, event_id={'Start': 1},
            tmin=tmin, tmax=tmax, baseline=None, preload=True
        )
    except ValueError as e:
        print(f"Error creating epoch for row {i}: {e}")
        continue

    # Convert NegPeak, PosPeak, and End times to seconds relative to the epoch start
    neg_peak_time = (neg_peak_sample - start_sample) / sf
    pos_peak_time = (pos_peak_sample - start_sample) / sf
    end_time = (end_sample - start_sample) / sf

    # Retrieve the selected channel and slow wave name
    selected_channel = row['Channel']
    slow_wave_name = row['Slow_Wave_Name']

    # Validate the selected channel
    if selected_channel not in raw.ch_names:
        print(f"Channel {selected_channel} not found in raw data for row {i}. Skipping.")
        continue

    # Get the index of the selected channel
    channel_idx = raw.ch_names.index(selected_channel)

    # Create a plot for the epoch without opening the image
    fig, ax = plt.subplots(figsize=(10, 4))

    # Plot the data for the selected channel
    epoch_data = epoch.get_data(picks=[channel_idx])[0, 0, :]  # Get data for the selected channel
    times = epoch.times  # Times corresponding to the epoch

    ax.plot(times, epoch_data, color='black', label='EEG Signal')

    # Add vertical dashed lines for Start, NegPeak, PosPeak, and End
    ax.axvline(0, color='black', linestyle='--', label='Start')  # Start event
    ax.axvline(neg_peak_time, color='red', linestyle='--', label='NegPeak')
    ax.axvline(pos_peak_time, color='green', linestyle='--', label='PosPeak')
    ax.axvline(end_time, color='blue', linestyle='--', label='End')

    # Set labels and title
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Amplitude (V)')
    ax.set_title(f"{slow_wave_name} // {selected_channel}")
    ax.legend(loc='upper right')

    # Save the figure to the output directory without opening it
    output_path = os.path.join(output_dir, f'{slow_wave_name}_{selected_channel}.png')
    fig.savefig(output_path)
    plt.close(fig)  # Close the figure to save memory

print("Processing complete. All epoch images were saved in the 'epoch-imgs' directory.")


Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1 events and 75 original time points ...
0 bad epochs dropped
Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1 events and 70 original time points ...
0 bad epochs dropped
Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1 events and 129 original time points ...
0 bad epochs dropped
Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1 events and 68 original time points ...
0 bad epochs dropped
Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1 events and 72 original time points ...
0 bad epochs dr

In [22]:
# Directory where you want to save the files
output_dir = "epoch-data"
os.makedirs(output_dir, exist_ok=True)

# Sampling frequency from the MNE raw object
sfreq = raw.info['sfreq']

# Define the baseline duration in seconds
baseline_duration = 0.2  # 0.2 seconds
baseline_samples = int(baseline_duration * sfreq)

# Loop through each row in the filtered CSV DataFrame
for i in range(len(extraction_df)):
    # Extract data and info for the ith epoch
    row = extraction_df.iloc[i]
    
    start_time = row['Start']
    neg_peak_time = row['NegPeak']
    pos_peak_time = row['PosPeak']
    end_time = row['End']

    # Calculate the new start time with baseline
    new_start_time = start_time - baseline_duration
    
    # Ensure that the new start time is not negative
    if new_start_time < 0:
        new_start_time = 0

    # Convert times to samples
    new_start_sample = int(new_start_time * sfreq)
    start_sample = int(start_time * sfreq)
    end_sample = int(end_time * sfreq)
    
    # Extract the epoch data with the baseline
    epoch_data = raw.get_data(start=new_start_sample, stop=end_sample + 1)
    info = raw.info

    # Create RawArray from epoch data
    raw_epoch = mne.io.RawArray(epoch_data, info)
    
    # Calculate the event sample points relative to the new start time
    neg_peak_sample = int((neg_peak_time - new_start_time) * sfreq)
    pos_peak_sample = int((pos_peak_time - new_start_time) * sfreq)
    start_event_sample = baseline_samples  # Start event is at baseline

    # Create events array
    events = np.array([
        [start_event_sample, 0, 1],  # Start event after the baseline
        [neg_peak_sample, 0, 2],     # NegPeak event
        [pos_peak_sample, 0, 3],     # PosPeak event
    ])
    
    # Add events to annotations
    annotations = mne.Annotations(onset=events[:, 0] / sfreq,
                                  duration=[0, 0, 0],
                                  description=['Start', 'NegPeak', 'PosPeak'])
    raw_epoch.set_annotations(annotations)
    
    # Get the epoch name and channel from the CSV file
    epoch_name = row['Slow_Wave_Name']
    channel_name = row['Channel']
    
    # Naming scheme for the file: "Slow_Wave_Name_Channel"
    file_name = f"{epoch_name}_{channel_name}.set"
    
    # Save as .set file directly in the output directory
    epoch_file = os.path.join(output_dir, file_name)
    mne.export.export_raw(epoch_file, raw_epoch, fmt='eeglab')
    print(f"Saved epoch '{file_name}' to {epoch_file}")

print("All epochs have been successfully saved as .set files in the 'epoch-data' directory.")

Creating RawArray with float64 data, n_channels=198, n_times=95
    Range : 0 ... 94 =      0.000 ...     0.940 secs
Ready.
Saved epoch 'proto1_pre-stim_sw1_E2.set' to epoch-data/proto1_pre-stim_sw1_E2.set
Creating RawArray with float64 data, n_channels=198, n_times=90
    Range : 0 ... 89 =      0.000 ...     0.890 secs
Ready.
Saved epoch 'proto1_pre-stim_sw10_E2.set' to epoch-data/proto1_pre-stim_sw10_E2.set
Creating RawArray with float64 data, n_channels=198, n_times=149
    Range : 0 ... 148 =      0.000 ...     1.480 secs
Ready.
Saved epoch 'proto1_pre-stim_sw101_E14.set' to epoch-data/proto1_pre-stim_sw101_E14.set
Creating RawArray with float64 data, n_channels=198, n_times=88
    Range : 0 ... 87 =      0.000 ...     0.870 secs
Ready.
Saved epoch 'proto1_pre-stim_sw1013_E123.set' to epoch-data/proto1_pre-stim_sw1013_E123.set
Creating RawArray with float64 data, n_channels=198, n_times=92
    Range : 0 ... 91 =      0.000 ...     0.910 secs
Ready.
Saved epoch 'proto1_pre-stim_sw1

In [20]:
# Assuming df_actual is already loaded and contains the necessary data
# Load raw EEG data (if not already loaded)
raw = mne.io.read_raw_eeglab('annotated_raw.set', preload=True)
sf = raw.info['sfreq']

# Set the fixed waveform length in seconds
waveform_length = 1  # 2 seconds
n_samples = int(waveform_length * sf)  # Convert to samples

# Create a dictionary to store waveforms for each stim condition
waveforms_by_condition = {}

for stim_condition in df_actual['Classification'].unique():
    condition_data = df_actual[df_actual['Classification'] == stim_condition]
    waveforms = []
    for _, row in condition_data.iterrows():
        # Extract start sample
        start_sample = int(row['Start'] * sf)
        end_sample = start_sample + n_samples  # Ensure fixed length of 2 seconds

        # Validate that the end_sample doesn't exceed the data limits
        if end_sample > raw.n_times:
            continue  # Skip this waveform if it goes beyond the data length

        # Extract the waveform for the specific channel
        channel_idx = raw.ch_names.index(row['Channel'])
        waveform = raw.get_data(picks=[channel_idx], start=start_sample, stop=end_sample).flatten()

        # Convert waveform from V to uV
        waveform_uV = waveform * 1e6  # Conversion from V to uV

        # Center the waveform relative to its start
        waveform_centered = waveform_uV - waveform_uV[0]
        waveforms.append(waveform_centered)

    waveforms_by_condition[stim_condition] = np.array(waveforms)

# Plot the averages
sns.set(style="whitegrid")
plt.figure(figsize=(12, 6))
time = np.linspace(0, waveform_length, n_samples)  # Time vector

colors = {'pre-stim': 'blue', 'stim': 'orange', 'post-stim': 'green'}
for stim_condition, waveforms in waveforms_by_condition.items():
    if waveforms.size > 0:
        avg_waveform = waveforms.mean(axis=0)
        std_waveform = waveforms.std(axis=0)
        plt.plot(time, avg_waveform, label=f'{stim_condition.capitalize()} (n={waveforms.shape[0]})', color=colors[stim_condition])
        plt.fill_between(time, avg_waveform - std_waveform, avg_waveform + std_waveform, alpha=0.2, color=colors[stim_condition])

plt.axvline(0, color='black', linestyle='--', label='Start')
plt.title('Average Waveforms Per Stimulation Condition (Centered at Start)')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude (uV)')  # Corrected unit label
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()

