## Import libraries

In [1]:
import os
import pandas as pd
import re
import mne
import matplotlib.pyplot as plt
from mne.preprocessing import ICA
from mne_icalabel.gui import label_ica_components
import autoreject
from specparam.plts.spectra import plot_spectra
from specparam import SpectralGroupModel
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from utils import (
    compare_before_after,
    create_epochs,
    switch_bad_to_interpolate,
    update_reject_log,
    plot_specparam_on_scalp,
    examine_spectra,
    exclude_bad_channels,
    extract_elements,
    save_specparam_results,
    plot_models,
)

plt.close("all")

# mne.viz.set_browser_backend("matplotlib") # As an alternative, you can use "qt" for the Qt backend
mne.set_config("MNE_BROWSER_BACKEND", "qt")


auto_reject_params = {
    "n_interpolate": [1, 2, 16, 32],
    "n_jobs": -1,
    "random_state": 100,
    "thresh_method": "bayesian_optimization",
    "verbose": False,
    "consensus": [0.8],
}

auto_reject_pre_ica = autoreject.AutoReject(**auto_reject_params)

fmax = 40.0
fg = SpectralGroupModel(
    peak_width_limits=[1, 6],
    min_peak_height=0.15,
    peak_threshold=2.0,
    max_n_peaks=6,
    verbose=False,
)
recompute = False


ModuleNotFoundError: No module named 'utils'

# Subject 147

In [None]:
# import raw data or epochs
subject = 147
auto_reject_params = {
    "n_interpolate": [
        1,
        2,
    ],
    "n_jobs": -1,
    "random_state": 100,
    "thresh_method": "bayesian_optimization",
    "verbose": False,
    "consensus": [0.4],
}
# epochs_params = {"reject": dict(eeg=400e-6)}
# Check if the file exists

df = pd.read_excel("ICA TO REMOVE.xlsx")

channels_to_add = ["E" + str(ch) for ch in extract_elements(df, 147, "Bad_Channels")]
print(channels_to_add)
raw = mne.io.read_raw_fif(f"sub-{subject}_filtered_raw.fif", preload=True).filter(
    l_freq=1.0, h_freq=None
)
for channel in channels_to_add:
    raw.info["bads"].append(channel)

raw.plot()

epochs = create_epochs(raw, epochs_params=None, length=5, overlap=1.5)

autoreject_fname = f"sub-{subject}_auto_reject_pre_ica.h5"
if os.path.exists(autoreject_fname) and recompute is False:
    auto_reject_pre_ica = autoreject.read_auto_reject(autoreject_fname)
    print(f"File {autoreject_fname} exists. Loading file")
else:
    print(f"File {autoreject_fname} does not exist.")
    auto_reject_pre_ica = autoreject.AutoReject(**auto_reject_params).fit(epochs[:20])
    print("fitting finished")
    auto_reject_pre_ica.save(autoreject_fname, overwrite=True)
    print(f"File {autoreject_fname} saved")

epochs_ar, reject_log = auto_reject_pre_ica.transform(epochs, return_log=True)
reject_plot = reject_log.plot("vertical")

In [None]:
scalings = dict(eeg=60e-6)

df = pd.read_excel("ICA TO REMOVE.xlsx")
bad_epochs_indices = extract_elements(df, 147, "Epochs")

if len(bad_epochs_indices) > 0:
    zeroed_reject_log = switch_bad_to_interpolate(reject_log)
    new_reject_log = update_reject_log(zeroed_reject_log, bad_epochs_indices)
else:
    new_reject_log = reject_log
new_reject_log.save(f"sub-{subject}_reject_log_updated.h5", overwrite=True)

In [None]:
# plot bad epochs if exists
new_reject_log.plot_epochs(epochs, scalings=scalings)
epochs[new_reject_log.bad_epochs].plot(scalings=scalings)
epochs[~new_reject_log.bad_epochs].plot(scalings=scalings)

## ICA

In [None]:
# Compute ICA
ica_params = {
    "n_components": 0.99,
    "random_state": 99,
    "method": "picard",
    "fit_params": {"ortho": False, "extended": True},
}

ica = ICA(**ica_params)
new_epochs = epochs[~new_reject_log.bad_epochs]

ica.fit(new_epochs)
if ica.n_components_ > 50:
    ica_params = {
        "n_components": 50,
        "random_state": 99,
        "method": "picard",
        "fit_params": {"ortho": False, "extended": True},
    }
    ica = ICA(**ica_params)
    ica.fit(new_epochs)

In [None]:
# Plot bad components
# plot  ica with bad components
df = pd.read_excel("ICA TO REMOVE.xlsx")
ica_bad_components = extract_elements(df, 147, "ICA_3Take")
print(ica_bad_components)
# quick  comparison
ica.exclude = ica_bad_components
_ = ica.plot_components()

In [None]:
# Plot bad components
# plot  ica with bad components
if mne.get_config("MNE_BROWSER_BACKEND") == "qt":
    ica.exclude = ica_bad_components
    ica_plot_whole_timeseries = ica.plot_sources(
        new_epochs,
        picks=None,
        show_scrollbars=True,
    )
    gui = label_ica_components(new_epochs, ica)

else:
    ica.exclude = []
    ica_plot_whole_timeseries = ica.plot_sources(
        new_epochs,
        picks=ica_bad_components,
        show_scrollbars=False,
        start=0,
        stop=len(new_epochs) - 1,
    )

In [None]:
# Compare bad data
ica.exclude = ica_bad_components
ica.save(f"sub-147_my_ica_model-ica.fif", overwrite=True)
epochs_clean_manual = ica.apply(new_epochs.copy(), exclude=ica.exclude)
fig_psd = compare_before_after(
    epochs[~new_reject_log.bad_epochs],
    epochs_clean_manual,
    subject,
)

## Extrapolation of bad channels using autoreject

In [None]:
auto_reject_post_ica = autoreject.AutoReject(
    n_interpolate=[1, 2, 4, 8, 32, 64],
    n_jobs=-1,
    random_state=100,
    thresh_method="bayesian_optimization",
    verbose=False,
    # n_interpolate=np.array([0]),
    # consensus=0.8,
).fit(epochs_clean_manual[:20])
print("fitting finished")
epochs_ar, reject_log_final = auto_reject_post_ica.transform(
    epochs_clean_manual, return_log=True
)
autoreject_post_fname = f"sub-{subject}_auto_reject_post_ica.h5"
auto_reject_post_ica.save(autoreject_post_fname, overwrite=True)

reject_plot = reject_log_final.plot("vertical")
reject_log_final.save(f"sub-{subject}_reject_log_final.h5", overwrite=True)

epochs_interpolated = epochs_ar.copy().interpolate_bads(exclude=["VREF"])
epochs_interpolated.save(f"analysis/sub-{subject}_interpolated-epo.fif", overwrite=True)

## Comparison of EEG data

In [None]:
# plot comparison
fig_psd1 = compare_before_after(
    epochs_clean_manual, epochs_interpolated, subject, title="After AR"
)

## Specparam visualisation

In [None]:
from utils import (
    save_specparam_results,
    plot_specparam_on_scalp,
    examine_spectra,
    plot_models,
)

subject = 147
raw = mne.io.read_raw_fif(f"sub-{subject}_filtered_raw.fif", preload=True).filter(
    l_freq=1.0, h_freq=None
)
n_interpolated_channels = len(raw.info["bads"])
# epochs = create_epochs(raw, epochs_params=None, length=5, overlap=1.5)
epochs_interpolated = mne.read_epochs(
    f"analysis/sub-{subject}_interpolated-epo.fif", preload=True
)
ica_loaded = mne.preprocessing.read_ica(f"sub-{subject}_my_ica_model-ica.fif")

psd = epochs_interpolated.compute_psd().average()
spectra, freqs = psd.get_data(return_freqs=True)
# Initialize a FOOOFGroup object, with desired settings

# Define the frequency range to fit
freq_range = [2, 40]
fg.fit(freqs, spectra, freq_range)
fg.plot()
specparam_df = save_specparam_results(
    fg,
    epochs_interpolated,
    ica_loaded,
    subject,
    n_interpolated_channels=n_interpolated_channels,
)
display(specparam_df.head())
# epochs_ar_good = exclude_bad_channels(epochs_interpolated)

plot_specparam_on_scalp(fg, epochs_interpolated, subject)
examine_spectra(fg, subject)


In [None]:
# plot models that are smallest, median and highest exponent
plot_models(fg, param_choice="exponent")

In [None]:
# plot models with worst, median and best goodness of fit
plot_models(fg, param_choice="r_squared")

In [None]:
plt.close("all")