# Define kernel metadata, add YAML and add imports

In [4]:
import nbformat
from nbformat.validator import ValidationError
import re


def validate_and_repair_notebook(notebook_path):  # type: ignore
    # Read the notebook
    with open(notebook_path, "r") as f:
        nb = nbformat.read(f, as_version=4)

    # Validate the notebook
    try:
        nbformat.validate(nb)
        print("The notebook is valid.")
    except ValidationError as e:
        print("The notebook is not valid:", e)

        # If the notebook is not valid, try to repair it
        nb = nbformat.reads(nbformat.writes(nb), as_version=4)
        with open(notebook_path, "w") as f:
            nbformat.write(nb, f)
        print("The notebook has been repaired.")


def add_kernel_metadata(nb):
    nb["metadata"]["kernelspec"] = {
        "display_name": "Python 3",
        "language": "python",
        "name": "conda-paths-3.12",
    }


def add_yaml_metadata(nb):
    raw_cell_content = """---
title: "Preprocessing of High-Density EEG Recordings"
execute:
    echo: false
    warning: false
    enabled: true
format:
    html:
        page-layout: full
        toc: true
        toc-location: left
        embed-resources: true
---"""
    nb["cells"] += [nbformat.v4.new_raw_cell(raw_cell_content)]


def add_import_cells(nb, backend="matplotlib"):
    functions_cell = """import os
import pandas as pd
import re
import mne
import matplotlib.pyplot as plt
from mne.preprocessing import ICA
from mne_icalabel.gui import label_ica_components
import autoreject
from specparam.plts.spectra import plot_spectra
from specparam import SpectralGroupModel
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from utils import (
    compare_before_after,
    create_epochs,
    switch_bad_to_interpolate,
    update_reject_log,
    plot_specparam_on_scalp,
    examine_spectra,
    exclude_bad_channels,
    extract_elements,
    save_specparam_results,
    plot_models
)
plt.close("all")

# mne.viz.set_browser_backend("matplotlib") # As an alternative, you can use "qt" for the Qt backend
mne.set_config("MNE_BROWSER_BACKEND", "{backend}")


auto_reject_params = {{
    "n_interpolate": [1, 2, 16, 32],
    "n_jobs": -1,
    "random_state": 100,
    "thresh_method": "bayesian_optimization",
    "verbose": False,
    "consensus": [0.8],
}}

auto_reject_pre_ica = autoreject.AutoReject(**auto_reject_params)

fmax = 40.0
fg = SpectralGroupModel(
    peak_width_limits=[1, 6],
    min_peak_height=0.15,
    peak_threshold=2.0,
    max_n_peaks=6,
    verbose=False,
)
recompute = False

"""
    nb["cells"] += [nbformat.v4.new_markdown_cell("## Import libraries")]
    nb["cells"] += [nbformat.v4.new_code_cell(functions_cell.format(backend=backend))]


def import_epochs(nb, sub_id: int):
    code_raw_data = """# import raw data or epochs
subject = {sub_id}
auto_reject_params = {{
    "n_interpolate": [
        1,
        2,
    ],
    "n_jobs": -1,
    "random_state": 100,
    "thresh_method": "bayesian_optimization",
    "verbose": False,
    "consensus": [0.4],
}}
#epochs_params = {{"reject": dict(eeg=400e-6)}}
# Check if the file exists

df = pd.read_excel("ICA TO REMOVE.xlsx")

channels_to_add = ["E" + str(ch) for ch in extract_elements(df, {sub_id}, "Bad_Channels")]
print(channels_to_add)
raw = mne.io.read_raw_fif(f"sub-{{subject}}_filtered_raw.fif", preload=True).filter(
    l_freq=1.0, h_freq=None
)
for channel in channels_to_add:
    raw.info["bads"].append(channel)

raw.plot()

epochs = create_epochs(raw, epochs_params=None, length=5, overlap=1.5)

autoreject_fname = f"sub-{{subject}}_auto_reject_pre_ica.h5"
if os.path.exists(autoreject_fname) and recompute is False:
    auto_reject_pre_ica = autoreject.read_auto_reject(autoreject_fname)
    print(f"File {{autoreject_fname}} exists. Loading file")
else:
    print(f"File {{autoreject_fname}} does not exist.")
    auto_reject_pre_ica = autoreject.AutoReject(**auto_reject_params).fit(epochs[:20])
    print("fitting finished")
    auto_reject_pre_ica.save(autoreject_fname, overwrite=True)
    print(f"File {{autoreject_fname}} saved")

epochs_ar, reject_log = auto_reject_pre_ica.transform(epochs, return_log=True)
reject_plot = reject_log.plot("vertical")"""
    nb["cells"] += [nbformat.v4.new_code_cell(code_raw_data.format(sub_id=sub_id))]


def rejection_log_manipulation(nb, sub_id: int):
    code_raw_data1 = """scalings = dict(eeg=60e-6)

df = pd.read_excel("ICA TO REMOVE.xlsx")
bad_epochs_indices = extract_elements(df, {sub_id}, "Epochs")

if len(bad_epochs_indices) > 0:
    zeroed_reject_log = switch_bad_to_interpolate(reject_log)
    new_reject_log = update_reject_log(zeroed_reject_log, bad_epochs_indices)
else:
    new_reject_log = reject_log
new_reject_log.save(f"sub-{{subject}}_reject_log_updated.h5", overwrite=True)"""
    nb["cells"] += [nbformat.v4.new_code_cell(code_raw_data1.format(sub_id=sub_id))]

    code_raw_data2 = """# plot bad epochs if exists
new_reject_log.plot_epochs(epochs, scalings=scalings)
epochs[new_reject_log.bad_epochs].plot(scalings=scalings)
epochs[~new_reject_log.bad_epochs].plot(scalings=scalings)"""
    nb["cells"] += [nbformat.v4.new_code_cell(code_raw_data2.format(sub_id=sub_id))]


def compute_ica(nb, sub_id: int):
    code_raw_data = """# Compute ICA
ica_params = {{
    "n_components": 0.99,
    "random_state": 99,
    "method": "picard",
    "fit_params": {{"ortho": False, "extended": True}},
}}

ica = ICA(**ica_params)
new_epochs = epochs[~new_reject_log.bad_epochs]

ica.fit(new_epochs)
if ica.n_components_ > 50:
    ica_params = {{
        "n_components": 50,
        "random_state": 99,
        "method": "picard",
        "fit_params": {{"ortho": False, "extended": True}},
    }}
    ica = ICA(**ica_params)
    ica.fit(new_epochs)"""
    code_component_ica = """# Plot bad components
# plot  ica with bad components
df = pd.read_excel("ICA TO REMOVE.xlsx")
ica_bad_components = extract_elements(df, {sub_id}, "ICA_3Take")
print(ica_bad_components)
# quick  comparison
ica.exclude =  ica_bad_components
_ = ica.plot_components()"""
    nb["cells"] += [nbformat.v4.new_markdown_cell("## ICA")]
    nb["cells"] += [nbformat.v4.new_code_cell(code_raw_data.format())]
    nb["cells"] += [nbformat.v4.new_code_cell(code_component_ica.format(sub_id=sub_id))]


def ica_bad_components(nb, sub_id: int):
    code_raw_data = """# Plot bad components
# plot  ica with bad components
if mne.get_config("MNE_BROWSER_BACKEND") == "qt":
    ica.exclude = ica_bad_components
    ica_plot_whole_timeseries = ica.plot_sources(
        new_epochs,
        picks=None,
        show_scrollbars=True,
    )
    gui = label_ica_components(new_epochs, ica)

else:
    ica.exclude = []
    ica_plot_whole_timeseries = ica.plot_sources(
        new_epochs,
        picks=ica_bad_components,
        show_scrollbars=False,
        start=0,
        stop=len(new_epochs) - 1,
    )"""
    nb["cells"] += [nbformat.v4.new_code_cell(code_raw_data.format(sub_id=sub_id))]


def ica_compare_bad_and_cleaned_data(nb, sub_id: int):
    code_raw_data = """# Compare bad data
ica.exclude = ica_bad_components
ica.save(f'sub-{sub_id}_my_ica_model-ica.fif',  overwrite=True)
epochs_clean_manual = ica.apply(new_epochs.copy(), exclude=ica.exclude)
fig_psd = compare_before_after(
    epochs[~new_reject_log.bad_epochs],
    epochs_clean_manual,
    subject,
)"""
    nb["cells"] += [nbformat.v4.new_code_cell(code_raw_data.format(sub_id=sub_id))]


# Code chunks computed during iteration

In [5]:
# add autoreject parts
def add_autoreject_post_ica_cells(nb, sub_id: int):
    autoreject_chunk = """auto_reject_post_ica = autoreject.AutoReject(
    n_interpolate=[1, 2, 4, 8, 32, 64],
    n_jobs=-1,
    random_state=100,
    thresh_method="bayesian_optimization",
    verbose=False,
    # n_interpolate=np.array([0]),
    # consensus=0.8,
).fit(epochs_clean_manual[:20])
print("fitting finished")
epochs_ar, reject_log_final = auto_reject_post_ica.transform(
    epochs_clean_manual, return_log=True
)
autoreject_post_fname = f"sub-{{subject}}_auto_reject_post_ica.h5"
auto_reject_post_ica.save(autoreject_post_fname, overwrite=True)

reject_plot = reject_log_final.plot("vertical")
reject_log_final.save(f"sub-{{subject}}_reject_log_final.h5", overwrite=True)

epochs_interpolated = epochs_ar.copy().interpolate_bads(exclude=["VREF"])
epochs_interpolated.save(
    f"analysis/sub-{{subject}}_interpolated-epo.fif", overwrite=True
)"""

    nb["cells"] += [
        nbformat.v4.new_markdown_cell(
            "## Extrapolation of bad channels using autoreject"
        )
    ]
    nb["cells"] += [nbformat.v4.new_code_cell(autoreject_chunk.format())]


def add_comparison_cells(nb):
    comparison_chunk = """# plot comparison
fig_psd1 = compare_before_after(epochs_clean_manual, epochs_interpolated, subject, title="After AR")"""

    nb["cells"] += [nbformat.v4.new_markdown_cell("## Comparison of EEG data")]
    nb["cells"] += [nbformat.v4.new_code_cell(comparison_chunk)]


# compute specparam
def add_specparam_cells(nb, sub_id: int):
    specparam_chunk = """from utils import (
    save_specparam_results,
    plot_specparam_on_scalp,
    examine_spectra,
    plot_models,
)
subject = {sub_id}
raw = mne.io.read_raw_fif(f"sub-{{subject}}_filtered_raw.fif", preload=True).filter(
    l_freq=1.0, h_freq=None
)
n_interpolated_channels = len(raw.info["bads"])
# epochs = create_epochs(raw, epochs_params=None, length=5, overlap=1.5)
epochs_interpolated = mne.read_epochs(
    f"analysis/sub-{{subject}}_interpolated-epo.fif", preload=True
)
ica_loaded = mne.preprocessing.read_ica(f"sub-{{subject}}_my_ica_model-ica.fif")

psd = epochs_interpolated.compute_psd().average()
spectra, freqs = psd.get_data(return_freqs=True)
# Initialize a FOOOFGroup object, with desired settings

# Define the frequency range to fit
freq_range = [2, 40]
fg.fit(freqs, spectra, freq_range)
fg.plot()
specparam_df = save_specparam_results(
    fg,
    epochs_interpolated,
    ica_loaded,
    subject,
    n_interpolated_channels=n_interpolated_channels,
)
display(specparam_df.head())
#epochs_ar_good = exclude_bad_channels(epochs_interpolated)

plot_specparam_on_scalp(fg,  epochs_interpolated, subject)
examine_spectra(fg, subject)
"""

    plot_models1 = """# plot models that are smallest, median and highest exponent
plot_models(fg, param_choice="exponent")"""

    plot_models2 = """# plot models with worst, median and best goodness of fit
plot_models(fg, param_choice="r_squared")"""

    nb["cells"] += [nbformat.v4.new_markdown_cell("## Specparam visualisation")]
    nb["cells"] += [nbformat.v4.new_code_cell(specparam_chunk.format(sub_id=sub_id))]
    nb["cells"] += [nbformat.v4.new_code_cell(plot_models1.format())]
    nb["cells"] += [nbformat.v4.new_code_cell(plot_models2.format())]

# Putting it all together

In [6]:
numbers = [
    101,
    106,
    108,
    120,
    124,
    126,
    128,
    129,
    132,
    135,
    139,
    142,
    144,
    145,
    146,
    147,
]

for sub_id in numbers:
    output_path = f"../check_subjects/sub_{sub_id}_reanalysis.ipynb"

    # create notebook
    nb = nbformat.v4.new_notebook()
    add_kernel_metadata(nb)
    add_yaml_metadata(nb)
    add_import_cells(nb, backend="qt")
    nb["cells"] += [nbformat.v4.new_markdown_cell("# Subject {}".format(sub_id))]
    # plot_raw_data(nb, sub_id)
    import_epochs(nb, sub_id=sub_id)
    rejection_log_manipulation(nb, sub_id=sub_id)
    compute_ica(nb, sub_id=sub_id)
    ica_bad_components(nb, sub_id=sub_id)
    ica_compare_bad_and_cleaned_data(nb, sub_id=sub_id)
    add_autoreject_post_ica_cells(nb, sub_id=sub_id)
    add_comparison_cells(nb)
    add_specparam_cells(nb, sub_id=sub_id)
    nb["cells"] += [nbformat.v4.new_code_cell("""plt.close("all")""".format())]
    # write notebook
    with open(output_path, "w", encoding="utf-8") as f:
        nbformat.write(nb, f)

    # validate and repair notebook if needed
    validate_and_repair_notebook(output_path)

The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
The notebook is valid.
