In [None]:
# automatic pre & post as the same length of stim epoch
import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Assuming you already have `raw` and sampling frequency `sf`

events = mne.events_from_annotations(raw)  # raw events
events_id = events[-1]  # grab event dict
actual_events = events[:-1][0]  # grab actual events
print(events_id, '\n')
print(actual_events)

column_dict = {'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4}

# Indices for 'stim end' and 'stim start'
stim_end_index = column_dict['stim end']
stim_start_index = column_dict['stim start']

# Filter events to get only stim start and stim end
filtered_data = [item for item in actual_events if item[2] == stim_end_index or item[2] == stim_start_index]

# Initialize lists for pre-stim, stim, and post-stim epochs
pre_stim_epochs = []
stim_epochs = []
post_stim_epochs = []

# Loop through the epochs and define pre-stim and post-stim epochs with equal duration to the stim epoch
for i in range(0, len(filtered_data), 2):
    if i + 1 < len(filtered_data):  # Ensure i+1 is within bounds
        stim_start = filtered_data[i][0]
        stim_end = filtered_data[i+1][0]
        stim_duration = stim_end - stim_start  # Calculate stim duration
        
        pre_stim_epoch = (stim_start - stim_duration, stim_start)
        stim_epoch = (stim_start, stim_end)
        post_stim_epoch = (stim_end, stim_end + stim_duration)
        
        pre_stim_epochs.append(pre_stim_epoch)
        stim_epochs.append(stim_epoch)
        post_stim_epochs.append(post_stim_epoch)

# Convert epochs to time for plotting
def convert_sample_to_time(epochs, sf):
    return [(start / sf, end / sf) for start, end in epochs]

pre_stim_epochs_time = convert_sample_to_time(pre_stim_epochs, sf)
stim_epochs_time = convert_sample_to_time(stim_epochs, sf)
post_stim_epochs_time = convert_sample_to_time(post_stim_epochs, sf)

# Plotting
plt.figure(figsize=(10, 6))

# Plot pre-stim epochs as shaded regions
for (start, end) in pre_stim_epochs_time:
    plt.axvspan(start, end, color='blue', alpha=0.3, label='Pre-Stim')

# Plot stim epochs as shaded regions
for (start, end) in stim_epochs_time:
    plt.axvspan(start, end, color='orange', alpha=0.3, label='Stim')

# Plot post-stim epochs as shaded regions
for (start, end) in post_stim_epochs_time:
    plt.axvspan(start, end, color='green', alpha=0.3, label='Post-Stim')

# Set labels and title
plt.xlabel('Time (s)')
plt.ylabel('Epoch Type')
plt.title('Epochs Visualization with Equal Pre-Stim and Post-Stim Duration')

# To prevent duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())

plt.show()


In [None]:
# mannutal pre & post epoch length
import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Assuming you already have `raw` and sampling frequency `sf`

events = mne.events_from_annotations(raw)  # raw events
events_id = events[-1]  # grab event dict
actual_events = events[:-1][0]  # grab actual events
print(events_id, '\n')
print(actual_events)

column_dict = {'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4}

# Indices for 'stim end' and 'stim start'
stim_end_index = column_dict['stim end']
stim_start_index = column_dict['stim start']

# Filter events to get only stim start and stim end
filtered_data = [item for item in actual_events if item[2] == stim_end_index or item[2] == stim_start_index]

# User inputs for pre-stim and post-stim durations (in seconds)
pre_stim_duration_sec = float(input("Enter the desired pre-stim duration in seconds: "))
post_stim_duration_sec = float(input("Enter the desired post-stim duration in seconds: "))

# Convert these durations to samples
pre_stim_duration = int(pre_stim_duration_sec * sf)
post_stim_duration = int(post_stim_duration_sec * sf)

# Initialize lists for pre-stim, stim, and post-stim epochs
pre_stim_epochs = []
stim_epochs = []
post_stim_epochs = []

# Loop through the epochs and define pre-stim and post-stim epochs with user-specified duration
for i in range(0, len(filtered_data), 2):
    if i + 1 < len(filtered_data):  # Ensure i+1 is within bounds
        stim_start = filtered_data[i][0]
        stim_end = filtered_data[i+1][0]
        
        pre_stim_epoch = (stim_start - pre_stim_duration, stim_start)
        stim_epoch = (stim_start, stim_end)
        post_stim_epoch = (stim_end, stim_end + post_stim_duration)
        
        pre_stim_epochs.append(pre_stim_epoch)
        stim_epochs.append(stim_epoch)
        post_stim_epochs.append(post_stim_epoch)

# Convert epochs to time for plotting
def convert_sample_to_time(epochs, sf):
    return [(start / sf, end / sf) for start, end in epochs]

pre_stim_epochs_time = convert_sample_to_time(pre_stim_epochs, sf)
stim_epochs_time = convert_sample_to_time(stim_epochs, sf)
post_stim_epochs_time = convert_sample_to_time(post_stim_epochs, sf)

# Plotting
plt.figure(figsize=(10, 6))

# Plot pre-stim epochs as shaded regions
for (start, end) in pre_stim_epochs_time:
    plt.axvspan(start, end, color='blue', alpha=0.3, label='Pre-Stim')

# Plot stim epochs as shaded regions
for (start, end) in stim_epochs_time:
    plt.axvspan(start, end, color='orange', alpha=0.3, label='Stim')

# Plot post-stim epochs as shaded regions
for (start, end) in post_stim_epochs_time:
    plt.axvspan(start, end, color='green', alpha=0.3, label='Post-Stim')

# Set labels and title
plt.xlabel('Time (s)')
plt.ylabel('Epoch Type')
plt.title('Epochs Visualization with User-Specified Pre-Stim and Post-Stim Duration')

# To prevent duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())

plt.show()


In [None]:
# added a layer of minimum stim time to be included
import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Assuming you already have `raw` and sampling frequency `sf`

events = mne.events_from_annotations(raw)  # raw events
events_id = events[-1]  # grab event dict
actual_events = events[:-1][0]  # grab actual events
print(events_id, '\n')
print(actual_events)

column_dict = {'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4}

# Indices for 'stim end' and 'stim start'
stim_end_index = column_dict['stim end']
stim_start_index = column_dict['stim start']

# Filter events to get only stim start and stim end
filtered_data = [item for item in actual_events if item[2] == stim_end_index or item[2] == stim_start_index]

# User inputs for pre-stim and post-stim durations (in seconds)
pre_stim_duration_sec = float(input("Enter the desired pre-stim duration in seconds: "))
post_stim_duration_sec = float(input("Enter the desired post-stim duration in seconds: "))

# Convert these durations to samples
pre_stim_duration = int(pre_stim_duration_sec * sf)
post_stim_duration = int(post_stim_duration_sec * sf)

# Minimum stim duration threshold in seconds
min_stim_duration_sec = 150
min_stim_duration_samples = int(min_stim_duration_sec * sf)

# Initialize lists for pre-stim, early stim, late stim, and post-stim epochs
pre_stim_epochs = []
early_stim_epochs = []
late_stim_epochs = []
post_stim_epochs = []

# Loop through the epochs and define pre-stim, early stim, late stim, and post-stim epochs
for i in range(0, len(filtered_data), 2):
    if i + 1 < len(filtered_data):  # Ensure i+1 is within bounds
        stim_start = filtered_data[i][0]
        stim_end = filtered_data[i+1][0]
        
        stim_duration = stim_end - stim_start  # Calculate the duration of the stim epoch
        
        if stim_duration < min_stim_duration_samples:
            continue  # Skip this stim epoch if it is shorter than the minimum duration
        
        stim_midpoint = (stim_start + stim_end) // 2  # Calculate the midpoint of the stim epoch
        
        pre_stim_epoch = (stim_start - pre_stim_duration, stim_start)
        early_stim_epoch = (stim_start, stim_midpoint)
        late_stim_epoch = (stim_midpoint, stim_end)
        post_stim_epoch = (stim_end, stim_end + post_stim_duration)
        
        pre_stim_epochs.append(pre_stim_epoch)
        early_stim_epochs.append(early_stim_epoch)
        late_stim_epochs.append(late_stim_epoch)
        post_stim_epochs.append(post_stim_epoch)

# Convert epochs to time for plotting
def convert_sample_to_time(epochs, sf):
    return [(start / sf, end / sf) for start, end in epochs]

pre_stim_epochs_time = convert_sample_to_time(pre_stim_epochs, sf)
early_stim_epochs_time = convert_sample_to_time(early_stim_epochs, sf)
late_stim_epochs_time = convert_sample_to_time(late_stim_epochs, sf)
post_stim_epochs_time = convert_sample_to_time(post_stim_epochs, sf)

# Plotting
plt.figure(figsize=(10, 6))

# Plot pre-stim epochs as shaded regions
for (start, end) in pre_stim_epochs_time:
    plt.axvspan(start, end, color='blue', alpha=0.3, label='Pre-Stim')

# Plot early stim epochs as shaded regions
for (start, end) in early_stim_epochs_time:
    plt.axvspan(start, end, color='orange', alpha=0.3, label='Early Stim')

# Plot late stim epochs as shaded regions
for (start, end) in late_stim_epochs_time:
    plt.axvspan(start, end, color='red', alpha=0.3, label='Late Stim')

# Plot post-stim epochs as shaded regions
for (start, end) in post_stim_epochs_time:
    plt.axvspan(start, end, color='green', alpha=0.3, label='Post-Stim')

# Set labels and title
plt.xlabel('Time (s)')
plt.ylabel('Epoch Type')
plt.title('Epochs Visualization with Early and Late Stim Splitting (Stim > 100s)')

# To prevent duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())

plt.show()


In [None]:

# Mannual pre & post stim lengths.
# Minimum stim length
# automatic overlap removal

# Assuming you already have `raw` and sampling frequency `sf`

events = mne.events_from_annotations(raw)  # raw events
events_id = events[-1]  # grab event dict
actual_events = events[:-1][0]  # grab actual events
print(events_id, '\n')
print(actual_events)

column_dict = {'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4}

# Indices for 'stim end' and 'stim start'
stim_end_index = column_dict['stim end']
stim_start_index = column_dict['stim start']

# Filter events to get only stim start and stim end
filtered_data = [item for item in actual_events if item[2] == stim_end_index or item[2] == stim_start_index]

# User inputs for pre-stim and post-stim durations (in seconds)
pre_stim_duration_sec = float(input("Enter the desired pre-stim duration in seconds: "))
post_stim_duration_sec = float(input("Enter the desired post-stim duration in seconds: "))

# Convert these durations to samples
pre_stim_duration = int(pre_stim_duration_sec * sf)
post_stim_duration = int(post_stim_duration_sec * sf)

# Minimum stim duration threshold in seconds
min_stim_duration_sec = 150
min_stim_duration_samples = int(min_stim_duration_sec * sf)

# Initialize lists for pre-stim, early stim, late stim, and post-stim epochs
pre_stim_epochs = []
early_stim_epochs = []
late_stim_epochs = []
post_stim_epochs = []

# Previous epoch end to check for overlap
previous_end = 0

# Loop through the epochs and define pre-stim, early stim, late stim, and post-stim epochs
for i in range(0, len(filtered_data), 2):
    if i + 1 < len(filtered_data):  # Ensure i+1 is within bounds
        stim_start = filtered_data[i][0]
        stim_end = filtered_data[i+1][0]
        
        stim_duration = stim_end - stim_start  # Calculate the duration of the stim epoch
        
        if stim_duration < min_stim_duration_samples:
            continue  # Skip this stim epoch if it is shorter than the minimum duration
        
        stim_midpoint = (stim_start + stim_end) // 2  # Calculate the midpoint of the stim epoch
        
        pre_stim_epoch = (stim_start - pre_stim_duration, stim_start)
        early_stim_epoch = (stim_start, stim_midpoint)
        late_stim_epoch = (stim_midpoint, stim_end)
        post_stim_epoch = (stim_end, stim_end + post_stim_duration)
        
        # Check for overlap with the previous epoch
        if pre_stim_epoch[0] < previous_end:
            continue  # Skip this entire protocol if there's an overlap with the previous one
        
        # Update previous_end to the end of the current post-stim epoch
        previous_end = post_stim_epoch[1]
        
        pre_stim_epochs.append(pre_stim_epoch)
        early_stim_epochs.append(early_stim_epoch)
        late_stim_epochs.append(late_stim_epoch)
        post_stim_epochs.append(post_stim_epoch)

# Convert epochs to time for plotting
def convert_sample_to_time(epochs, sf):
    return [(start / sf, end / sf) for start, end in epochs]

pre_stim_epochs_time = convert_sample_to_time(pre_stim_epochs, sf)
early_stim_epochs_time = convert_sample_to_time(early_stim_epochs, sf)
late_stim_epochs_time = convert_sample_to_time(late_stim_epochs, sf)
post_stim_epochs_time = convert_sample_to_time(post_stim_epochs, sf)

# Plotting
plt.figure(figsize=(10, 6))

# Plot pre-stim epochs as shaded regions
for (start, end) in pre_stim_epochs_time:
    plt.axvspan(start, end, color='blue', alpha=0.3, label='Pre-Stim')

# Plot early stim epochs as shaded regions
for (start, end) in early_stim_epochs_time:
    plt.axvspan(start, end, color='orange', alpha=0.3, label='Early Stim')

# Plot late stim epochs as shaded regions
for (start, end) in late_stim_epochs_time:
    plt.axvspan(start, end, color='red', alpha=0.3, label='Late Stim')

# Plot post-stim epochs as shaded regions
for (start, end) in post_stim_epochs_time:
    plt.axvspan(start, end, color='green', alpha=0.3, label='Post-Stim')

# Set labels and title
plt.xlabel('Time (s)')
plt.ylabel('Epoch Type')
plt.title('Epochs Visualization with Overlap Removal')

# To prevent duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())

plt.show()


In [None]:
def add_value_labels(ax, spacing=5):
    """Add labels to the end of each bar in a bar chart."""
    for rect in ax.patches:
        # Get the height of each bar
        height = rect.get_height()
        # Adding text annotation on the bar
        ax.text(
            rect.get_x() + rect.get_width() / 2,  # X coordinate of the text
            height + spacing,  # Y coordinate of the text (with some spacing)
            f'{height:.2f}',  # Format the label as needed
            ha='center',  # Horizontal alignment
            va='bottom'  # Vertical alignment
        )

plt.figure(figsize=(15, 6))

# Plotting Mean Values with annotations and moving the legend outside
plt.subplot(1, 2, 1)  # 1 row, 2 columns, first plot
ax = comparison_means.plot(kind='bar', ax=plt.gca(), color=['#6baed6', '#bdd7e7', '#eff3ff', '#fdbe85', '#fd8d3c'])
plt.title('Mean Values of Wave Properties')
plt.ylabel('Mean Values')
plt.xlabel('Classification', labelpad=10)
plt.xticks(rotation=0)
add_value_labels(ax)  # Adding the value labels
plt.legend(title='Properties', loc='upper left', bbox_to_anchor=(1, 1))  # Moving the legend outside

# Plotting Counts with annotations and moving the legend outside
plt.subplot(1, 2, 2)  # 1 row, 2 columns, second plot
ax2 = comparison_counts.plot(kind='bar', color='#6baed6', ax=plt.gca())
plt.title('Count of Instances by Group')
plt.ylabel('Count')
plt.xlabel('Classification', labelpad=10)
plt.xticks(rotation=0)
add_value_labels(ax2)  # Adding the value labels

# Show plots
plt.tight_layout()  # Adjust subplots to give some padding and prevent overlap
plt.show()


In [None]:
# Assuming `filtered_epochs` is your filtered DataFrame and `raw` is your MNE Raw object
# Also assuming `sf` is the sampling frequency of your data

# Initialize a list to store the figures for the first five epochs
figures_list = []

# Loop through each row in the filtered_epochs DataFrame
for index, row in filtered_epochs_1s_negative.iterrows():
    # Convert 'Start', 'NegPeak', 'PosPeak', and 'End' times to sample indices for the current row
    start_sample = int(row['Start'] * sf)
    neg_peak_sample = int(row['NegPeak'] * sf)
    pos_peak_sample = int(row['PosPeak'] * sf)
    end_sample = int(row['End'] * sf)

    # Add 0.2 seconds before the start of the epoch for the baseline
    baseline_sample = int(0.2 * sf)
    new_start_sample = start_sample - baseline_sample

    # Ensure that the new start time is not negative
    if new_start_sample < 0:
        new_start_sample = 0

    # Calculate tmax based on the end event
    tmax = (end_sample - new_start_sample) / sf
    tmin = -0.2

    # Create an event array for the 'Start' event only
    events = np.array([
        [start_sample, 0, 1]  # Start event at t=0 (event_id = 1)
    ])

    # Create the epoch using the 'Start' event
    epochs = mne.Epochs(raw, events, event_id={'Start': 1},
                        tmin=tmin, tmax=(end_sample - start_sample) / sf, baseline=(tmin, 0), preload=True)

    # Convert NegPeak and PosPeak times from sample indices to seconds relative to the epoch start
    neg_peak_time = (row['NegPeak'] - row['Start'])
    pos_peak_time = (row['PosPeak'] - row['Start'])