### The objective

create a short and sweet notebook that takes in EEGLAB preprocessed data, and outputs SW statistics

## Some questions to look into with time

- why some sw seem to be detected by multiple channels while other sw are detected by a single ch?
- how can we look at "traveling" vs focal sw? i.e Type I vs Type II
- is there specific sw features that are correlated with their distribution?

- are there more sw detected in frontal ch?

In [1]:
#Make sure you have eeglabio for saving the Epochs
#pip install eeglabio


# Housekeeping imports
import os
import mne
from mne.io import RawArray
import numpy as np
import pandas as pd
import yasa
import matplotlib.pyplot as plt  # Make sure to import pyplot as plt for plotting
import statsmodels.api as sm
import ipywidgets as widgets  # Renamed for clarity

# Activate interactive figures with %matplotlib qt (useful in Jupyter environments)
%matplotlib qt


### Load data
* Change the io.methodX based on the EEG file type you are trying to load
* `preload` lets you keep the data in memory and manipulate it in different cells.

In [2]:
fname = '/Users/idohaber/Desktop/Paper_dir/Source_test/1_Functional_Data/0.5-6_full_NREM.set'
raw = mne.io.read_raw_eeglab(fname, preload=True);
#raw.filter(0.5, 30, fir_design='firwin')  # Adjust the frequency range as needed
raw

Reading /Users/idohaber/Desktop/Paper_dir/Source_test/1_Functional_Data/0.5-6_full_NREM.fdt
Reading 0 ... 5369842  =      0.000 ... 10739.684 secs...


  raw = mne.io.read_raw_eeglab(fname, preload=True);
  raw = mne.io.read_raw_eeglab(fname, preload=True);


0,1
Measurement date,Unknown
Experimenter,Unknown
Participant,Unknown

0,1
Digitized points,197 points
Good channels,194 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available

0,1
Sampling frequency,500.00 Hz
Highpass,0.00 Hz
Lowpass,250.00 Hz
Filenames,0.5-6_full_NREM.fdt
Duration,02:58:60 (HH:MM:SS)


In [3]:
# prepare the data for processing
data = raw.get_data(units="uV") 
raw.resample(100)
sf = raw.info['sfreq']
print(data.shape , sf)

(194, 5369843) 100.0


In [4]:
events = mne.events_from_annotations(raw)  # raw events
events_id = events[-1]                     # grab event dict
actual_events = events[:-1][0]             # grab actual events
print(events_id, '\n') 
print(actual_events)

Used Annotations descriptions: ['Sleep Stage', 'boundary', 'stim end', 'stim start']
{'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4} 

[[      0       0       2]
 [    511       0       1]
 [   3511       0       1]
 ...
 [1068613       0       1]
 [1071613       0       1]
 [1073968       0       2]]


In [5]:
# Extract events and event IDs from annotations in the raw data
events, event_id = mne.events_from_annotations(raw)

print(event_id, '\n')
print(events)

# Dictionary mapping event descriptions to numerical codes
column_dict = {'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4}

# Indices for 'stim end' and 'stim start'
stim_end_index = column_dict['stim end']
stim_start_index = column_dict['stim start']

# Filter events to get only stim start and stim end
filtered_data = [item for item in events if item[2] == stim_end_index or item[2] == stim_start_index]

# Minimum stim duration threshold in seconds (example: 100 seconds)
min_stim_duration_sec = 100
min_stim_duration_samples = int(min_stim_duration_sec * sf)

# Initialize lists for pre-stim, early stim, late stim, and post-stim epochs with protocol numbers
pre_stim_epochs = []
early_stim_epochs = []
late_stim_epochs = []
post_stim_epochs = []

# Previous epoch end to check for overlap
previous_end = 0

# Protocol counter
protocol_number = 1

# Loop through the epochs and define pre-stim, early stim, late stim, and post-stim epochs
for i in range(0, len(filtered_data), 2):
    if i + 1 < len(filtered_data):  # Ensure i+1 is within bounds
        stim_start = filtered_data[i][0]
        stim_end = filtered_data[i+1][0]
        stim_duration = stim_end - stim_start  # Calculate stim duration
        
        if stim_duration < min_stim_duration_samples:
            continue  # Skip this stim epoch if it is shorter than the minimum duration
        
        stim_midpoint = (stim_start + stim_end) // 2  # Calculate the midpoint of the stim epoch
        
        pre_stim_epoch = (stim_start - stim_duration, stim_start, protocol_number)
        early_stim_epoch = (stim_start, stim_midpoint, protocol_number)
        late_stim_epoch = (stim_midpoint, stim_end, protocol_number)
        post_stim_epoch = (stim_end, stim_end + stim_duration, protocol_number)
        
        # Check for overlap with the previous epoch
        if pre_stim_epoch[0] < previous_end:
            continue  # Skip this entire protocol if there's an overlap with the previous one
        
        # Update previous_end to the end of the current post-stim epoch
        previous_end = post_stim_epoch[1]
        
        pre_stim_epochs.append(pre_stim_epoch)
        early_stim_epochs.append(early_stim_epoch)
        late_stim_epochs.append(late_stim_epoch)
        post_stim_epochs.append(post_stim_epoch)
        
        # Increment protocol number
        protocol_number += 1

# Convert epochs to time for plotting
def convert_sample_to_time(epochs, sf):
    return [(start / sf, end / sf) for start, end, protocol in epochs]

pre_stim_epochs_time = convert_sample_to_time(pre_stim_epochs, sf)
early_stim_epochs_time = convert_sample_to_time(early_stim_epochs, sf)
late_stim_epochs_time = convert_sample_to_time(late_stim_epochs, sf)
post_stim_epochs_time = convert_sample_to_time(post_stim_epochs, sf)

# Optionally store or print protocol numbers for reference
protocol_numbers = [epoch[2] for epoch in pre_stim_epochs]  # Collect protocol numbers (only need to do it once)
print(f"Protocol Numbers: {protocol_numbers}")

# Plotting
plt.figure(figsize=(10, 6))

# Plot pre-stim epochs as shaded regions
for (start, end) in pre_stim_epochs_time:
    plt.axvspan(start, end, color='blue', alpha=0.3, label='Pre-Stim')

# Plot early stim epochs as shaded regions and add protocol numbers
for i, (start, end) in enumerate(early_stim_epochs_time):
    plt.axvspan(start, end, color='orange', alpha=0.3, label='Early Stim')
    # Add the protocol number only for early stim, to represent the entire protocol
    plt.text((start + end) / 2, 0.2, f'P{protocol_numbers[i]}', color='black', fontsize=10, ha='center')

# Plot late stim epochs as shaded regions
for (start, end) in late_stim_epochs_time:
    plt.axvspan(start, end, color='red', alpha=0.3, label='Late Stim')

# Plot post-stim epochs as shaded regions
for (start, end) in post_stim_epochs_time:
    plt.axvspan(start, end, color='green', alpha=0.3, label='Post-Stim')

# Set labels and title
plt.xlabel('Time (s)')
plt.ylabel('Epoch Type')
plt.title('Stimulation Protocols Visualization')

# To prevent duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())

plt.show()




Used Annotations descriptions: ['Sleep Stage', 'boundary', 'stim end', 'stim start']
{'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4} 

[[      0       0       2]
 [    511       0       1]
 [   3511       0       1]
 ...
 [1068613       0       1]
 [1071613       0       1]
 [1073968       0       2]]
Protocol Numbers: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]


In [6]:
#sw = yasa.sw_detect(raw, hypno=hypno_up, include=(2, 3))
sw = yasa.sw_detect(raw, verbose=False, coupling=False);
df = sw.summary(); # general summary for each sw
df # Inspect the dataframe

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    1.6s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    3.8s


Unnamed: 0,Start,NegPeak,MidCrossing,PosPeak,End,Duration,ValNegPeak,ValPosPeak,PTP,Slope,Frequency,Channel,IdxChannel
0,273.76,274.03,274.86,275.16,275.65,1.89,-41.824487,33.425113,75.249599,90.662168,0.529101,E1,0
1,545.35,545.66,545.90,546.19,546.51,1.16,-40.937927,55.789954,96.727881,403.032836,0.862069,E1,0
2,1209.77,1210.18,1210.51,1210.79,1211.14,1.37,-50.271566,32.988102,83.259668,252.302023,0.729927,E1,0
3,1211.14,1211.49,1211.75,1212.06,1212.71,1.57,-41.000249,47.190866,88.191115,339.196596,0.636943,E1,0
4,1212.71,1213.75,1214.07,1214.41,1214.84,2.13,-48.924268,27.038476,75.962744,237.383574,0.469484,E1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
7752,8782.88,8783.18,8783.44,8783.62,8783.85,0.97,-50.691698,25.743016,76.434714,293.979671,1.030928,E253,193
7753,9180.13,9180.50,9180.77,9181.10,9181.55,1.42,-54.312763,79.285652,133.598416,494.808946,0.704225,E253,193
7754,9962.47,9963.12,9963.38,9963.70,9964.08,1.61,-45.219392,69.214942,114.434334,440.132054,0.621118,E253,193
7755,9964.08,9964.53,9965.31,9965.70,9966.23,2.15,-48.162902,68.680883,116.843785,149.799724,0.465116,E253,193


In [None]:
# sw.plot_detection() # lets you scroll through the detection very conveniently

In [7]:
# Define the classification function
def classify_wave(start_time, pre_stim_epochs_time, early_stim_epochs_time, late_stim_epochs_time, post_stim_epochs_time):
    """Classify each wave based on the start time into 'Pre-Stim', 'Early-Stim', 'Late-Stim', or 'Post-Stim' and assign protocol number."""
    for idx, (start, end) in enumerate(pre_stim_epochs_time):
        if start <= start_time <= end:
            return 'Pre-Stim', idx + 1
    for idx, (start, end) in enumerate(early_stim_epochs_time):
        if start <= start_time <= end:
            return 'Early-Stim', idx + 1
    for idx, (start, end) in enumerate(late_stim_epochs_time):
        if start <= start_time <= end:
            return 'Late-Stim', idx + 1
        
    for idx, (start, end) in enumerate(post_stim_epochs_time):
        if start <= start_time <= end:
            return 'Post-Stim', idx + 1
    return 'Unknown', None  # If the wave does not fall within any of the epochs

# Apply classification to DataFrame
df[['Classification', 'Protocol Number']] = df['Start'].apply(lambda start_time: classify_wave(start_time, pre_stim_epochs_time, early_stim_epochs_time, late_stim_epochs_time, post_stim_epochs_time)).apply(pd.Series)

# Filter out rows classified as 'Unknown'
df_filtered = df[df['Classification'] != 'Unknown']

# Now df_filtered contains both the classification and the protocol number for each wave.


In [8]:

# Group by classification and calculate mean and count for each group
comparison_means = df_filtered.groupby('Classification')[['Duration', 'ValNegPeak', 'ValPosPeak', 'PTP', 'Frequency']].mean()
comparison_counts = df_filtered.groupby('Classification')['Start'].count()  # Counting instances using the 'Start' column

# Print results
print("Mean Values by Group:")
print(comparison_means)
print("\nCount of Instances by Group:")
print(comparison_counts)


Mean Values by Group:
                Duration  ValNegPeak  ValPosPeak         PTP  Frequency
Classification                                                         
Early-Stim      1.557771  -54.866958   49.426493  104.293451   0.666882
Late-Stim       1.594121  -54.061757   50.527821  104.589577   0.657163
Post-Stim       1.464050  -53.450962   46.468855   99.919817   0.718585
Pre-Stim        1.496005  -52.440277   46.623603   99.063881   0.700583

Count of Instances by Group:
Classification
Early-Stim     933
Late-Stim     1109
Post-Stim     2020
Pre-Stim      1457
Name: Start, dtype: int64


In [None]:
df #look at the frame to make sure classification added properly 

In [None]:
# Assuming df is your original DataFrame with relevant columns

# Define the specific columns you want to plot
columns_to_plot = ['Duration', 'ValNegPeak', 'ValPosPeak', 'PTP', 'Frequency']

# Define all possible classifications except 'Unknown'
all_classifications = ['Early-Stim', 'Post-Stim', 'Pre-Stim', 'Late-Stim']

# Filter out 'Unknown' classification from the DataFrame
df_filtered = df[df['Classification'].isin(all_classifications)]

# Calculate overall means and counts, including all classifications except 'Unknown'
overall_means = df_filtered.groupby('Classification')[columns_to_plot].mean().reindex(all_classifications, fill_value=0)
overall_counts = df_filtered['Classification'].value_counts().reindex(all_classifications, fill_value=0)

# Function to add value labels on bars
def add_value_labels(ax, spacing=5):
    """Add labels to the end of each bar in a bar chart."""
    for rect in ax.patches:
        y_value = rect.get_height()
        x_value = rect.get_x() + rect.get_width() / 2
        label = f"{y_value:.2f}" if y_value != 0 else "0"  # Use a single zero for labels with no decimal part
        ax.annotate(
            label, 
            (x_value, y_value), 
            xytext=(0, spacing), 
            textcoords="offset points", 
            ha='center', 
            va='bottom'
        )

# Figure 1: Overall comparison for all classifications combined, disregarding protocol number
plt.figure(figsize=(15, 6))

# Overall Mean Values Plot
plt.subplot(1, 2, 1)
ax = overall_means.plot(kind='bar', ax=plt.gca(), color=['#6baed6', '#bdd7e7', '#eff3ff', '#fdbe85', '#fd8d3c'])
plt.title('Overall Mean Values of Wave Properties (All Protocols Combined)')
plt.ylabel('Mean Values')
plt.xlabel('Classification', labelpad=10)
plt.xticks(rotation=0)
plt.legend(title='Properties', loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
add_value_labels(ax)  # Add value labels

# Overall Count Plot
plt.subplot(1, 2, 2)
ax2 = overall_counts.plot(kind='bar', color='#6baed6', ax=plt.gca())
plt.title('Overall Count of Instances by Classification (All Protocols Combined)')
plt.ylabel('Count')
plt.xlabel('Classification', labelpad=10)
plt.xticks(rotation=0)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
add_value_labels(ax2)  # Add value labels

# Show the overall plot
plt.tight_layout()  # Adjusts subplots to give some padding and prevent overlap
plt.show()

# Now, let's create separate plots for each protocol
protocol_numbers = df_filtered['Protocol Number'].dropna().unique()  # Get unique protocol numbers

for protocol in protocol_numbers:
    protocol_data = df_filtered[df_filtered['Protocol Number'] == protocol]
    protocol_means = protocol_data.groupby('Classification')[columns_to_plot].mean().reindex(all_classifications, fill_value=0)
    protocol_counts = protocol_data['Classification'].value_counts().reindex(all_classifications, fill_value=0)

    plt.figure(figsize=(15, 6))

    # Mean values for each protocol
    plt.subplot(1, 2, 1)
    ax = protocol_means.plot(kind='bar', ax=plt.gca(), color=['#6baed6', '#bdd7e7', '#eff3ff', '#fdbe85', '#fd8d3c'])
    plt.title(f'Mean Values of Wave Properties (Protocol {protocol})')
    plt.ylabel('Mean Values')
    plt.xlabel('Classification', labelpad=10)
    plt.xticks(rotation=0)
    plt.legend(title='Properties', loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
    add_value_labels(ax)  # Add value labels

    # Counts for each protocol
    plt.subplot(1, 2, 2)
    ax2 = protocol_counts.plot(kind='bar', color='#6baed6', ax=plt.gca())
    plt.title(f'Count of Instances by Classification (Protocol {protocol})')
    plt.ylabel('Count')
    plt.xlabel('Classification', labelpad=10)
    plt.xticks(rotation=0)
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside
    add_value_labels(ax2)  # Add value labels

    # Show the plots for this protocol
    plt.tight_layout()
    plt.show()


In [9]:
'''
creates a CSV will all instances of SW and sorts them in order of importance:
Proto#
pre -> early -> late -> post
Wave#
'''

# Assuming 'df' is your DataFrame
# Filter out rows where 'Protocol Number' is NaN
df_filtered = df.dropna(subset=['Protocol Number'])

# Generate the slow wave names
df_filtered['Slow_Wave_Name'] = 'proto' + df_filtered['Protocol Number'].astype(int).astype(str) + '_' + df_filtered['Classification'].str.lower().replace(' ', '-') + '_sw' + (df_filtered.groupby(['Protocol Number', 'Classification']).cumcount() + 1).astype(str)

# Define the order for the 'Classification'
classification_order = ['pre-stim', 'early-stim', 'late-stim', 'post-stim']

# Create a categorical type for sorting Classification
df_filtered['Classification'] = df_filtered['Classification'].str.lower().replace(' ', '-')
df_filtered['Classification'] = pd.Categorical(df_filtered['Classification'], categories=classification_order, ordered=True)

# Sort the DataFrame based on Protocol Number, Classification order, and slow wave number (extracted from Slow_Wave_Name)
df_sorted = df_filtered.sort_values(by=['Protocol Number', 'Classification', 'Slow_Wave_Name'])

# Keep only the relevant columns
df_epochs = df_sorted[['Slow_Wave_Name', 'Start', 'End', 'Protocol Number', 'Classification', 'NegPeak', 'PosPeak', 'Channel', 'ValNegPeak']]

# Display the DataFrame with the new columns
print(df_epochs.head())

# Save the resulting DataFrame to a CSV file
df_epochs.to_csv('sorted_slow_waves.csv', index=False)


           Slow_Wave_Name   Start     End  Protocol Number Classification  \
149   proto1_pre-stim_sw1  168.91  169.98              1.0       pre-stim   
516  proto1_pre-stim_sw10  162.20  163.75              1.0       pre-stim   
696  proto1_pre-stim_sw11  168.89  169.99              1.0       pre-stim   
832  proto1_pre-stim_sw12  168.92  169.88              1.0       pre-stim   
833  proto1_pre-stim_sw13  199.70  201.33              1.0       pre-stim   

     NegPeak  PosPeak Channel  ValNegPeak  
149   169.16   169.65      E3  -49.127846  
516   162.58   163.07     E10  -48.868772  
696   169.14   169.63     E11  -50.249224  
832   169.15   169.60     E12  -43.986929  
833   200.61   201.07     E12  -42.908749  


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered['Slow_Wave_Name'] = 'proto' + df_filtered['Protocol Number'].astype(int).astype(str) + '_' + df_filtered['Classification'].str.lower().replace(' ', '-') + '_sw' + (df_filtered.groupby(['Protocol Number', 'Classification']).cumcount() + 1).astype(str)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered['Classification'] = df_filtered['Classification'].str.lower().replace(' ', '-')
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value 

In [10]:
# Assuming df_epochs is your sorted DataFrame from the previous steps

# Function to filter epochs based on window size and picking criteria
def filter_epochs(df, window_size, pick_most_negative=False):
    filtered_epochs_list = []
    last_end_time = -float('inf')
    current_window = []

    for index, row in df.iterrows():
        if row['Start'] > last_end_time + window_size:
            if current_window:
                if pick_most_negative:
                    # Find the wave with the most negative ValNegPeak within the current window
                    max_wave = min(current_window, key=lambda x: x['ValNegPeak'])
                    filtered_epochs_list.append(max_wave)
                else:
                    # Pick the first wave in the window
                    filtered_epochs_list.append(current_window[0])
            # Reset the current window and add the current row to the new window
            current_window = [row]
            # Update the last_end_time to the current row's end time
            last_end_time = row['End']
        else:
            # Add the current row to the current window
            current_window.append(row)

    # After the loop, check if there are any remaining waves in the current window
    if current_window:
        if pick_most_negative:
            max_wave = min(current_window, key=lambda x: x['ValNegPeak'])
            filtered_epochs_list.append(max_wave)
        else:
            filtered_epochs_list.append(current_window[0])

    return pd.DataFrame(filtered_epochs_list, columns=df.columns)

# 1. 0.5s window, pick the first wave
filtered_epochs_05s_first = filter_epochs(df_epochs, window_size=0.5, pick_most_negative=False)
filtered_epochs_05s_first.to_csv('filtered_epochs_05s_first.csv', index=False)

# 2. 0.5s window, pick the wave with the most negative ValNegPeak
filtered_epochs_05s_negative = filter_epochs(df_epochs, window_size=0.5, pick_most_negative=True)
filtered_epochs_05s_negative.to_csv('filtered_epochs_05s_negative.csv', index=False)

# 3. 1s window, pick the first wave
filtered_epochs_1s_first = filter_epochs(df_epochs, window_size=1.0, pick_most_negative=False)
filtered_epochs_1s_first.to_csv('filtered_epochs_1s_first.csv', index=False)

# 4. 1s window, pick the wave with the most negative ValNegPeak
filtered_epochs_1s_negative = filter_epochs(df_epochs, window_size=1.0, pick_most_negative=True)
filtered_epochs_1s_negative.to_csv('filtered_epochs_1s_negative.csv', index=False)

print("Filtered CSV files created successfully.")


Filtered CSV files created successfully.


In [11]:
# Load the CSV file into a DataFrame
events_df = pd.read_csv('filtered_epochs_1s_negative.csv')

# Initialize lists to hold annotation data
onsets = []
durations = []
descriptions = []

# Populate the annotation lists for the main events
for index, row in events_df.iterrows():
    # Add the "Start" event
    onsets.append(row['Start'])  # The start time of the event in seconds
    durations.append(row['End'] - row['Start'])  # The duration of the event
    descriptions.append("Start")
    
    # Add the peakNeg as an instantaneous event
    onsets.append(row['NegPeak'])
    durations.append(0)  # Instantaneous event
    descriptions.append("NegPeak")
    
    # Add the peakPos as an instantaneous event
    onsets.append(row['PosPeak'])
    durations.append(0)  # Instantaneous event
    descriptions.append("PosPeak")
    
    # Add the "End" event
    onsets.append(row['End'])
    durations.append(0)  # Instantaneous event for end marker
    descriptions.append("End")

# Create an MNE Annotations object
annotations = mne.Annotations(onset=onsets, duration=durations, description=descriptions)

# Add the annotations to the raw data
raw.set_annotations(annotations)

# Export the annotated raw data to an EEGLAB .set file
output_fname = 'annotated_raw.set'
mne.export.export_raw(output_fname, raw, fmt='eeglab')




In [None]:
# Print available annotation descriptions to verify
print("Available annotation descriptions:")
print(set(annotations.description))


In [21]:
import os


# Assuming `filtered_epochs_1s_negative` is your filtered DataFrame and `raw` is your MNE Raw object
# Also assuming `sf` is the sampling frequency of your data

# Create the directory to save images if it doesn't exist
output_dir = 'epoch-imgs'
os.makedirs(output_dir, exist_ok=True)

# Ensure the loop goes through each row in the DataFrame
for i in range(len(filtered_epochs_1s_negative)):
    # Access the specific row
    row = filtered_epochs_1s_negative.iloc[i]

    # Convert 'Start', 'NegPeak', 'PosPeak', and 'End' times to sample indices for the current row
    start_sample = int(row['Start'] * sf)
    neg_peak_sample = int(row['NegPeak'] * sf)
    pos_peak_sample = int(row['PosPeak'] * sf)
    end_sample = int(row['End'] * sf)

    # Add 0.2 seconds before the start of the epoch for the baseline
    baseline_sample = int(0.2 * sf)
    new_start_sample = start_sample - baseline_sample

    # Ensure that the new start time is not negative
    if new_start_sample < 0:
        new_start_sample = 0

    # Calculate tmax based on the end event
    tmax = (end_sample - new_start_sample) / sf
    tmin = -0.2

    # Create an event array for the 'Start' event only
    events = np.array([
        [start_sample, 0, 1]  # Start event at t=0 (event_id = 1)
    ])

    # Create the epoch using the 'Start' event
    epoch = mne.Epochs(raw, events, event_id={'Start': 1},
                        tmin=tmin, tmax=tmax, baseline=(tmin, 0), preload=True)

    # Convert NegPeak and PosPeak times from sample indices to seconds relative to the epoch start
    neg_peak_time = (neg_peak_sample - start_sample) / sf
    pos_peak_time = (pos_peak_sample - start_sample) / sf

    # Retrieve the selected channel and slow wave name
    selected_channel = row['Channel']
    slow_wave_name = row['Slow_Wave_Name']

    # Create the title using the desired format
    title = f"{slow_wave_name} // {selected_channel}"

    # Plot the epoch image for the selected channel
    fig = epoch.plot_image(picks=[selected_channel], title=title, show=False)
    
    for ax in fig[0].axes:
        if ax.get_label() != '<colorbar>':  # Exclude the colorbar
            ax.axvline(neg_peak_time, color='red', linestyle='--', label='NegPeak')
            ax.axvline(pos_peak_time, color='green', linestyle='--', label='PosPeak')

    # Save the figure to the output directory with the specific naming format
    output_path = os.path.join(output_dir, f'{slow_wave_name}_{selected_channel}.png')
    fig[0].savefig(output_path)
    plt.close(fig[0])  # Close the figure to save memory

print("Processing complete. All epoch images were saved in the 'epoch-imgs' directory.")


Not setting metadata
1 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 1 events and 151 original time points ...
0 bad epochs dropped
Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
1 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 1 events and 204 original time points ...
0 bad epochs dropped
Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
1 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 1 events and 246 original time points ...
0 bad epochs dropped
Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Not setting metadata
1 matching events found
A

In [None]:
import os

# Directory where you want to save the files
output_dir = "output_epochs_set"
os.makedirs(output_dir, exist_ok=True)

# Sampling frequency
sfreq = epochs.info['sfreq']

# Loop through each epoch in the MNE epochs object
for i in range(len(epochs)):
    # Extract data and info
    epoch_data = epochs[i].get_data()[0]  # Get data for the ith epoch
    info = epochs.info

    # Create RawArray from epoch data
    raw_epoch = mne.io.RawArray(epoch_data, info)
    
    # Create events array
    start_sample = 0  # The start of the epoch
    neg_peak_sample = int((filtered_epochs.iloc[i]['NegPeak'] - filtered_epochs.iloc[i]['Start']) * sfreq)
    pos_peak_sample = int((filtered_epochs.iloc[i]['PosPeak'] - filtered_epochs.iloc[i]['Start']) * sfreq)
    
    events = np.array([
        [start_sample, 0, 1],  # Start event
        [neg_peak_sample, 0, 2],  # NegPeak event
        [pos_peak_sample, 0, 3],  # PosPeak event
    ])
    
    # Add events to annotations
    annotations = mne.Annotations(onset=events[:, 0] / sfreq,
                                  duration=[0, 0, 0],
                                  description=['Start', 'NegPeak', 'PosPeak'])
    raw_epoch.set_annotations(annotations)
    
    # Get the epoch name from the CSV file
    epoch_name = filtered_epochs.iloc[i]['Slow_Wave_Name']
    
    # Save as .set file
    epoch_file = os.path.join(output_dir, f"{epoch_name}.set")
    mne.export.export_raw(epoch_file, raw_epoch, fmt='eeglab')
    print(f"Saved epoch '{epoch_name}' to {epoch_file}")


######################
######################
######################
######################
######################
######################
######################
######################


In [None]:
# Assuming df is your DataFrame name
print("Descriptive Statistics for Start Times:")
print(df['Start'].describe())

print("\nDescriptive Statistics for Slope:")
print(df['Slope'].describe())

correlation = df['Start'].corr(df['Slope'])
print("Correlation coefficient between 'Start' and 'Slope':", correlation)

In [None]:
# Ensure 'Start' is the independent variable and 'Slope' is the dependent variable
X = sm.add_constant(df['Start'])  # adding a constant
y = df['Slope']
model = sm.OLS(y, X).fit()


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.scatter(df['Start'], df['Slope'], alpha=0.5, label='Data Points') #alpha controls the transperacy
plt.plot(df['Start'], model.predict(X), color='red', label='Regression Line')
plt.title('Relationship between Start Time and Slope of Slow Waves')
plt.xlabel('Start Time (s)')
plt.ylabel('Slope')
plt.legend()
plt.show()

In [None]:
sw_chan = sw.summary(grp_chan=True, grp_stage=True) #summary per channel
sw_chan.head(10)

### How to create manual epochs for SW and average them
1. raster plot
2. line graph

In [None]:
channels = ['EEG 001' , 'EEG 002' , 'EEG 003']

# Loop over each channel
for chn in channels:
    # Filter DataFrame for current channel
    df_chn = df[df['Channel'] == chn]
    # Convert 'Start' and 'End' times to sample indices
    start_samples = (df_chn['Start'] * sf).astype(int)
    end_samples = (df_chn['End'] * sf).astype(int)
    # Calculate tmin and tmax
    tmin = -0.2  # 200 ms before the start time
    tmax = np.max((end_samples - start_samples) / sf) + 0.1  # 500 ms after the longest end
    # Create an events array
    events_chn = np.column_stack((start_samples, np.zeros_like(start_samples), np.ones_like(start_samples)))
    # Create Epochs
    epochs_chn = mne.Epochs(raw, events_chn, event_id=1, tmin=tmin, tmax=tmax, picks=[chn], baseline=(None, 0), preload=True)

    # Plotting
    # Plot epochs
    epochs_chn.plot(scalings={'eeg': 60e-6})  # Adjust scalings if necessary

    # Plotting epochs with the image plot that includes the average and the individual epochs
    epochs_chn.plot_image(picks=chn, combine='mean')

    

    

In [None]:
sw.plot_average(figsize=(12, 9)) # creates an avg figure for all SW from all channels

### Troubleshooting
The following cells will come in handy if you need further data manipulation

In [None]:
# Find the index of the minimum value in the 'ValNegPeak' column
min_index = df['ValNegPeak'].idxmin()

# To display the index
print("Index of minimum value in 'ValNegPeak':", min_index)

# If you want to see the entire row corresponding to this minimum value
min_value_row = df.loc[min_index]
print(min_value_row)