### The objective

create a short and sweet notebook that takes in EEGLAB preprocessed data, and outputs SW statistics

## Some questions to look into with time

- why some sw seem to be detected by multiple channels while other sw are detected by a single ch?
- how can we look at "traveling" vs focal sw? i.e Type I vs Type II
- is there specific sw features that are correlated with their distribution?

- are there more sw detected in frontal ch?

In [None]:
#Import the goods:
#matplotlib qu allows you to open interactive figures. Highly Recommended for this notebook
#Make sure you activate the YASA conda environment
%matplotlib qt  
import mne
import numpy as np
import pandas as pd
import yasa
import matplotlib as plt
import statsmodels.api as sm
import ipywidgets


## This next few cells will shop you different methods of getting help:

**1. Using the help() Function**

You can use Python’s built-in help() function to see the documentation of a function, module, or object. This will display a scrollable text area inside the notebook that includes the docstring and other helpful information.

**2. Using Question Mark ?**

Appending a question mark (?) before or after an object, method, or function in a Jupyter Notebook will display its docstring in a pop-up window. Handy tool for quick look-ups.

**3. Using Double Question Marks ??**

More detailed information, including the source code (if available), you can use double question marks (??). This is useful for understanding the implementation details.

**4. Using the dir() Function**

To get a list of all the attributes and methods associated with an object, module, or class, you can use the dir() function. Does not provide documentation, but helps you explore what’s available.

In [None]:
help(yasa.bandpower)

In [None]:
yasa.compute_features_stage?

In [None]:
yasa.filter_data??

In [None]:
dir(yasa.art_detect)

### Load data
* Change the io.methodX based on the EEG file type you are trying to load
* `preload` lets you keep the data in memory and manipulate it in different cells.

In [None]:
fname = '/Users/idohaber/Desktop/Source_test/1_Functional_Data/02_27_pilot_MB_sleep1_HP_LP_bc_we_short_bs_sr_avgref.set'
raw = mne.io.read_raw_eeglab(fname, preload=True);
#raw.filter(0.5, 30, fir_design='firwin')  # Adjust the frequency range as needed
raw

In [None]:
# prepare the data for processing
data = raw.get_data(units="uV") 
raw.resample(100)
sf = raw.info['sfreq']
print(data.shape , sf)

In [None]:
# View the raw data and make sure everything looks as expected
raw.plot(clipping=None);

In [None]:
# Drop bad channels and view remaining channels
raw.drop_channels('E63');
chan = raw.ch_names
print(chan)

In [None]:
events = mne.events_from_annotations(raw) # raw events
events_id = events[-1]                    # grab event dict
actual_events = events[:-1][0]            # grab actual events
print(events_id ,'\n') 
print(actual_events)

In [None]:
column_dict = {'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4}

# Indices for 'stim end' and 'stim start'
stim_end_index = column_dict['stim end']
stim_start_index = column_dict['stim start']

# Filtering the list and removing the value at index 1
filtered_and_trimmed_data = [[item[0], item[2]] for item in actual_events if item[2] in [3, 4]] #if you want to remove the first index
filtered_data = [item for item in actual_events if item[2] == stim_end_index or item[2] == stim_start_index] # for convient visualization in MNE formatting

# Extracting epochs
epochs = [(filtered_data[i][0], filtered_data[i+1][0]) for i in range(len(filtered_data) - 1)]

# Separate the epochs into 'stim' and 'non-stim' based on odd and even indices
stim_epochs = [epochs[i] for i in range(len(epochs)) if i % 2 == 0]  # Even index: 0, 2, 4, ...
non_stim_epochs = [epochs[i] for i in range(len(epochs)) if i % 2 != 0]  # Odd index: 1, 3, 5, ...

print("Stim Epochs:", stim_epochs)
print("Non-Stim Epochs:", non_stim_epochs)


In [None]:
def convert_sample_to_time(epochs, sf):
    return [(start / sf, end / sf) for start, end in epochs]

def filter_df_by_epochs(df, epochs):
    # This function filters the DataFrame to only include rows within any of the specified epochs
    return pd.concat([df[(df['time'] >= start) & (df['time'] <= end)] for start, end in epochs])

In [None]:
# Convert epoch values from samples to time
stim_epochs_time = convert_sample_to_time(stim_epochs, sf)
non_stim_epochs_time = convert_sample_to_time(non_stim_epochs, sf)
print('Stim Epochs:' ,  stim_epochs_time)
print('Non Stim Epochs' , non_stim_epochs_time)

In [None]:
# show events over time
show_events = mne.viz.plot_events(actual_events) 
show_stim_events = mne.viz.plot_events(filtered_data)


In [None]:
yasa.plot_spectrogram(data[chan.index("E10")], sf , win_sec=5);

In [None]:
#sw = yasa.sw_detect(raw, hypno=hypno_up, include=(2, 3))
sw = yasa.sw_detect(raw, verbose=False, coupling=False);
df = sw.summary(); # general summary for each sw
df # Inspect the dataframe

In [None]:
sw.plot_detection() # lets you scroll through the detection very conveniently

In [None]:
# Define the classification function
def classify_wave(start_time, stim_epochs_time, non_stim_epochs_time):
    """Classify each wave based on the start time into 'Stim' or 'Non-Stim'."""
    for start, end in stim_epochs_time:
        if start <= start_time <= end:
            return 'Stim'
    for start, end in non_stim_epochs_time:
        if start <= start_time <= end:
            return 'Non-Stim'
    return 'Unknown'  # If the wave does not fall within either then put it here
# Apply classification to DataFrame
df['Classification'] = df['Start'].apply(classify_wave, args=(stim_epochs_time, non_stim_epochs_time))

In [None]:
# Group by classification and calculate mean and count for each group
comparison_means = df.groupby('Classification')[['Duration', 'ValNegPeak', 'ValPosPeak', 'PTP', 'Frequency']].mean()
comparison_counts = df.groupby('Classification')['Start'].count()  # Counting instances using the 'Start' column

# Print results
print("Mean Values by Group:")
print(comparison_means)
print("\nCount of Instances by Group:")
print(comparison_counts)


In [None]:
df #look at the frame to make sure classification added properly 

In [None]:
import matplotlib.pyplot as plt
# Plotting Mean Values with annotations and moving the legend outside
plt.figure(figsize=(15, 6))
plt.subplot(1, 2, 1)  # This means 1 row, 2 columns, first plot
ax = comparison_means.plot(kind='bar', ax=plt.gca(), color=['#6baed6', '#bdd7e7', '#eff3ff', '#fdbe85', '#fd8d3c'])
plt.title('Mean Values of Wave Properties')
plt.ylabel('Mean Values')
plt.xlabel('Classification' , labelpad=10)
plt.xticks(rotation=0)
#add_value_labels(ax)
plt.legend(title='Properties', loc='upper left', bbox_to_anchor=(1,1))  # Moving the legend outside

# Plotting Counts with annotations and moving the legend outside
plt.subplot(1, 2, 2)  # This means 1 row, 2 columns, second plot
ax2 = comparison_counts.plot(kind='bar', color='#6baed6', ax=plt.gca())
plt.title('Count of Instances by Group')
plt.ylabel('Count')
plt.xlabel('Classification' , labelpad=10)
plt.xticks(rotation=0)
#add_value_labels(ax2)

# Show plots
plt.tight_layout()  # This adjusts subplots to give some padding and prevent overlap
plt.show()


In [None]:
# Assuming df is your DataFrame name
print("Descriptive Statistics for Start Times:")
print(df['Start'].describe())

print("\nDescriptive Statistics for Slope:")
print(df['Slope'].describe())

correlation = df['Start'].corr(df['Slope'])
print("Correlation coefficient between 'Start' and 'Slope':", correlation)

In [None]:
# Ensure 'Start' is the independent variable and 'Slope' is the dependent variable
X = sm.add_constant(df['Start'])  # adding a constant
y = df['Slope']
model = sm.OLS(y, X).fit()


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.scatter(df['Start'], df['Slope'], alpha=0.5, label='Data Points') #alpha controls the transperacy
plt.plot(df['Start'], model.predict(X), color='red', label='Regression Line')
plt.title('Relationship between Start Time and Slope of Slow Waves')
plt.xlabel('Start Time (s)')
plt.ylabel('Slope')
plt.legend()
plt.show()

In [None]:
sw_chan = sw.summary(grp_chan=True, grp_stage=True) #summary per channel
sw_chan.head(10)

In [None]:
frontal_ch = ['E31', 'E166', 'E32', 'E167', 'E25', 'E189', '177']
posterior_ch = ['E118', 'E127', 'E152', 'E109', 'E137', 'E115', 'E159']


# Apply classification to DataFrame
df['Classification'] = df['Start'].apply(classify_wave, args=(stim_epochs_time, non_stim_epochs_time))6



In [None]:
raw.plot_sensors(kind='3d', show_names=True ); # this is a place holder for now so I do not forget

### How to create manual epochs for SW and average them
1. raster plot
2. line graph

In [None]:
channels = ['EEG 001' , 'EEG 002' , 'EEG 003']

# Loop over each channel
for chn in channels:
    # Filter DataFrame for current channel
    df_chn = df[df['Channel'] == chn]
    # Convert 'Start' and 'End' times to sample indices
    start_samples = (df_chn['Start'] * sf).astype(int)
    end_samples = (df_chn['End'] * sf).astype(int)
    # Calculate tmin and tmax
    tmin = -0.2  # 200 ms before the start time
    tmax = np.max((end_samples - start_samples) / sf) + 0.1  # 500 ms after the longest end
    # Create an events array
    events_chn = np.column_stack((start_samples, np.zeros_like(start_samples), np.ones_like(start_samples)))
    # Create Epochs
    epochs_chn = mne.Epochs(raw, events_chn, event_id=1, tmin=tmin, tmax=tmax, picks=[chn], baseline=(None, 0), preload=True)

    # Plotting
    # Plot epochs
    epochs_chn.plot(scalings={'eeg': 60e-6})  # Adjust scalings if necessary

    # Plotting epochs with the image plot that includes the average and the individual epochs
    epochs_chn.plot_image(picks=chn, combine='mean')

    

    

In [None]:
sw.plot_average(figsize=(12, 9)) # creates an avg figure for all SW from all channels

### Troubleshooting
The following cells will come in handy if you need further data manipulation

In [None]:
# Find the index of the minimum value in the 'ValNegPeak' column
min_index = df['ValNegPeak'].idxmin()

# To display the index
print("Index of minimum value in 'ValNegPeak':", min_index)

# If you want to see the entire row corresponding to this minimum value
min_value_row = df.loc[min_index]
print(min_value_row)