In [None]:
### IMPORT EXTERNAL FUNCTIONS
import mne
import numpy as np
from os.path import join
import os
import pandas as pd
import matplotlib.pyplot as plt
from mne.preprocessing import ICA, read_ica
from mne_icalabel import label_components
from collections import defaultdict

os.getcwd()
import functions.io as io
import functions.utils as utils
import functions.preprocessing as preprocessing

### Overview

- Loads data (single sub)
- Checks unit of raw data (without any manipulation of raw, just checking)
- Dropping channels that aren't of interest, and setting channel types (you will want to keep X.Y,Z channels though!)
- Re-scales data and replaces raw (if check above suggests that data is in "raw units ADC" from amplifier. If one needs to be rescaled all do, but I keep the check in here just so I/we know what's going on and why we're rescaling)
- Set montage (so that each electrode has it's location when we use scalp tophographies for visualisations and electrode placement (and later for other stuff I assume))
- High- and low-pass filter + notch filter
- Annotate breaks, so that beginning, breaks between blocks, and end isn't included in the analyses. I found that if we crop the event timings get shifted, if we just annotate here then we don't shift any times and events around, and the "bad" parts are not included in analyses.
- Plot raw trace and manually annotate bad parts of raw trace (I think only the most obvious movements/artefacts with huge shifts amplitude). Annotating can be done when seeing raw traces interactivately, click "a" on keybord to get the "helper/info" and then type "movement" straight away and hit enter. This creates a new "marker" which you then use to mark bad segments (click and drag over the bad segments). Also remove obviously bad channels here (more info at that stage).
- Following this, when the obviously bad movements/artefacts/channels are removed, we can run some checks on channels to check if they are likely to be bad or not - see more info at that point in the notebook. I still don't know what's "appropriate" to do here, but the code is there if useful, and if the results can be used to guide selection of channels?
- After all removal -> re-reference to average
- Run ICA and save it so I don't need to re-run it later (if everything is identical in preprocessing), but can load it and apply it to the data again with e.g., other components removed etc. 
- Plot ICA components to see topographies, time series of components, how much variance is explained by the components etc (I don't really understand this variance fully, what is good and what is bad? Is there a good and bad?). Then some mne functions that "automatically" suggests which components are likely to be ECG, EOG, and mostly muscle noise. I think for now i prefer to do it manually, and use these checks as guides and compare to what I think the components are - see more notes at that point in the notebook.
- Select bad components and remove them 
- Apply the ICA to data
- After this the preprocessing of raw data is done. 


I then have parts where I epoch the data, and then remove bad epochs (the parts I have previously marked as bad are still in the data, because when I load the epochs I tell it to also include what segments of the raw trace I marked as bad to be excluded. I found that if I exclude the "bad_movement" I marked manually before ICA, all my epoch indexes are shifted because parts of the data has been removed, so the behavioural file index and actual remaining data indexes are not matching and I'm pulling the wrong epochs in my extract epoch functions. By keeping the "bad_movement" data in when I extract epochs, no data is removed when epochs are created, so I will have to quickly go over and just remove those that are bad again). This could probably be done in a smoother way, but I don't find this too annyoing to do, at epoch level it doesn't take too long to quickly scan through. 


I plot (and save) PSDs regularly to see what the impact the different preprocessing steps have. Might not be necessary later on - but for now i find it ok. 
I also save the file after some steps (it is marked when saved), so that the steps that takes long/are tedious don't have to be re-done all the time if wanting to test different methods etc. 

# 1. Load dataset and prepare for pre-processing #

1. Load one EEG raw dataset and its associated impedances files
2. Change channel types to match recording setup & remove unnecessary channels (like TRIGGERS, STATUS, ETC.)
3. Apply the 10-20 montage to EEG recording

In [None]:
# Session to preprocess
session_id = "sub023 DBS ON mSST"

working_path = os.path.dirname(os.getcwd())
onedrive_path = utils._get_onedrive_path()

sub = session_id.split(' ') [0]
condition = session_id.split(' ') [1] + ' ' + session_id.split(' ') [2]
sub_onedrive_path_task = join(onedrive_path, sub, 'synced_data', session_id)

#  Set saving path
results_path = join(working_path, "results")
saving_path = join(results_path, session_id)
if not os.path.isdir(saving_path):
    os.makedirs(saving_path)
sub_save_path = join(saving_path, "sub_data", f"{sub}", "figures")
os.makedirs(sub_save_path, exist_ok=True)  # Create the directory if it doesn't exist

save_path = os.path.join(saving_path, "sub_data", f"{sub}", "data")
os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist

# Load raw data
filename = [f for f in os.listdir(sub_onedrive_path_task) if (
    f.endswith('.set') and f.startswith('SYNCHRONIZED_EXTERNAL'))]
file = join(sub_onedrive_path_task, filename[0])
raw = mne.io.read_raw(file, preload=True)

# also load the two impedance files to check channels briefly
impedance_folder = join(onedrive_path, sub, 'raw_data', 'XDF', condition)
impedance_begin_filename = [f for f in os.listdir(impedance_folder) if (
    f.endswith('.txt') and f.startswith('mSST_impedances_begin'))]
impedance_end_filename = [f for f in os.listdir(impedance_folder) if (
    f.endswith('.txt') and f.startswith('mSST_impedances_end'))]
file_begin = join(impedance_folder, impedance_begin_filename[0])
file_end = join(impedance_folder, impedance_end_filename[0])
impedance_begin = pd.read_csv(file_begin, sep='\t', header=None)
impedance_end = pd.read_csv(file_end, sep='\t', header=None)
impedance_begin.drop(2, axis=1, inplace=True)  
impedance_end.drop(2, axis=1, inplace=True)
imp = impedance_begin.merge(impedance_end, on=0, how='outer', suffixes=('_begin', '_end'))
# remove all rows starting with 'UNI':
imp = imp[~imp[0].str.startswith('UNI')]

# flag channels with impedance above 25 kOhm:
high_imp_channels = imp[(imp['1_begin'] > 25) | (imp['1_end'] > 25)]
if not high_imp_channels.empty:
    print("Channels with high impedance (> 25 kOhm):")
    print(high_imp_channels)
else:
    print("No channels with high impedance found.")

sf = raw.info['sfreq']
print(f"Sampling frequency: {sf} Hz")

raw.drop_channels(['CREF', 'X', 'Y', 'Z', 'TRIGGERS', 'STATUS', 'COUNTER', 'BIP 01'])
raw.set_channel_types({'BIP 02': 'ecg', 'BIP 03': 'eog'}) 

# set 10-20 montage
data = raw.set_montage('standard_1020', match_case=False, match_alias=True, on_missing='warn') 
ch_names = data.ch_names


#### Plot raw PSD

In [None]:
%matplotlib inline

### Plot PSD of signal (averaging all channels)
psd_raw_avg = data.compute_psd(method="welch", picks="eeg", fmin=0, fmax=130, n_fft=round(sf)*2, n_overlap=int(round(sf)), window="hamming")
psd_raw_avg.plot(dB=True, average=True).suptitle(f"Raw PSD {sub}")
plt.show()

### Plot PSD of signal (not averaging)
psd_raw = data.compute_psd(method="welch", picks="eeg", fmin=0, fmax=130, n_fft=round(sf)*2, n_overlap=int(round(sf)), window="hamming")
psd_raw.plot(dB=True, average=False).suptitle(f"Raw PSD {sub}")
plt.savefig(join(sub_save_path, f"{sub}_raw_PSD_avg.png"), dpi=300)
plt.show()

In [None]:
%matplotlib qt
data.plot(n_channels = len(ch_names), scalings="auto", title="Raw data")

# 2. Filter the data #

1. Apply high-pass filter at 1Hz to remove slow-drifts
2. Apply low-pass filter at 80Hz to remove DBS main artifact and only keep frequencies of interest for futher analysis
3. Apply notch filter at 50Hz to remove line noise
4. Additionnally if another artifactual peak is visible in the PSD (aliased stimulation frequency...), add another notch-filter around it to dampen it

## 2.1. COMMON filtering steps: High- and low-pass filtering + notch ##

In [None]:
# first apply 1Hz high pass filter
high_passed_filt_eeg_data = data.copy().filter(1, None) 

# then apply 80Hz low pass filter
high_low_passed_filt_eeg_data = high_passed_filt_eeg_data.copy().filter(None, 80) 

# Last, apply notch filter(s)
# Even if 50Hz activity is not present in the raw, it is sometimes present after re-referencing
# Therefore, it is safer to apply a notch filter at 50Hz for all sessions.
filt_eeg_data = high_low_passed_filt_eeg_data.copy().notch_filter(50) 


%matplotlib inline

### Plot PSD of signal (averaging all channels)
psd_filt_avg = filt_eeg_data.compute_psd(method="welch", picks="eeg", fmin=0, fmax=100, n_fft=round(sf)*2, n_overlap=int(round(sf)), window="hamming")
psd_filt_avg.plot(dB=True, average=True).suptitle(f"{sub} Filtered PSD")
plt.show()

### Plot PSD of signal (not averaging)
psd_filt = filt_eeg_data.compute_psd(method="welch", picks="eeg", fmin=0, fmax=100, n_fft=round(sf)*2, n_overlap=int(round(sf)), window="hamming")
psd_filt.plot(dB=True, average=False).suptitle(f"{sub} Filtered PSD")
plt.show()


## 2.2. SPECIFIC step: filter out potential other artifacts ##

e.g. in our recordings at 2048Hz, we commonly see a peak at 48Hz coming from the aliasing of the 16th harmonic of 125Hz DBS 

In [None]:
filt_eeg_data.notch_filter(48)

# 3. Identify and interpolate bad channels before re-referencing to average reference #

1. Get a first general idea of "bad" channels using automatic classification
2. Plot data and scroll through the full recording session: check channels flagged in the automatic detection method, and if necessary delete other bad-looking channels 
3. Interpolate channels labeled as bad
4. Re-reference to average

## 3.1. Recognising bad channels by Z-scoring and looking at SD and variance in freq. and time domain ##

**PSD_Z:**
- Computes each channel’s average power (1–80 Hz) in dB, then z-scores across channels—so it flags electrodes whose overall spectral “bulk” is unusually high or low (e.g., a consistently noisy or dead contact) (A single value per channel, tells us whether the channel is generally noisier or quieter than the others) (looks at Frequency domain).

**P2P_Z:**
- Measures the maximum 250 ms peak-to-peak excursion per channel and z-scores those values—so it catches electrodes with extreme transients (cable pops, big spikes) or abnormally flat signals (looks at Time domain)
(Takes the highest voltage the channel reaches minus the lowest voltage it reaches—that difference is its peak-to-peak value.)

**Corr_Z:**
- Computes each channel’s mean Pearson correlation to all other channels, then z-scores—so it flags sensors that aren’t co-varying with the head (e.g., drifty or disconnected channels, channels that don't correlate with their neighbours).

**Var_Z:**
- Takes each channel’s overall variance and z-scores—so it highlights electrodes that are unusually “spiky” or “quiet” over the whole recording (looks at Time domain).

**Max_Freq_Z:**
- Looks for the largest per-frequency deviation (in z-units) from the grand mean spectrum—so it catches narrowband bursts (line-noise leakage, muscle peaks) that might be lost in the broad PSD average (=For each frequency bin look at how many SD above/below the across-channel mean power that channel’s power is at that exact frequency (i.e. a z-score at 10 Hz, at 20 Hz, etc.), and take whichever of those frequency-specific z-scores is largest in absolute value. That way, even if a channel’s overall PSD looks okay, a single narrowband spike (say a 50 Hz line-noise leak or a muscle peak at 80 Hz) will make its “max_freq_z” jump out.) (looks at Frequency domain).

#### Rule of thumb:
- If it fails two or more of the checks, probably bad - but do a visual check of PSD and raw trace
- Failing one check: Look at PSD and raw trace - if looking good, keep it in

**NOTE** Frontal channels can often be flagged, but that is likely because of eyeblinks, so don't necessarily remove them if the raw trace itself looks ok besides the eyeblink!

In [None]:
%matplotlib qt

# rename for readability
bad_chan_identifier = filt_eeg_data.copy().drop_channels(['Fp1', 'Fpz', 'Fp2']) # drop frontal channels because they are often noisy due to eye blinks and this affects the mean/variance computed

# pick only the EEG channels that aren’t already in raw.info['bads']
picks = mne.pick_types(bad_chan_identifier.info, meg=False, eeg=True, exclude='bads')

sfreq = bad_chan_identifier.info['sfreq']

# Frequency-domain PSD checks
psd_container = bad_chan_identifier.compute_psd(
    method="welch",
    picks=picks,
    fmin=1.0,
    fmax=80.0,
    n_fft=160,
    window="hamming"
)

# Convert to dB and compute mean PSD per channel
psds     = psd_container.get_data()     # (n_picks, n_freqs)
psd_db   = 10 * np.log10(psds)
mean_psd = psd_db.mean(axis=1)
psd_z    = (mean_psd - mean_psd.mean()) / mean_psd.std()

# Peak-to-peak amplitude in 250 ms windows
data = bad_chan_identifier.get_data(picks=picks)     # (n_picks, n_times)
win_samp = int(0.25 * sfreq)
n_win    = data.shape[1] // win_samp
p2p_mat  = np.zeros((len(picks), n_win))
for w in range(n_win):
    seg = data[:, w*win_samp:(w+1)*win_samp]
    p2p_mat[:, w] = seg.max(axis=1) - seg.min(axis=1)
max_p2p = p2p_mat.max(axis=1)
p2p_z   = (max_p2p - max_p2p.mean()) / max_p2p.std()

# Channel–channel correlation 
corr      = np.corrcoef(data)
mean_corr = corr.mean(axis=0)
corr_z    = (mean_corr - mean_corr.mean()) / mean_corr.std()

# Compute time-domain variance per channel and z-score
# Flags channels with too much or too little amplitude variability over the rec (relative to the other electrodes),
# i.e., it could be flat or very spiky
chan_vars = np.var(data, axis=1)
var_z     = (chan_vars - chan_vars.mean()) / chan_vars.std()

# Compute frequency-wise z-scores and max deviation per channel
# Looks for channels which have high spikes in some frequencies, could indicate e.g., line noise or muscle artefact
mean_freq   = psd_db.mean(axis=0)
std_freq    = psd_db.std(axis=0)
freq_z      = (psd_db - mean_freq) / std_freq
max_freq_z  = np.max(np.abs(freq_z), axis=1)

# Threshold all metrics at ±2.5 Z
thresh = 2.5
mask_psd   = np.abs(psd_z)    > thresh
mask_p2p   = np.abs(p2p_z)    > thresh
mask_corr  = corr_z           < -thresh    # flag very low corr (z < -2.5)
mask_var   = np.abs(var_z)    > thresh
mask_freq  = max_freq_z       > thresh

# combine
mask_all = mask_psd | mask_p2p | mask_corr | mask_var | mask_freq
bad_channels = [bad_chan_identifier.ch_names[picks[i]]
                for i, m in enumerate(mask_all) if m]

# Print per-metric flagged channels
print(">> PSD outliers (|z|>2.5):", [bad_chan_identifier.ch_names[picks[i]] for i in np.where(mask_psd)[0]])
print(">> Peak-to-peak outliers (|z|>2.5):", [bad_chan_identifier.ch_names[picks[i]] for i in np.where(mask_p2p)[0]])
print(">> Low-correlation outliers (corr_z < -2.5):", [bad_chan_identifier.ch_names[picks[i]] for i in np.where(mask_corr)[0]])
print(">> Variance outliers (|z|>2.5):", [bad_chan_identifier.ch_names[picks[i]] for i in np.where(mask_var)[0]])
print(">> Spectral‐spike outliers (max_freq_z>2.5):", [bad_chan_identifier.ch_names[picks[i]] for i in np.where(mask_freq)[0]])
print(">> Final flagged channels:", bad_channels)

# Summarize in a DataFrame 
df = pd.DataFrame({
    "Channel":     [bad_chan_identifier.ch_names[p]      for p in picks],
    "PSD_Z":       np.round(psd_z,    2),
    "P2P_Z":       np.round(p2p_z,    2),
    "Corr_Z":      np.round(corr_z,   2),
    "Var_Z":       np.round(var_z,    2),
    "Max_Freq_Z":  np.round(max_freq_z,2),
    "Flagged":     mask_all
})
print(df)

# Visual check: all individual PSDs
bad_chan_identifier.plot_psd(
    fmin=1.0, fmax=80.0,
    picks=picks,
    dB=True,
    average=False,
    show=True
)


## 3.2. Scrolling through recording and manual removal of "bad" channels ##

In [None]:
%matplotlib qt

# Plot raw signal
filt_eeg_data.plot(n_channels = len(ch_names), duration = 20, scalings="auto", block=True) # Cell block waits for plot to be closed before continuing (i.e., saving)

os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist

file_p = join(save_path, f"{sub.replace(' ', '_')}_artefacts_removed_EEGdata_eeg.fif")
filt_eeg_data.save(file_p, overwrite=True)

In [None]:
# ensure bad channels are listed here, otherwise they will be included in the average reference
filt_eeg_data.info['bads']

## 3.3. Interpolate bad channels ##

In [None]:
eeg_data_interp = filt_eeg_data.copy().interpolate_bads(reset_bads=True)

## 3.4. Re-reference to average ##

In [None]:
av_ref_data = []
av_ref_data = eeg_data_interp.copy().set_eeg_reference(ref_channels="average")


#### Plot PSD again

In [None]:
%matplotlib inline

### Plot PSD of signal (averaging all channels)
psd_final_avg = av_ref_data.compute_psd(method="welch", picks="eeg", fmin=1, fmax=130, window="hamming")
psd_final_avg.plot(dB=True, average=True).suptitle(f"{sub} PSD after re-referencing")
plt.show()

### Plot PSD of signal (not averaging)
psd_final = av_ref_data.compute_psd(method="welch", picks="eeg", fmin=1, fmax=130, window="hamming")
psd_final.plot(dB=True, average=False).suptitle(f"{sub} PSD after re-referencing")
plt.savefig(join(sub_save_path, f"{sub}_after_reRef.png"), dpi=300)
plt.show()

In [None]:
%matplotlib qt
av_ref_data.plot(n_channels = len(ch_names), scalings="auto", title="After re-referencing to average")

# 4. Clean artifatcs using Independent Component Analysis #

1. Fit ICA to the filtered re-referenced data
2. Automatic labelling of "bad" components
    1. Using mne-icalabel
    2. Using eegprep (from Arnaud Delorme)
3. Manual labelling of "bad" components (and comparison with automatic)
4. Select which ICA components to exclude and reconstruct the signal without them


**Notes from https://eeglab.org/tutorials/06_RejectArtifacts/RunICA.html:**

"ICA takes all training data into consideration. When too many types (i.e., scalp distributions) of noise - complex movement artifacts, electrode pops, etc – are left in the training data, these unique and irreplicable data features will draw the attention of ICA, producing a set of component maps including many single-channel or noisy-appearing components. The number of components (degrees of freedom) devoted to the decomposition of brain EEG alone will be correspondingly reduced.
Therefore, presenting ICA with as much clean EEG data as possible is the best strategy. Note that blink and other stereotyped EEG artifacts do not necessarily have to be removed since they are likely to be isolated as single ICA components.
Here clean EEG data means data after removing noisy time segments (does not apply to removed ICA components)." --> hence, remove breaks before runnning ICA

**Random note:**

When cognitive load increases, blinking usually decreases. Because when participants are expecting something important to happen (stimulus appearance), they will unconsciously try not to blink, to pay more attention (source: https://www.youtube.com/watch?v=AXCxrDikpaM)

## 4.1. Fit the ICA ##

Here we use the extended infomax method, because it is the preferred method for mne-icalabel automatic classification. The ICA must be 

In [None]:
break_annot = mne.preprocessing.annotate_break(av_ref_data, min_break_duration=10, t_start_after_previous=4, t_stop_before_next=4)
### Add the breaks to the annotations of the data
av_ref_data2 = av_ref_data.set_annotations(raw.annotations + break_annot)

In [None]:
# Dictionary to store the fitted or loaded ICA objects
ica_fit_dict = {}
 
# Construct the file path for the ICA file
ica_file_path = join(saving_path, 'sub_data', f"{sub}", f"{sub}-ica.fif")
    
# # Check if the ICA file exists
# if os.path.exists(ica_file_path):
#     # If it exists, load the ICA object
#     ica = read_ica(ica_file_path)
#     print(f"Loaded ICA for {sub} from {ica_file_path}")

# else:
# If it doesn't exist, fit ICA to the data
ica = ICA(n_components=None, random_state=11, method='infomax', fit_params=dict(extended=True)) # Number of components is chosen by function to account for 99%
ica.fit(av_ref_data2.copy().pick_types(eeg=True), reject_by_annotation=True)                                  

# Save the ICA object to file (saving the ICA decomposition, not any raw/cleaned data)
ica.save(ica_file_path, overwrite=True)
print(f"Fitted and saved ICA for {sub} at {ica_file_path}")
    
ica_fit_dict = ica

# for i in range(ica.n_components_):
#     fig = ica_fit_dict.plot_properties(picks=[i], inst=av_ref_data2, show=False) #Topographies of each component
#     plt.savefig(join(sub_save_path, f"{sub}_ica_component_{i}.png"), dpi=300)


## 4.2. Automatic labelling of "bad" components ##

### 4.2.1. Using mne-icalabel ###

Trying automatic labeling using mne-icalabel package

In [None]:
%matplotlib inline
# Get automatic classification labels for ICA components using mne-icalabel package (pip install mne-icalabel) to get a first idea
ic_labels = label_components(av_ref_data, ica, method="iclabel")
labels = ic_labels["labels"]
mne_icalabel_excluded = [
    idx for idx, label in enumerate(labels) if label not in ["brain", "other"]
]
mne_icalabel_labels = [
    (idx, label) for idx, label in enumerate(labels)
]
print(f"Excluding these ICA components: {mne_icalabel_excluded}")

ica.plot_properties(av_ref_data, picks=mne_icalabel_excluded, verbose=False)

In [None]:
mne_icalabel_labels

### 4.2.2. Using eegprep (from Arnaud Delorme) ###

In [None]:
# from eegprep import iclabel
# EEG = iclabel(av_ref_data)

In [None]:
## HERE ADD EEGPREP PIPELINE

## 4.3. Manual labelling of "bad' ICA components ##

In [None]:
%matplotlib qt
fig_sources = ica.plot_sources(av_ref_data2) # Timeseries of each ICA component
plt.savefig(join(sub_save_path, f"{sub}_ica_sources.png"), dpi=300)

In [None]:
%matplotlib qt

fig = ica.plot_components(inst=av_ref_data,  psd_args={'fmin': 0, 'fmax': 80}, ncols=6, nrows=6) #Topographies of each component
plt.savefig(join(sub_save_path, f"{sub}_ica_components.png"), dpi=300)

## 4.4. Select final components to exclude and reconstruct signal ##

In [None]:
#to_exclude = [0, 1, 9, 10, 11, 12, 13, 17, 21, 22, 23]
to_exclude = mne_icalabel_excluded.copy()
#to_exclude.extend([11, 18, 21, 22, 23, 25, 26])

In [None]:
to_exclude

In [None]:
# Apply ICA to the data
data_after_ica = ica.apply(av_ref_data2.copy(), exclude=to_exclude)

# 5. Visual inspection and saving after all preprocessing steps and ICA #

In [None]:
%matplotlib qt
data_after_ica.plot(n_channels=len(ch_names), scalings="auto", title="After ICA")



In [None]:
%matplotlib inline

### Plot PSD of signal (not averaging)
psd_after_ica = data_after_ica.compute_psd(method="welch", picks="eeg", fmin=1, fmax=80, window="hamming")
psd_after_ica.plot(dB=True, average=False).suptitle(f"{sub} PSD after ICA")
sub_save_path = join(saving_path, "sub_data", f"{sub}", "figures", "PSDs")
plt.savefig(join(sub_save_path, f"{sub}_after_ICA_PSD.png"), dpi=300)
plt.show()

In [None]:
save_path = os.path.join(saving_path, "sub_data", f"{sub}", "data")
os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist

file_p = join(save_path, f"{sub.replace(' ', '_')}_postICA_EEGdata_eeg.fif")
data_after_ica.save(file_p, overwrite=True)


# 6. ADDITIONAL: visualize STFT of all channels #

In [None]:
import scipy

In [None]:
%matplotlib inline
ch_names_data = data_after_ica.info['ch_names']
eeg_data = data_after_ica.get_data(picks='eeg')
eeg_times = data_after_ica.times

ch_names_data.remove('BIP 02')
ch_names_data.remove('BIP 03')

vmin, vmax = -17, -13

for ch in ch_names_data:
    ch_index = ch_names_data.index(ch)

    f, t, Zxx = scipy.signal.stft(
        eeg_data[ch_index, :], raw.info['sfreq'], nperseg=int(round(raw.info['sfreq'])), noverlap=int(round(raw.info['sfreq']) / 2), nfft=int(round(raw.info['sfreq']))
    )
    Pxx = np.abs(Zxx)
    plt.figure(figsize=(10, 4))
    im = plt.imshow(np.log(Pxx), aspect='auto', origin='lower',
                            extent=[t[0], t[-1], f[0], f[-1]], cmap='viridis',
                            vmin=vmin, vmax=vmax)
    plt.ylim(0,100)
    plt.colorbar()
    plt.xlabel('Time [sec]')
    plt.ylabel('Frequency [Hz]')
    plt.title(f'{session_id} - EEG - {ch}')
    #plt.savefig(os.path.join(fig_saving_path, f"{session_id}_{ch}_STFT.png"))
    plt.show()

# 7. ADDITIONAL: Extract epochs to check data quality after ICA and plot specific channels: 
C3/C4 channels are the ones above the motor cortex, they can be plotted to check that the expected beta desynchronization is visible.

In [None]:
mSST_raw_behav_session_data_path = join(
        onedrive_path, sub, "raw_data", 'BEHAVIOR', condition, 'mSST'
        )
for filename in os.listdir(mSST_raw_behav_session_data_path):
        if filename.endswith(".csv"):
            fname = filename
filepath_behav = join(mSST_raw_behav_session_data_path, fname)
df = pd.read_csv(filepath_behav)

# return the index of the first row which is not filled by a Nan value:
start_task_index = df['blocks.thisRepN'].first_valid_index()
stop_task_index = df['blocks.thisRepN'].last_valid_index()
df_maintask = df.iloc[start_task_index:stop_task_index + 1] ### HERE MISTAKE OF INDEXING: CHECK IN OTHER SCRIPTS IF THIS IS ALSO WRONG!!!

# remove all useless columns to clean up dataframe
column_names = df_maintask.columns
columns_to_keep = [i for i in [
    'blocks.thisN', 'trial_loop.thisN', 'trial_type', 
    'continue_signal_time', 'stop_signal_time', 
    'fixation_cross.started', 'go_rectangle.started',
    'key_resp_experiment.keys', 'key_resp_experiment.corr', 'key_resp_experiment.rt',
    'early_press_resp.keys', 'early_press_resp.rt', 'early_press_resp.corr',
    'late_key_resp1.keys', 'late_key_resp1.rt', 
    'late_key_resp2.keys', 'late_key_resp2.rt'
    ] if i in column_names]

mini_df_maintask = df_maintask[columns_to_keep]
print(mini_df_maintask.shape)

# remove the trials with early presses, as in these trials the cues were not presented (for mSST)
early_presses = mini_df_maintask[mini_df_maintask['early_press_resp.corr'] == 1]
early_presses_trials = list(early_presses.index)
number_early_presses = len(early_presses_trials)
print(f'Number of early presses: {number_early_presses}')

# remove trials with early presses from the dataframe:
df_maintask_copy = mini_df_maintask.drop(early_presses_trials).reset_index(drop=True)
print(df_maintask_copy.shape)
print(df_maintask_copy['blocks.thisN'])

# First generate global epochs (without taking into account success outcome)
# events and event_id used for epochs creation
events, event_id = mne.events_from_annotations(data_after_ica)
epochs, filtered_event_dict = preprocessing.create_epochs(
        data_after_ica, 
        sub, 
        keys_to_keep = ['GC', 'GF', 'GO', 'GS', 'continue', 'stop'],
        tmin = -3.5,
        tmax = 3.5,
        baseline=None
        )
n_epochs = len(epochs)
print(epochs)

# inverse mapping (event code -> label)
inv_event_id = {v: k for k, v in event_id.items()}

metadata = pd.DataFrame(index=np.arange(len(epochs)))
metadata["event"] = [inv_event_id[e] for e in epochs.events[:, 2]]
metadata["trial_type"] = np.nan

# LFP -> behavioral naming mapping
mapping = {
    "GC": "go_continue_trial",
    "GO": "go_trial",
    "GF": "go_fast_trial",
    "GS": "stop_trial",
}

trial_mask = metadata["event"].isin(mapping.keys())

assert trial_mask.sum() == len(df_maintask_copy), \
    f"Mismatch: {trial_mask.sum()} LFP trials vs {len(df_maintask_copy)} behavioral trials"

# fill directly from behavioral file
for col in df_maintask_copy.columns:
    metadata.loc[trial_mask, col] = df_maintask_copy[col].values

for i in metadata.index:
    if metadata.loc[i, "event"] == "continue":
        # find the last GC before this
        prev_idx = metadata.loc[:i-1][metadata["event"] == "GC"].index[-1]
        metadata.loc[i, df_maintask_copy.columns] = metadata.loc[prev_idx, df_maintask_copy.columns]

    elif metadata.loc[i, "event"] == "stop":
        # find the last GS before this
        prev_idx = metadata.loc[:i-1][metadata["event"] == "GS"].index[-1]
        metadata.loc[i, df_maintask_copy.columns] = metadata.loc[prev_idx, df_maintask_copy.columns]

epochs.metadata = metadata


In [None]:
epochs

In [None]:
%matplotlib qt
# Generate list of evoked objects from conditions names
evokeds = [epochs[name].crop(tmin=-0.5, tmax=1.5).average() for name in ("GO", "GF","GC", "GS")]
colors = "blue", "red", "green", "black"
title = "Evoked responses"

fig, axes = plt.subplots(1)

mne.viz.plot_evoked_topo(evokeds, title=title, background_color="w", axes=axes)

fig.savefig(join(saving_path, 'evoked_responses.pdf'))

In [None]:
######################
### TFR PARAMETERS ###
######################

decim = 1 
freqs = np.arange(1, 80, 1) 
# For 500ms time resolution at 1 Hz: n_cycles = 1 * 0.5 = 0.5
# For 50ms time resolution at 40 Hz: n_cycles = 40 * 0.05 = 2
# Linear interpolation between these points
#n_cycles = 0.5 + (freqs - 1) * (2 - 0.5) / (40 - 1)
#n_cycles = freqs / 2.0
n_cycles = np.minimum(np.maximum(freqs / 2.0, 2), 10)

tfr_args = dict(
    method="morlet",
    freqs=freqs,
    n_cycles=n_cycles,
    decim=decim,
    return_itc=False,
    average=False
)        

tmin_tmax = [-500, 1500]
vmin_vmax = [-70, 70]

In [None]:
%matplotlib qt
epochs.plot(n_channels = 32, n_epochs = 4, events=True)

In [None]:
%matplotlib inline

for epoch_cond in ['GO_successful', 'GF_successful', 'GC_successful', 'GS_successful', 'GS_unsuccessful']:
    ch_interest = "Cz"

    t_min_max = [-500, 1500]

    epoch_type = epoch_cond.split('_')[0]
    outcome_str = epoch_cond.split('_')[1]

    outcome = 1.0 if outcome_str == 'successful' else 0.0

    type_mask = epochs.metadata["event"] == epoch_type
    outcome_mask = epochs.metadata["key_resp_experiment.corr"] == outcome
    data = epochs[type_mask & outcome_mask]   

    channel_epochs = data.copy().pick([ch_interest])
    power_channel = channel_epochs.compute_tfr(**tfr_args)
    mean_power_channel = np.nanmean(power_channel.data, axis=0).squeeze()

    times = power_channel.times * 1000  # Convert to milliseconds
    freqs = power_channel.freqs

    baseline_indices = (times >= -500) & (times <= -200)
    baseline_power_channel = np.nanmean(mean_power_channel[:, baseline_indices], axis=1, keepdims=True)
    percentage_change_channel = (mean_power_channel - baseline_power_channel) / baseline_power_channel * 100

    time_indices = np.logical_and(times >= t_min_max[0], times <= t_min_max[1])
    sliced_data = percentage_change_channel[:, time_indices].squeeze()    

    plt.imshow(sliced_data, aspect='auto', origin='lower', 
            extent=[t_min_max[0], t_min_max[1], 
            tfr_args["freqs"][0], tfr_args["freqs"][-1]], 
            cmap='jet', vmin=vmin_vmax[0], vmax=vmin_vmax[-1]
    )
    plt.axvline(x=0, color='k', linestyle='--', linewidth=1)
    plt.title(f"Percentage Change in Power for {epoch_cond} - {ch_interest}")
    plt.xlabel("Time from GO cue (ms)")
    plt.ylabel("Frequency (Hz)")
    plt.colorbar(label="Percentage Change (%)")
    plt.tight_layout()
    plt.show()

In [None]:
%matplotlib qt
# Generate list of evoked objects from conditions names
evokeds_stopping = [epochs_subsets[name].average() for name in ("GS_successful", "GS_unsuccessful")]
colors = "green", "red"
title = "Evoked stopping responses"

fig, axes = plt.subplots(1)

mne.viz.plot_evoked_topo(evokeds_stopping, title=title, background_color="w", axes=axes)

#fig.savefig(join(saving_path, 'evoked_GS_success_unsuccess_responses.pdf'))

In [None]:
%matplotlib qt
# Generate list of evoked objects from conditions names
evokeds = [epochs_subsets[name].average() for name in ("GO_successful", "GS_successful")]
colors = ("green", 0.5), ("red", 0.5)
title = "Evoked reactive inhibition responses compared to going"

fig, axes = plt.subplots(1)

mne.viz.plot_evoked_topo(evokeds, color=colors,title=title, background_color="w", axes=axes)

In [None]:
%matplotlib qt
# Generate list of evoked objects from conditions names
evokeds = [epochs_subsets[name].average() for name in ("GO_successful", "GF_successful")]
colors = ("green", 0.5), ("blue", 0.5)
title = "Evoked proactive inhibition responses"

fig, axes = plt.subplots(1)

mne.viz.plot_evoked_topo(evokeds, color=colors,title=title, background_color="w", axes=axes)

### Note
Once each subject has been preprocessed and epochs cleaned, other scripts will load the cleaned epochs and run the analyses.