In [None]:
import mne
from mne.io import read_raw
import numpy as np
import matplotlib.pyplot as plt
from os.path import join
import os
import pandas as pd
import json
from collections import defaultdict

from functions import ephy_plotting, preprocessing, analysis, io, utils

# 1. Load the dataset #

In [None]:
working_path = os.path.dirname(os.getcwd())
results_path = join(working_path, "results")
behav_results_saving_path = join(results_path, "behav_results")
# read the json file containing the included and excluded subjects, based on the behavioral results
included_excluded_file = join(behav_results_saving_path, 'final_included_subjects.json')
with open(included_excluded_file, 'r') as file:
    included_subjects = json.load(file)

# keep only subjects starting with "sub":
included_subjects = [subj for subj in included_subjects if subj.startswith('sub')]
print(f'Included_subjects: {included_subjects}')
onedrive_path = utils._get_onedrive_path()

#  Set saving path for cleaned epochs
saving_path= join(results_path, 'lfp_epochs')
os.makedirs(saving_path, exist_ok=True)  # Create the directory if it doesn't exist

sub_dict_epochs = {}  #  Stores the epochs for each subject/session
all_sub_session_dict = {}
all_sub_session_dict = defaultdict(dict)  # Each missing key gets an empty dictionary


In [None]:
included_subjects.remove('sub028 DBS OFF mSST')

In [None]:
# # Load all data for all included subjects
# data = io.load_behav_data(included_subjects, onedrive_path)

# # Compute statistics for each loaded subject
# stats = {}
# stats = utils.extract_stats(data)
# # If no file was found, create a new JSON file
# filename = "stats.json"
# file_path = os.path.join(results_path, filename)
# #if not os.path.isfile(file_path):
# #    with open(file_path, "w", encoding="utf-8") as file:
# #            json.dump({}, file, indent=4)

# # Save the updated or new JSON file
# with open(file_path, "w", encoding="utf-8") as file:
#     json.dump(stats, file, indent=4)

# # remove sub027 DBS OFF mSST from included_subjects because it has not been synchronized yet
# #included_subjects.remove('sub027 DBS OFF mSST')
# included_subjects


# 2. Create full session plots for each subject (Raw traces, TFR plot, PSD) #

In [None]:
for session_ID in included_subjects:
    print(f"Now processing {session_ID}")
    session_dict = {}
    sub = session_ID[:6]
    condition = session_ID.split(' ') [1] + ' ' + session_ID.split(' ') [2]
    sub_onedrive_path = join(onedrive_path, sub)
    sub_onedrive_path_task = join(onedrive_path, sub, 'synced_data', session_ID)
    filename = [f for f in os.listdir(sub_onedrive_path_task) if (
        f.endswith('.set') and f.startswith('SYNCHRONIZED_INTRACRANIAL'))]
    
    if not filename:
        raise FileNotFoundError(f"No .set file found in {sub_onedrive_path_task}")

    file = join(sub_onedrive_path_task, filename[0])

    if not os.path.isfile(file):
        raise FileNotFoundError(f"File does not exist: {file}")

    print(f"Loading file: {file}")
    #file = join(sub_onedrive_path_task, filename[0])
    raw = read_raw(file, preload=True)

    saving_path_single = join(results_path, 'single_sub', f'{sub} mSST') 
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    ephy_plotting.plot_raw_stim(session_ID, raw, saving_path_single)
    psd_left, freqs_left, psd_right, freqs_right = analysis.compute_psd_welch(raw)
    session_dict['psd_left_V^2/Hz'] = psd_left
    session_dict['freqs_left'] = freqs_left
    session_dict['psd_right_V^2/Hz'] = psd_right
    session_dict['freqs_right'] = freqs_right

    # Compute band power for theta, alpha, low-beta and high-beta ranges:
    band_metrics_left = utils.compute_band_metrics(psd_left, freqs_left)
    band_metrics_right = utils.compute_band_metrics(psd_right, freqs_right)
    session_dict['left'] = band_metrics_left
    session_dict['right'] = band_metrics_right

    print(f'Values for Left STN: {band_metrics_left}')
    print(f'Values for Right STN: {band_metrics_right}')

    ephy_plotting.plot_psd_log(
        session_ID, raw, freqs_left, psd_left, 
        freqs_right, psd_right, saving_path_single, is_filt=False
        )
    ephy_plotting.plot_stft_stim(
        session_ID, raw, saving_path=saving_path_single, is_filt=False, 
        vmin = -18, vmax = -12, 
        fmin=0, fmax=100
        )

    all_sub_session_dict[sub][condition] = session_dict


In [None]:
df = analysis.compare_band_power(all_sub_session_dict)
# save dataframe to excel
df.to_excel(join(results_path, "band_power_comparison.xlsx"), index=False)

In [None]:
# from scipy.stats import wilcoxon

# for band in df['band'].unique():
#     for hemi in df['hemisphere'].unique():
#         subset = df[(df['band'] == band) & (df['hemisphere'] == hemi)]
#         stat, p_val = wilcoxon(subset['DBS OFF_power_uV2'], subset['DBS ON_power_uV2'])
#         print(f"{band} - {hemi}: Wilcoxon p={p_val:.4f}")

# 3. Work with epochs #

In [None]:
for session_ID in included_subjects:
    session_dict = {}
    sub = session_ID[:6]
    condition = session_ID.split(' ') [1] + ' ' + session_ID.split(' ') [2]
    print(f"Now processing {session_ID}")
    all_sub_session_dict[sub][condition] = session_dict
    
    sub_onedrive_path = join(onedrive_path, sub)
    sub_onedrive_path_task = join(onedrive_path, sub, 'synced_data', session_ID)
    filename = [f for f in os.listdir(sub_onedrive_path_task) if (
        f.endswith('.set') and f.startswith('SYNCHRONIZED_INTRACRANIAL'))]
    
    if not filename:
        raise FileNotFoundError(f"No .set file found in {sub_onedrive_path_task}")

    file = join(sub_onedrive_path_task, filename[0])

    if not os.path.isfile(file):
        raise FileNotFoundError(f"File does not exist: {file}")

    print(f"Loading file: {file}")
    #file = join(sub_onedrive_path_task, filename[0])
    raw = read_raw(file, preload=True)
    all_sub_session_dict[sub][condition]['CHANNELS'] = raw.ch_names

    # Rename channels to be consistent across subjects:
    new_channel_names = [
        "Left_STN",
        "Right_STN",
        "left_peak_STN",
        "right_peak_STN",
        "STIM_Left_STN",
        "STIM_Right_STN"
    ]

    # Get the existing channel names
    old_channel_names = raw.ch_names

    # Create a mapping from old to new names
    rename_dict = {old: new for old, new in zip(old_channel_names, new_channel_names)}

    # Rename the channels
    raw.rename_channels(rename_dict)

    all_sub_session_dict[sub][condition]['RENAMED_CHANNELS'] = raw.ch_names

    # Filter between 1 and 80 Hz:
    filtered_data = raw.copy().filter(l_freq=1, h_freq=80)

    # Extract events and create epochs
    # only keep lfp channels
    filtered_data_lfp = filtered_data.copy().pick_channels([filtered_data.ch_names[0], filtered_data.ch_names[1]])

    #epochs, filtered_event_dict = preprocessing.create_epochs(filtered_data_lfp, session_ID)

    mSST_raw_behav_session_data_path = join(
            onedrive_path, sub, "raw_data", 'BEHAVIOR', condition, 'mSST'
            )
    for filename in os.listdir(mSST_raw_behav_session_data_path):
            if filename.endswith(".csv"):
                fname = filename
    filepath_behav = join(mSST_raw_behav_session_data_path, fname)
    df = pd.read_csv(filepath_behav)

    # return the index of the first row which is not filled by a Nan value:
    start_task_index = df['blocks.thisRepN'].first_valid_index()
    stop_task_index = df['blocks.thisRepN'].last_valid_index()
    df_maintask = df.iloc[start_task_index:stop_task_index + 1] ### HERE MISTAKE OF INDEXING: CHECK IN OTHER SCRIPTS IF THIS IS ALSO WRONG!!!


    # remove all useless columns to clean up dataframe
    column_names = df_maintask.columns
    columns_to_keep = [i for i in [
        'blocks.thisN', 'trial_loop.thisN', 'trial_type', 
        'continue_signal_time', 'stop_signal_time', 
        'fixation_cross.started', 'go_rectangle.started',
        'key_resp_experiment.keys', 'key_resp_experiment.corr', 'key_resp_experiment.rt',
        'early_press_resp.keys', 'early_press_resp.rt', 'early_press_resp.corr',
        'late_key_resp1.keys', 'late_key_resp1.rt', 
        'late_key_resp2.keys', 'late_key_resp2.rt'
        ] if i in column_names]

    mini_df_maintask = df_maintask[columns_to_keep]

    # remove the trials with early presses, as in these trials the cues were not presented (for mSST)
    early_presses = mini_df_maintask[mini_df_maintask['early_press_resp.corr'] == 1]
    early_presses_trials = list(early_presses.index)
    number_early_presses = len(early_presses_trials)

    # remove trials with early presses from the dataframe:
    df_maintask_copy = mini_df_maintask.drop(early_presses_trials).reset_index(drop=True)

    # First generate global epochs (without taking into account success outcome)
    # events and event_id used for epochs creation
    events, event_id = mne.events_from_annotations(filtered_data_lfp)
    epochs, filtered_event_dict = preprocessing.create_epochs(
         filtered_data_lfp, 
         sub, 
         keys_to_keep = ['GC', 'GF', 'GO', 'GS', 'continue', 'stop'],
         tmin = -3.5,
         tmax = 3.5,
         baseline=None
         )
    n_epochs = len(epochs)

    # inverse mapping (event code -> label)
    inv_event_id = {v: k for k, v in event_id.items()}

    metadata = pd.DataFrame(index=np.arange(len(epochs)))
    metadata["event"] = [inv_event_id[e] for e in epochs.events[:, 2]]
    metadata["sample"] = epochs.events[:, 0]
    metadata["event_timing"] = epochs.events[:, 0] / raw.info['sfreq']  # in seconds
    metadata["trial_type"] = np.nan

    # LFP -> behavioral naming mapping
    mapping = {
        "GC": "go_continue_trial",
        "GO": "go_trial",
        "GF": "go_fast_trial",
        "GS": "stop_trial",
    }

    trial_mask = metadata["event"].isin(mapping.keys())

    assert trial_mask.sum() == len(df_maintask_copy), \
        f"Mismatch: {trial_mask.sum()} LFP trials vs {len(df_maintask_copy)} behavioral trials"

    # fill directly from behavioral file
    for col in df_maintask_copy.columns:
        metadata.loc[trial_mask, col] = df_maintask_copy[col].values

    for i in metadata.index:
        if metadata.loc[i, "event"] == "continue":
            # find the last GC before this
            prev_idx = metadata.loc[:i-1][metadata["event"] == "GC"].index[-1]
            metadata.loc[i, df_maintask_copy.columns] = metadata.loc[prev_idx, df_maintask_copy.columns]

        elif metadata.loc[i, "event"] == "stop":
            # find the last GS before this
            prev_idx = metadata.loc[:i-1][metadata["event"] == "GS"].index[-1]
            metadata.loc[i, df_maintask_copy.columns] = metadata.loc[prev_idx, df_maintask_copy.columns]

    epochs.metadata = metadata

    sub_dict_epochs[session_ID] = epochs


# For each session, look at the epochs and remove 'bad' epochs, then save the cleaned file #

In [None]:
subject_id = 'sub023 DBS ON mSST'
cleaned_epochs = sub_dict_epochs[subject_id]

In [None]:
cleaned_epochs

In [None]:
"""Implementation of all the FASTER steps."""

from collections import defaultdict

import mne
import numpy as np
import scipy.signal
from mne import pick_info
from mne._fiff.pick import _picks_by_type
from mne.preprocessing.bads import _find_outliers
from mne.utils import logger
from scipy.stats import kurtosis


def _bad_mask_to_names(info, bad_mask):
    """Remap mask to ch names."""
    bad_idx = [np.where(m)[0] for m in bad_mask]
    return [[info["ch_names"][k] for k in epoch] for epoch in bad_idx]


def _combine_indices(bads):
    """Summarize indices."""
    return list(set(v for val in bads.values() if len(val) > 0 for v in val))


def hurst(x):
    """Estimate Hurst exponent on a timeseries.

    The estimation is based on the second order discrete derivative.

    Parameters
    ----------
    x : 1D numpy array
        The timeseries to estimate the Hurst exponent for.

    Returns
    -------
    h : float
        The estimation of the Hurst exponent for the given timeseries.

    """
    y = np.cumsum(np.diff(x, axis=1), axis=1)

    b1 = [1, -2, 1]
    b2 = [1, 0, -2, 0, 1]

    # second order derivative
    y1 = scipy.signal.lfilter(b1, 1, y, axis=1)
    y1 = y1[:, len(b1) - 1 : -1]  # first values contain filter artifacts

    # wider second order derivative
    y2 = scipy.signal.lfilter(b2, 1, y, axis=1)
    y2 = y2[:, len(b2) - 1 : -1]  # first values contain filter artifacts

    s1 = np.mean(y1**2, axis=1)
    s2 = np.mean(y2**2, axis=1)

    return 0.5 * np.log2(s2 / s1)


def _efficient_welch(data, sfreq):
    """Call scipy.signal.welch with parameters optimized for greatest speed.

    Comes at the expense of precision. The window is set to ~10 seconds and windows are
    non-overlapping.

    Parameters
    ----------
    data : array, shape (..., n_samples)
        The timeseries to estimate signal power for. The last dimension
        is assumed to be time.
    sfreq : float
        The sample rate of the timeseries.

    Returns
    -------
    fs : array of float
        The frequencies for which the power spectra was calculated.
    ps : array, shape (..., frequencies)
        The power spectra for each timeseries.

    """
    from scipy.signal import welch

    nperseg = min(data.shape[-1], 2 ** int(np.log2(10 * sfreq) + 1))  # next power of 2

    return welch(data, sfreq, nperseg=nperseg, noverlap=0, axis=-1)


def _freqs_power(data, sfreq, freqs):
    fs, ps = _efficient_welch(data, sfreq)
    try:
        return np.sum([ps[..., np.searchsorted(fs, f)] for f in freqs], axis=0)
    except IndexError:
        raise ValueError(
            (
                "Insufficient sample rate to  estimate power at {} Hz for line "
                "noise detection. Use the 'metrics' parameter to disable the "
                "'line_noise' metric."
            ).format(freqs)
        )


def _distance_correction(info, picks, x):
    """Remove the effect of distance to reference sensor.

    Computes the distance of each sensor to the reference sensor. Then regresses the
    effect of this distance out of the values in x.

    Parameters
    ----------
    info : instance of Info
        The measurement info. This should contain positions for all the sensors.
    picks : list of int
        Indices of the channels that correspond to the values in x.
    x : list of float
        Values to correct.

    Returns
    -------
    x_corr : list of float
        values in x corrected for the distance to reference sensor.

    """
    pos = np.array([info["chs"][ch]["loc"][:3] for ch in picks])
    ref_pos = np.array([info["chs"][ch]["loc"][3:6] for ch in picks])

    if np.any(np.all(pos == 0, axis=1)):
        raise ValueError(
            "Cannot perform correction for distance to reference "
            "sensor: not all selected channels have position "
            "information."
        )
    if np.any(np.all(ref_pos == 0, axis=1)):
        raise ValueError(
            "Cannot perform correction for distance to reference "
            "sensor: the location of the reference sensor is not "
            "specified for all selected channels."
        )

    # Compute angular distances to the reference sensor
    pos /= np.linalg.norm(pos, axis=1)[:, np.newaxis]
    ref_pos /= np.linalg.norm(ref_pos, axis=1)[:, np.newaxis]
    angles = [np.arccos(np.dot(a, b)) for a, b in zip(pos, ref_pos)]

    # Fit a quadratic curve to correct for the angular distance
    fit = np.polyfit(angles, x, 2)
    return x - np.polyval(fit, angles)


def find_bad_channels(
    epochs,
    picks=None,
    max_iter=1,
    thres=3,
    eeg_ref_corr=False,
    use_metrics=None,
    return_by_metric=False,
):
    """Automatically find and mark bad channels.

    Implements the first step of the FASTER algorithm.

    This function attempts to automatically mark bad EEG channels by performing outlier
    detection. It operated on epoched data, to make sure only relevant data is analyzed.

    Parameters
    ----------
    epochs : Instance of Epochs
        The epochs for which bad channels need to be marked
    picks : list of int | None
        Channels to operate on. Defaults to EEG channels.
    thres : float
        The threshold value, in standard deviations, to apply. A channel crossing this
        threshold value is marked as bad. Defaults to 3.
    max_iter : int
        The maximum number of iterations performed during outlier detection
        (defaults to 1, as in the original FASTER paper).
    eeg_ref_corr : bool
        If the EEG data has been referenced using a single electrode setting this
        parameter to True will enable a correction factor for the distance of each
        electrode to the reference. If an average reference is applied, or the mean of
        multiple reference electrodes, set this parameter to False. Defaults to False,
        which disables the correction.
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'variance', 'correlation', 'hurst', 'kurtosis', 'line_noise'
        Defaults to all of them.
    return_by_metric : bool
        Whether to return the bad channels as a flat list (False, default) or as a
        dictionary with the names of the used metrics as keys and the bad channels found
        by this metric as values.

    Returns
    -------
    bads : list of str
        The names of the bad EEG channels.

    """
    metrics = {
        "variance": lambda x: np.var(x, axis=1),
        "correlation": lambda x: np.nanmean(
            np.ma.masked_array(np.corrcoef(x), np.identity(len(x), dtype=bool)), axis=0
        ),
        "hurst": lambda x: hurst(x),
        "kurtosis": lambda x: kurtosis(x, axis=1),
        "line_noise": lambda x: _freqs_power(x, epochs.info["sfreq"], [50, 60]),
    }

    if picks is None:
        picks = mne.pick_types(epochs.info, meg=False, eeg=True, exclude=[])
    if use_metrics is None:
        use_metrics = metrics.keys()

    # Concatenate epochs in time
    data = epochs.get_data(copy=False)[:, picks]
    data = data.transpose(1, 0, 2).reshape(data.shape[1], -1)

    # Find bad channels
    bads = defaultdict(list)
    info = pick_info(epochs.info, picks, copy=True)
    for ch_type, chs in _picks_by_type(info):
        logger.info("Bad channel detection on %s channels:" % ch_type.upper())
        for metric in use_metrics:
            scores = metrics[metric](data[chs])
            if eeg_ref_corr:
                scores = _distance_correction(epochs.info, picks, scores)
            bad_channels = [
                epochs.ch_names[picks[chs[i]]]
                for i in _find_outliers(scores, thres, max_iter)
            ]
            logger.info("\tBad by %s: %s" % (metric, bad_channels))
            bads[metric].append(bad_channels)

    bads = dict((k, np.concatenate(v).tolist()) for k, v in bads.items())

    if return_by_metric:
        return bads
    else:
        return _combine_indices(bads)


def _deviation(data):
    """Compute the deviation from mean for each channel in a set of epochs.

    This is not implemented as a lambda function, because the channel means should be
    cached during the computation.

    Parameters
    ----------
    data : 3D numpy array
        The epochs (#epochs x #channels x #samples).

    Returns
    -------
    dev : list of float
        For each epoch, the mean deviation of the channels.

    """
    ch_mean = np.mean(data, axis=2)
    return ch_mean - np.mean(ch_mean, axis=0)


def find_bad_epochs(
    epochs, picks=None, thres=3, max_iter=1, use_metrics=None, return_by_metric=False
):
    """Automatically find and mark bad epochs.

    Implements the second step of the FASTER algorithm.

    This function attempts to automatically mark bad epochs by performing outlier
    detection.

    Parameters
    ----------
    epochs : Instance of Epochs
        The epochs to analyze.
    picks : list of int | None
        Channels to operate on. Defaults to EEG channels.
    thres : float
        The threshold value, in standard deviations, to apply. An epoch
        crossing this threshold value is marked as bad. Defaults to 3.
    max_iter : int
        The maximum number of iterations performed during outlier detection
        (defaults to 1, as in the original FASTER paper).
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'amplitude', 'variance', 'deviation'
        Defaults to all of them.
    return_by_metric : bool
        Whether to return the bad channels as a flat list (False, default) or as a
        dictionary with the names of the used metrics as keys and the bad channels found
        by this metric as values.

    Returns
    -------
    bads : list of int
        The indices of the bad epochs.

    """
    metrics = {
        "amplitude": lambda x: np.mean(np.ptp(x, axis=2), axis=1),
        "deviation": lambda x: np.mean(_deviation(x), axis=1),
        "variance": lambda x: np.mean(np.var(x, axis=2), axis=1),
    }

    if picks is None:
        picks = mne.pick_types(epochs.info, meg=False, eeg=True, exclude="bads")
    if use_metrics is None:
        use_metrics = metrics.keys()

    info = pick_info(epochs.info, picks, copy=True)
    data = epochs.get_data(copy=True)[:, picks]

    bads = defaultdict(list)
    for ch_type, chs in _picks_by_type(info):
        #logger.info("Bad epoch detection on %s channels:" % ch_type.upper())
        for metric in use_metrics:
            scores = metrics[metric](data[:, chs])
            bad_epochs = _find_outliers(scores, thres, max_iter)
            #logger.info("\tBad by %s: %s" % (metric, bad_epochs))
            bads[metric].append(bad_epochs)

    bads = dict((k, np.concatenate(v).tolist()) for k, v in bads.items())
    if return_by_metric:
        return bads
    else:
        return _combine_indices(bads)


def _power_gradient(data, sfreq, prange):
    """Estimate the gradient of the power spectrum at upper frequencies.

    Parameters
    ----------
    data : array, shape (n_components, n_samples)
        The timeseries to estimate signal power for. The last dimension is presumed to
        be time.
    sfreq : float
        The sample rate of the timeseries.
    prange : pair of floats
        The (lower, upper) frequency limits of the power spectrum to use. In the FASTER
        paper, they set these to the passband of the lowpass filter.

    Returns
    -------
    grad : array of float
        The gradients of the timeseries.

    """
    fs, ps = _efficient_welch(data, sfreq)

    # Limit power spectrum to selected frequencies
    start, stop = (np.searchsorted(fs, p) for p in prange)
    if start >= ps.shape[1]:
        raise ValueError(
            (
                "Sample rate insufficient to estimate {} Hz power. "
                "Use the 'power_gradient_range' parameter to tweak "
                "the tested frequencies for this metric or use the "
                "'metrics' parameter to disable the "
                "'power_gradient' metric."
            ).format(prange[0])
        )
    ps = ps[:, start:stop]

    # Compute mean gradients
    return np.mean(np.diff(ps), axis=1)


def find_bad_components(
    ica,
    epochs,
    thres=3,
    max_iter=1,
    use_metrics=None,
    prange=None,
    return_by_metric=False,
):
    """Perform the third step of the FASTER algorithm.

    This function attempts to automatically mark bad ICA components by performing
    outlier detection.

    Parameters
    ----------
    ica : Instance of ICA
        The ICA operator, already fitted to the supplied Epochs object.
    epochs : Instance of Epochs
        The untransformed epochs to analyze.
    thres : float
        The threshold value, in standard deviations, to apply. A component crossing this
        threshold value is marked as bad. Defaults to 3.
    max_iter : int
        The maximum number of iterations performed during outlier detection
        (defaults to 1, as in the original FASTER paper).
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'eog_correlation', 'kurtosis', 'power_gradient', 'hurst',
            'median_gradient'
        Defaults to all of them.
    prange : None | pair of floats
        The (lower, upper) frequency limits of the power spectrum to use for the power
        gradient computation. In the FASTER paper, they set these to the passband of the
        highpass and lowpass filter. If None, defaults to the 'highpass' and 'lowpass'
        filter settings in ica.info.
    return_by_metric : bool
        Whether to return the bad channels as a flat list (False, default) or as a
        dictionary with the names of the used metrics as keys and the bad channels found
        by this metric as values.

    Returns
    -------
    bads : list of int
        The indices of the bad components.

    See Also
    --------
    ICA.find_bads_ecg
    ICA.find_bads_eog

    """
    source_data = ica.get_sources(epochs).get_data(copy=False).transpose(1, 0, 2)
    source_data = source_data.reshape(source_data.shape[0], -1)

    if prange is None:
        prange = (ica.info["highpass"], ica.info["lowpass"])
    if len(prange) != 2:
        raise ValueError("prange must be a pair of floats")

    metrics = {
        "eog_correlation": lambda x: x.find_bads_eog(epochs)[1],
        "kurtosis": lambda x: kurtosis(
            np.dot(x.mixing_matrix_.T, x.pca_components_[: x.n_components_]), axis=1
        ),
        "power_gradient": lambda x: _power_gradient(
            source_data, ica.info["sfreq"], prange
        ),
        "hurst": lambda x: hurst(source_data),
        "median_gradient": lambda x: np.median(np.abs(np.diff(source_data)), axis=1),
        "line_noise": lambda x: _freqs_power(
            source_data, epochs.info["sfreq"], [50, 60]
        ),
    }

    if use_metrics is None:
        use_metrics = metrics.keys()

    bads = defaultdict(list)
    for metric in use_metrics:
        scores = np.atleast_2d(metrics[metric](ica))
        for s in scores:
            bad_comps = _find_outliers(s, thres, max_iter)
            logger.info("Bad by %s:\n\t%s" % (metric, bad_comps))
            bads[metric].append(bad_comps)

    bads = dict((k, np.concatenate(v).tolist()) for k, v in bads.items())
    if return_by_metric:
        return bads
    else:
        return _combine_indices(bads)


def find_bad_channels_in_epochs(
    epochs,
    picks=None,
    thres=3,
    max_iter=1,
    eeg_ref_corr=False,
    use_metrics=None,
    return_by_metric=False,
):
    """Perform the fourth step of the FASTER algorithm.

    This function attempts to automatically mark bad channels in each epochs by
    performing outlier detection.

    Parameters
    ----------
    epochs : Instance of Epochs
        The epochs to analyze.
    picks : list of int | None
        Channels to operate on. Defaults to EEG channels.
    thres : float
        The threshold value, in standard deviations, to apply. An epoch crossing this
        threshold value is marked as bad. Defaults to 3.
    max_iter : int
        The maximum number of iterations performed during outlier detection
        (defaults to 1, as in the original FASTER paper).
    eeg_ref_corr : bool
        If the EEG data has been referenced using a single electrode setting this
        parameter to True will enable a correction factor for the distance of each
        electrode to the reference. If an average reference is applied, or the mean of
        multiple reference electrodes, set this parameter to False. Defaults to False,
        which disables the correction.
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'amplitude', 'variance', 'deviation', 'median_gradient'
        Defaults to all of them.
    return_by_metric : bool
        Whether to return the bad channels as a flat list (False, default) or as a
        dictionary with the names of the used metrics as keys and the bad channels found
        by this metric as values.

    Returns
    -------
    bads : list of lists of int
        For each epoch, the indices of the bad channels.

    """
    metrics = {
        "amplitude": lambda x: np.ptp(x, axis=2),
        "deviation": lambda x: _deviation(x),
        "variance": lambda x: np.var(x, axis=2),
        "median_gradient": lambda x: np.median(np.abs(np.diff(x)), axis=2),
        "line_noise": lambda x: _freqs_power(x, epochs.info["sfreq"], [50, 60]),
    }

    if picks is None:
        picks = mne.pick_types(epochs.info, meg=False, eeg=True, exclude="bads")
    if use_metrics is None:
        use_metrics = metrics.keys()

    info = pick_info(epochs.info, picks, copy=True)
    data = epochs.get_data(copy=False)[:, picks]
    bads = dict((m, np.zeros((len(data), len(picks)), dtype=bool)) for m in metrics)
    for ch_type, chs in _picks_by_type(info):
        ch_names = [info["ch_names"][k] for k in chs]
        chs = np.array(chs)
        for metric in use_metrics:
            logger.info(
                "Bad channel-in-epoch detection on %s channels:" % ch_type.upper()
            )
            s_epochs = metrics[metric](data[:, chs])
            for i_epochs, scores in enumerate(s_epochs):
                if eeg_ref_corr:
                    scores = _distance_correction(epochs.info, picks, scores)
                outliers = _find_outliers(scores, thres, max_iter)
                if len(outliers) > 0:
                    bad_segment = [ch_names[k] for k in outliers]
                    logger.info(
                        "Epoch %d, Bad by %s:\n\t%s" % (i_epochs, metric, bad_segment)
                    )
                    bads[metric][i_epochs, chs[outliers]] = True

    info = pick_info(epochs.info, picks, copy=True)
    if return_by_metric:
        bads = dict((m, _bad_mask_to_names(info, v)) for m, v in bads.items())
    else:
        bads = np.sum(list(bads.values()), axis=0).astype(bool)
        bads = _bad_mask_to_names(info, bads)

    return bads


def run_faster(epochs, thres=3, copy=True):
    """Run the entire FASTER pipeline on the data."""
    if copy:
        epochs = epochs.copy()

    # Step one
    logger.info("Step 1: mark bad channels")
    epochs.info["bads"] += find_bad_channels(epochs, thres=5)

    # Step two
    logger.info("Step 2: mark bad epochs")
    bad_epochs = find_bad_epochs(epochs, thres=thres)
    good_epochs = list(set(range(len(epochs))).difference(set(bad_epochs)))
    epochs = epochs[good_epochs]

    # Step three (using the build-in MNE functionality for this)
    logger.info("Step 3: mark bad ICA components")
    picks = mne.pick_types(epochs.info, meg=False, eeg=True, eog=True, exclude="bads")
    ica = mne.preprocessing.ICA(len(picks)).fit(epochs, picks=picks)
    ica.exclude = find_bad_components(ica, epochs, thres=thres)
    ica.apply(epochs)
    epochs.apply_baseline(epochs.baseline)

    # Step four
    logger.info("Step 4: mark bad channels for each epoch")
    bad_channels_per_epoch = find_bad_channels_in_epochs(epochs, thres=thres)
    for i, b in enumerate(bad_channels_per_epoch):
        if len(b) > 0:
            epoch = epochs[i]
            epoch.info["bads"] += b
            epoch.interpolate_bads()
            epochs._data[i, :, :] = epoch._data[0, :, :]

    # Now that the data is clean, apply average reference
    epochs.set_eeg_reference("average")

    # That's all for now
    return epochs

In [None]:
from mne._fiff.pick import _picks_by_type
from scipy.stats import zscore

epochs = cleaned_epochs.copy()

metrics = {
    "amplitude": lambda x: np.mean(np.ptp(x, axis=2), axis=1),
    "deviation": lambda x: np.mean(_deviation(x), axis=1),
    "variance": lambda x: np.mean(np.var(x, axis=2), axis=1),
}
thresh = 3
max_iter = 1
tail=0
picks = mne.pick_types(epochs.info, meg=False, eeg=True, exclude="bads")
use_metrics = metrics.keys()

info = pick_info(epochs.info, picks, copy=True)
data = epochs.get_data(copy=True)[:, picks]

bads={}
all_bads = []
bads = defaultdict(list)
for ch_type, chs in _picks_by_type(info):
    #logger.info("Bad epoch detection on %s channels:" % ch_type.upper())
    for metric in use_metrics:
        scores = metrics[metric](data[:, chs])

        ############################################""
        my_mask = np.zeros(len(scores), dtype=bool)
        for _ in range(max_iter):
            scores = np.ma.masked_array(scores, my_mask)
            if tail == 0:
                this_z = np.abs(zscore(scores))
            elif tail == 1:
                this_z = zscore(scores)
            elif tail == -1:
                this_z = -zscore(scores)
            else:
                raise ValueError(f"Tail parameter {tail} not recognised.")
            local_bad = this_z > thresh
            my_mask = np.max([my_mask, local_bad], 0)
            if not np.any(local_bad):
                break

        bad_epochs = np.where(my_mask)[0]
        ##############################################""
        
        #bad_epochs = _find_outliers(scores, thres, max_iter)
        #logger.info("\tBad by %s: %s" % (metric, bad_epochs))
        bads[metric].append(bad_epochs)
        all_bads.append(bad_epochs)

In [None]:
all_bads

In [None]:
# Step 2: mark bad epochs
bad_epochs = find_bad_epochs(cleaned_epochs.copy(), return_by_metric=False)
bad_epochs.sort()

In [None]:
bad_epochs

In [None]:
cleaned_epochs

In [None]:
%matplotlib qt

fig = cleaned_epochs.plot(n_epochs=4, events=True, block=True)
plt.close(fig)  

In [None]:
%matplotlib qt

# cleaned_epochs.plot(n_epochs=10, n_channels = len(cleaned_epochs.ch_names), events=True)

In [None]:
# nan_bads = [252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272 ,273, 274, 275, 276, 277, 278, 279, 280]

In [None]:
# cleaned_epochs.drop(nan_bads)

In [None]:
%matplotlib qt
cleaned_epochs.plot(n_epochs=4, n_channels = len(cleaned_epochs.ch_names), events=True)

In [None]:
cleaned_epochs

In [None]:
metadata_df = pd.DataFrame(cleaned_epochs.metadata)
# save both to csv (easier for later python import), and xlsx (easier to read in excel)
metadata_df.to_csv(os.path.join(saving_path, f"{subject_id}_cleaned-long-epo_metadata.csv"), index=True)
metadata_df.to_excel(os.path.join(saving_path, f"{subject_id}_cleaned-long-epo_metadata.xlsx"), index=True)

In [None]:
file_epoch = os.path.join(saving_path, f"{subject_id}_cleaned-long-epo.fif")
cleaned_epochs.save(file_epoch, overwrite=True)

In [None]:
#epoch_reload = mne.read_epochs(os.path.join(saving_path, f"sub011 DBS OFF mSST_cleaned-long-epo.fif"), preload=True)

In [None]:
# cropped_epochs = cleaned_epochs.copy().crop(tmin=-0.5, tmax=1.5)
# file_cropped_epoch = os.path.join(saving_path, f"{subject_id}_cleaned-short-epo.fif")
# cropped_epochs.save(file_cropped_epoch, overwrite=True)