### The objective

create a short and sweet notebook that takes in EEGLAB preprocessed data, and outputs SW statistics

## Some questions to look into with time

- why some sw seem to be detected by multiple channels while other sw are detected by a single ch?
- how can we look at "traveling" vs focal sw? i.e Type I vs Type II
- is there specific sw features that are correlated with their distribution?

- are there more sw detected in frontal ch?

In [1]:
#Import the goods:
#matplotlib qu allows you to open interactive figures. Highly Recommended for this notebook
#Make sure you activate the YASA conda environment
%matplotlib qt  
import mne
import numpy as np
import pandas as pd
import yasa
import matplotlib as plt
import statsmodels.api as sm
import ipywidgets


## This next few cells will shop you different methods of getting help:

**1. Using the help() Function**

You can use Python’s built-in help() function to see the documentation of a function, module, or object. This will display a scrollable text area inside the notebook that includes the docstring and other helpful information.

**2. Using Question Mark ?**

Appending a question mark (?) before or after an object, method, or function in a Jupyter Notebook will display its docstring in a pop-up window. Handy tool for quick look-ups.

**3. Using Double Question Marks ??**

More detailed information, including the source code (if available), you can use double question marks (??). This is useful for understanding the implementation details.

**4. Using the dir() Function**

To get a list of all the attributes and methods associated with an object, module, or class, you can use the dir() function. Does not provide documentation, but helps you explore what’s available.

In [None]:
help(yasa.bandpower)

In [None]:
yasa.compute_features_stage?

In [None]:
yasa.filter_data??

In [None]:
dir(yasa.art_detect)

### Load data
* Change the io.methodX based on the EEG file type you are trying to load
* `preload` lets you keep the data in memory and manipulate it in different cells.

In [2]:
fname = '/Users/idohaber/Desktop/Paper_dir/Source_test/1_Functional_Data/0.5-6_full_NREM.set'
raw = mne.io.read_raw_eeglab(fname, preload=True);
#raw.filter(0.5, 30, fir_design='firwin')  # Adjust the frequency range as needed
raw

Reading /Users/idohaber/Desktop/Paper_dir/Source_test/1_Functional_Data/0.5-6_full_NREM.fdt
Reading 0 ... 5369842  =      0.000 ... 10739.684 secs...


  raw = mne.io.read_raw_eeglab(fname, preload=True);
  raw = mne.io.read_raw_eeglab(fname, preload=True);


0,1
Measurement date,Unknown
Experimenter,Unknown
Participant,Unknown

0,1
Digitized points,197 points
Good channels,194 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available

0,1
Sampling frequency,500.00 Hz
Highpass,0.00 Hz
Lowpass,250.00 Hz
Filenames,0.5-6_full_NREM.fdt
Duration,02:58:60 (HH:MM:SS)


In [3]:
# prepare the data for processing
data = raw.get_data(units="uV") 
raw.resample(100)
sf = raw.info['sfreq']
print(data.shape , sf)

(194, 5369843) 100.0


In [None]:
# View the raw data and make sure everything looks as expected
raw.plot(clipping=None);

In [None]:
# Drop bad channels and view remaining channels
raw.drop_channels('E063');
chan = raw.ch_names
print(chan)

In [23]:
events = mne.events_from_annotations(raw) # raw events
events_id = events[-1]                    # grab event dict
actual_events = events[:-1][0]            # grab actual events
print(events_id ,'\n') 
print(actual_events)

Used Annotations descriptions: ['Sleep Stage', 'boundary', 'stim end', 'stim start']
{'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4} 

[[      0       0       2]
 [    511       0       1]
 [   3511       0       1]
 ...
 [1068613       0       1]
 [1071613       0       1]
 [1073968       0       2]]


In [24]:
column_dict = {'Sleep Stage': 1, 'boundary': 2, 'stim end': 3, 'stim start': 4}

# Indices for 'stim end' and 'stim start'
stim_end_index = column_dict['stim end']
stim_start_index = column_dict['stim start']

# Filtering the list and removing the value at index 1
filtered_and_trimmed_data = [[item[0], item[2]] for item in actual_events if item[2] in [3, 4]] #if you want to remove the first index
filtered_data = [item for item in actual_events if item[2] == stim_end_index or item[2] == stim_start_index] # for convient visualization in MNE formatting

# Extracting epochs
epochs = [(filtered_data[i][0], filtered_data[i+1][0]) for i in range(len(filtered_data) - 1)]

# Separate the epochs into 'stim' and 'non-stim' based on odd and even indices
stim_epochs = [epochs[i] for i in range(len(epochs)) if i % 2 == 0]  # Even index: 0, 2, 4, ...
non_stim_epochs = [epochs[i] for i in range(len(epochs)) if i % 2 != 0]  # Odd index: 1, 3, 5, ...

print("Stim Epochs:", stim_epochs)
print("Non-Stim Epochs:", non_stim_epochs)


Stim Epochs: [(26899, 48449), (99058, 120316), (168303, 188319), (238081, 246943), (270439, 291073), (326426, 347449), (395419, 416172), (467733, 478034), (531521, 552544), (576529, 597224), (609224, 628240), (664045, 684264), (719542, 740218), (794219, 807994), (842528, 856914), (909670, 929750), (964962, 985182), (1029337, 1050476)]
Non-Stim Epochs: [(48449, 99058), (120316, 168303), (188319, 238081), (246943, 270439), (291073, 326426), (347449, 395419), (416172, 467733), (478034, 531521), (552544, 576529), (597224, 609224), (628240, 664045), (684264, 719542), (740218, 794219), (807994, 842528), (856914, 909670), (929750, 964962), (985182, 1029337)]


In [25]:
def convert_sample_to_time(epochs, sf):
    return [(start / sf, end / sf) for start, end in epochs]

def filter_df_by_epochs(df, epochs):
    # This function filters the DataFrame to only include rows within any of the specified epochs
    return pd.concat([df[(df['time'] >= start) & (df['time'] <= end)] for start, end in epochs])

In [26]:
# Convert epoch values from samples to time
stim_epochs_time = convert_sample_to_time(stim_epochs, sf)
non_stim_epochs_time = convert_sample_to_time(non_stim_epochs, sf)
print('Stim Epochs:' ,  stim_epochs_time)
print('Non Stim Epochs' , non_stim_epochs_time)

Stim Epochs: [(268.99, 484.49), (990.58, 1203.16), (1683.03, 1883.19), (2380.81, 2469.43), (2704.39, 2910.73), (3264.26, 3474.49), (3954.19, 4161.72), (4677.33, 4780.34), (5315.21, 5525.44), (5765.29, 5972.24), (6092.24, 6282.4), (6640.45, 6842.64), (7195.42, 7402.18), (7942.19, 8079.94), (8425.28, 8569.14), (9096.7, 9297.5), (9649.62, 9851.82), (10293.37, 10504.76)]
Non Stim Epochs [(484.49, 990.58), (1203.16, 1683.03), (1883.19, 2380.81), (2469.43, 2704.39), (2910.73, 3264.26), (3474.49, 3954.19), (4161.72, 4677.33), (4780.34, 5315.21), (5525.44, 5765.29), (5972.24, 6092.24), (6282.4, 6640.45), (6842.64, 7195.42), (7402.18, 7942.19), (8079.94, 8425.28), (8569.14, 9096.7), (9297.5, 9649.62), (9851.82, 10293.37)]


In [28]:
# show events over time
show_events = mne.viz.plot_events(actual_events) 
show_stim_events = mne.viz.plot_events(filtered_data)


In [None]:
yasa.plot_spectrogram(data[chan.index("E10")], sf , win_sec=5);

In [9]:
#sw = yasa.sw_detect(raw, hypno=hypno_up, include=(2, 3))
sw = yasa.sw_detect(raw, verbose=False, coupling=False);
df = sw.summary(); # general summary for each sw
df # Inspect the dataframe

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    1.8s


Unnamed: 0,Start,NegPeak,MidCrossing,PosPeak,End,Duration,ValNegPeak,ValPosPeak,PTP,Slope,Frequency,Channel,IdxChannel
0,273.76,274.03,274.86,275.16,275.65,1.89,-41.824487,33.425113,75.249599,90.662168,0.529101,E1,0
1,545.35,545.66,545.90,546.19,546.51,1.16,-40.937927,55.789954,96.727881,403.032836,0.862069,E1,0
2,1209.77,1210.18,1210.51,1210.79,1211.14,1.37,-50.271566,32.988102,83.259668,252.302023,0.729927,E1,0
3,1211.14,1211.49,1211.75,1212.06,1212.71,1.57,-41.000249,47.190866,88.191115,339.196596,0.636943,E1,0
4,1212.71,1213.75,1214.07,1214.41,1214.84,2.13,-48.924268,27.038476,75.962744,237.383574,0.469484,E1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
7752,8782.88,8783.18,8783.44,8783.62,8783.85,0.97,-50.691698,25.743016,76.434714,293.979671,1.030928,E253,193
7753,9180.13,9180.50,9180.77,9181.10,9181.55,1.42,-54.312763,79.285652,133.598416,494.808946,0.704225,E253,193
7754,9962.47,9963.12,9963.38,9963.70,9964.08,1.61,-45.219392,69.214942,114.434334,440.132054,0.621118,E253,193
7755,9964.08,9964.53,9965.31,9965.70,9966.23,2.15,-48.162902,68.680883,116.843785,149.799724,0.465116,E253,193


In [10]:
sw.plot_detection() # lets you scroll through the detection very conveniently

interactive(children=(IntSlider(value=0, description='Epoch:', layout=Layout(align_items='center', justify_con…

<function yasa.detection._DetectionResults.plot_detection.<locals>.update(epoch, amplitude, channel, win_size, filt)>

In [19]:
# Define the classification function
def classify_wave(start_time, pre_stim_epochs_time, stim_epochs_time, post_stim_epochs_time):
    """Classify each wave based on the start time into 'Pre-Stim', 'Stim', or 'Post-Stim'."""
    for start, end in pre_stim_epochs_time:
        if start <= start_time <= end:
            return 'Pre-Stim'
    for start, end in stim_epochs_time:
        if start <= start_time <= end:
            return 'Stim'
    for start, end in post_stim_epochs_time:
        if start <= start_time <= end:
            return 'Post-Stim'
    return 'Unknown'  # If the wave does not fall within any of the epochs

# Apply classification to DataFrame
df['Classification'] = df['Start'].apply(classify_wave, args=(pre_stim_epochs_time, stim_epochs_time, post_stim_epochs_time))


NameError: name 'pre_stim_epochs_time' is not defined

In [12]:
# Group by classification and calculate mean and count for each group
comparison_means = df.groupby('Classification')[['Duration', 'ValNegPeak', 'ValPosPeak', 'PTP', 'Frequency']].mean()
comparison_counts = df.groupby('Classification')['Start'].count()  # Counting instances using the 'Start' column

# Print results
print("Mean Values by Group:")
print(comparison_means)
print("\nCount of Instances by Group:")
print(comparison_counts)


Mean Values by Group:
                Duration  ValNegPeak  ValPosPeak         PTP  Frequency
Classification                                                         
Non-Stim        1.490206  -53.456670   47.130343  100.587012   0.707368
Stim            1.586573  -54.634870   48.912681  103.547551   0.654855
Unknown         1.499242  -50.468992   46.027232   96.496224   0.697786

Count of Instances by Group:
Classification
Non-Stim    4554
Stim        2781
Unknown      422
Name: Start, dtype: int64


In [13]:
df #look at the frame to make sure classification added properly 

Unnamed: 0,Start,NegPeak,MidCrossing,PosPeak,End,Duration,ValNegPeak,ValPosPeak,PTP,Slope,Frequency,Channel,IdxChannel,Classification
0,273.76,274.03,274.86,275.16,275.65,1.89,-41.824487,33.425113,75.249599,90.662168,0.529101,E1,0,Stim
1,545.35,545.66,545.90,546.19,546.51,1.16,-40.937927,55.789954,96.727881,403.032836,0.862069,E1,0,Non-Stim
2,1209.77,1210.18,1210.51,1210.79,1211.14,1.37,-50.271566,32.988102,83.259668,252.302023,0.729927,E1,0,Non-Stim
3,1211.14,1211.49,1211.75,1212.06,1212.71,1.57,-41.000249,47.190866,88.191115,339.196596,0.636943,E1,0,Non-Stim
4,1212.71,1213.75,1214.07,1214.41,1214.84,2.13,-48.924268,27.038476,75.962744,237.383574,0.469484,E1,0,Non-Stim
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7752,8782.88,8783.18,8783.44,8783.62,8783.85,0.97,-50.691698,25.743016,76.434714,293.979671,1.030928,E253,193,Non-Stim
7753,9180.13,9180.50,9180.77,9181.10,9181.55,1.42,-54.312763,79.285652,133.598416,494.808946,0.704225,E253,193,Stim
7754,9962.47,9963.12,9963.38,9963.70,9964.08,1.61,-45.219392,69.214942,114.434334,440.132054,0.621118,E253,193,Non-Stim
7755,9964.08,9964.53,9965.31,9965.70,9966.23,2.15,-48.162902,68.680883,116.843785,149.799724,0.465116,E253,193,Non-Stim


In [14]:
import matplotlib.pyplot as plt
# Plotting Mean Values with annotations and moving the legend outside
plt.figure(figsize=(15, 6))
plt.subplot(1, 2, 1)  # This means 1 row, 2 columns, first plot
ax = comparison_means.plot(kind='bar', ax=plt.gca(), color=['#6baed6', '#bdd7e7', '#eff3ff', '#fdbe85', '#fd8d3c'])
plt.title('Mean Values of Wave Properties')
plt.ylabel('Mean Values')
plt.xlabel('Classification' , labelpad=10)
plt.xticks(rotation=0)
#add_value_labels(ax)
plt.legend(title='Properties', loc='upper left', bbox_to_anchor=(1,1))  # Moving the legend outside

# Plotting Counts with annotations and moving the legend outside
plt.subplot(1, 2, 2)  # This means 1 row, 2 columns, second plot
ax2 = comparison_counts.plot(kind='bar', color='#6baed6', ax=plt.gca())
plt.title('Count of Instances by Group')
plt.ylabel('Count')
plt.xlabel('Classification' , labelpad=10)
plt.xticks(rotation=0)
#add_value_labels(ax2)

# Show plots
plt.tight_layout()  # This adjusts subplots to give some padding and prevent overlap
plt.show()


In [15]:
# Assuming df is your DataFrame name
print("Descriptive Statistics for Start Times:")
print(df['Start'].describe())

print("\nDescriptive Statistics for Slope:")
print(df['Slope'].describe())

correlation = df['Start'].corr(df['Slope'])
print("Correlation coefficient between 'Start' and 'Slope':", correlation)

Descriptive Statistics for Start Times:
count     7757.000000
mean      4446.780500
std       3129.549056
min         68.280000
25%       1818.850000
50%       3725.760000
75%       5974.860000
max      10738.320000
Name: Start, dtype: float64

Descriptive Statistics for Slope:
count    7757.000000
mean      314.783246
std       150.292025
min        65.525210
25%       214.873258
50%       289.455053
75%       374.077443
max      1085.810415
Name: Slope, dtype: float64
Correlation coefficient between 'Start' and 'Slope': -0.04983203119918011


In [16]:
# Ensure 'Start' is the independent variable and 'Slope' is the dependent variable
X = sm.add_constant(df['Start'])  # adding a constant
y = df['Slope']
model = sm.OLS(y, X).fit()


In [17]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.scatter(df['Start'], df['Slope'], alpha=0.5, label='Data Points') #alpha controls the transperacy
plt.plot(df['Start'], model.predict(X), color='red', label='Regression Line')
plt.title('Relationship between Start Time and Slope of Slow Waves')
plt.xlabel('Start Time (s)')
plt.ylabel('Slope')
plt.legend()
plt.show()

In [18]:
sw_chan = sw.summary(grp_chan=True, grp_stage=True) #summary per channel
sw_chan.head(10)

Unnamed: 0_level_0,Count,Duration,ValNegPeak,ValPosPeak,PTP,Slope,Frequency
Channel,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
E1,71,1.574789,-59.050551,46.118574,105.169126,305.846313,0.662559
E10,180,1.531111,-61.29033,50.965044,112.255373,332.076089,0.687356
E100,30,1.568,-53.853229,41.834737,95.687967,259.319227,0.651388
E101,50,1.5796,-51.065444,42.832307,93.897751,256.039518,0.650027
E105,36,1.479167,-55.986286,35.743112,91.729398,257.713828,0.700866
E106,36,1.485278,-56.564182,36.492885,93.057067,259.400992,0.700938
E109,39,1.638205,-52.291124,40.770383,93.061507,239.00155,0.621562
E11,136,1.514926,-53.533949,51.518927,105.052876,322.625705,0.697543
E110,44,1.619091,-50.721725,40.657877,91.379601,247.288795,0.634518
E113,42,1.474762,-57.954636,35.703676,93.658312,276.455665,0.701671


In [None]:
frontal_ch = ['E31', 'E166', 'E32', 'E167', 'E25', 'E189', '177']
posterior_ch = ['E118', 'E127', 'E152', 'E109', 'E137', 'E115', 'E159']


# Apply classification to DataFrame
df['Classification'] = df['Start'].apply(classify_wave, args=(stim_epochs_time, non_stim_epochs_time))6



In [None]:
raw.plot_sensors(kind='3d', show_names=True ); # this is a place holder for now so I do not forget

### How to create manual epochs for SW and average them
1. raster plot
2. line graph

In [None]:
channels = ['EEG 001' , 'EEG 002' , 'EEG 003']

# Loop over each channel
for chn in channels:
    # Filter DataFrame for current channel
    df_chn = df[df['Channel'] == chn]
    # Convert 'Start' and 'End' times to sample indices
    start_samples = (df_chn['Start'] * sf).astype(int)
    end_samples = (df_chn['End'] * sf).astype(int)
    # Calculate tmin and tmax
    tmin = -0.2  # 200 ms before the start time
    tmax = np.max((end_samples - start_samples) / sf) + 0.1  # 500 ms after the longest end
    # Create an events array
    events_chn = np.column_stack((start_samples, np.zeros_like(start_samples), np.ones_like(start_samples)))
    # Create Epochs
    epochs_chn = mne.Epochs(raw, events_chn, event_id=1, tmin=tmin, tmax=tmax, picks=[chn], baseline=(None, 0), preload=True)

    # Plotting
    # Plot epochs
    epochs_chn.plot(scalings={'eeg': 60e-6})  # Adjust scalings if necessary

    # Plotting epochs with the image plot that includes the average and the individual epochs
    epochs_chn.plot_image(picks=chn, combine='mean')

    

    

In [None]:
sw.plot_average(figsize=(12, 9)) # creates an avg figure for all SW from all channels

### Troubleshooting
The following cells will come in handy if you need further data manipulation

In [None]:
# Find the index of the minimum value in the 'ValNegPeak' column
min_index = df['ValNegPeak'].idxmin()

# To display the index
print("Index of minimum value in 'ValNegPeak':", min_index)

# If you want to see the entire row corresponding to this minimum value
min_value_row = df.loc[min_index]
print(min_value_row)