### this compares two signals using a perm test and plots their wavelets

In [None]:
import sys
print(sys.path)
sys.path.append("C:/Users/jz421/Desktop/GlobalLocal/IEEG_Pipelines/") #need to do this cuz otherwise ieeg isn't added to path...

import mne.time_frequency
import mne
from ieeg.io import get_data, raw_from_layout
from ieeg.navigate import trial_ieeg, channel_outlier_marker, crop_empty_data, outliers_to_nan
from ieeg.calc.scaling import rescale
from ieeg.calc.stats import time_perm_cluster
import os
from ieeg.timefreq.utils import wavelet_scaleogram, crop_pad
import numpy as np
from utils import calculate_RTs

In [None]:
import os
import numpy as np
import mne

HOME = os.path.expanduser("~")

# get box directory depending on OS
if os.name == 'nt': # windows
    LAB_root = os.path.join(HOME, "Box", "CoganLab")
else: # mac
    LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box", "CoganLab")

layout = get_data('GlobalLocal', root=LAB_root)

def get_uncorrected_wavelets(sub, layout, events, times):
    '''
    Get non-baseline-corrected wavelets for trials corresponding to those in events.

    Parameters:
    -----------
    sub : str
        The subject identifier.
    layout : BIDSLayout
        The BIDS layout object containing the data.
    events : list of str
        List of event names to extract trials for.
    times : list of float
        Time window relative to the events to extract data from.

    Returns:
    --------
    spec : mne.time_frequency.EpochsTFR
        The time-frequency representation of the wavelet-transformed data.

    Examples:
    ---------
    >>> sub = 'sub-01'
    >>> events = ['Stimulus/c25', 'Stimulus/c75']
    >>> times = [-0.5, 1.5]
    >>> spec = get_uncorrected_wavelets(sub, layout, events, times)
    >>> isinstance(spec, mne.time_frequency.EpochsTFR)
    True
    '''
    # Load the data
    filt = raw_from_layout(layout.derivatives['derivatives/clean'], subject=sub,
                            extension='.edf', desc='clean', preload=False) #get line-noise filtered data
    print(filt)

    ## Crop raw data to minimize processing time
    good = crop_empty_data(filt)

    # good.drop_channels(good.info['bads'])
    good.info['bads'] = channel_outlier_marker(good, 3, 2)
    good.drop_channels(good.info['bads'])
    # good.info['bads'] += channel_outlier_marker(good, 4, 2)
    # good.drop_channels(good.info['bads'])
    good.load_data()

    ch_type = filt.get_channel_types(only_data_chs=True)[0]
    good.set_eeg_reference(ref_channels="average", ch_type=ch_type)

    # Remove intermediates from mem
    good.plot()

    ## epoching and trial outlier removal

    save_dir = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet', sub)
    if not os.path.exists(save_dir):
            os.makedirs(save_dir)

    RTs, skipped = calculate_RTs(good)
    avg_RT = np.median(RTs)
    print(avg_RT)

    all_trials_list = []
    
    for event in events:
    # Epoching and HG extraction for each specified event. Then concatenate all trials epochs objects together (do Stimulus/c25 and Stimulus/c75 for example, and combine to get all congruent trials)
        times_adj = [times[0] - 0.5, times[1] + 0.5] # add 0.5 to beginning and end of times for padding
        trials = trial_ieeg(good, event, times_adj, preload=True,
                            reject_by_annotation=False)
        all_trials_list.append(trials)

    # Concatenate all trials
    all_trials = mne.concatenate_epochs(all_trials_list)

    outliers_to_nan(all_trials, outliers=10)
    spec = wavelet_scaleogram(all_trials, n_jobs=1, decim=int(good.info['sfreq'] / 100))
    crop_pad(spec, "0.5s")
    
    return spec

     
def get_difference_wavelets(sub, layout, events_condition_1, events_condition_2, times, p_thresh=0.05, ignore_adjacency=1, n_perm=100, n_jobs=1):
    '''
    Compares two signals, plotting the significantly different clusters in their wavelets.

    Parameters:
    -----------
    sub : str
        The subject identifier.
    layout : BIDSLayout
        The BIDS layout object containing the data.
    events_condition_1 : list of str
        List of event names for the first condition.
    events_condition_2 : list of str
        List of event names for the second condition.
    times : list of float
        Time window relative to the events to extract data from.
    p_thresh : float, optional
        The p-value threshold for significance (default is 0.05).
    ignore_adjacency : int, optional
        The number of adjacent time points to ignore when forming clusters (default is 1).
    n_perm : int, optional
        The number of permutations to perform (default is 100).
    n_jobs : int, optional
        The number of jobs to run in parallel (default is 1).

    Returns:
    --------
    mask : numpy.ndarray
        A boolean mask indicating significant clusters.
    pvals : numpy.ndarray
        The p-values for each cluster.

    Examples:
    ---------
    >>> sub = 'sub-01'
    >>> events_condition_1 = ['Stimulus/c25']
    >>> events_condition_2 = ['Stimulus/c75']
    >>> times = [-0.5, 1.5]
    >>> mask, pvals = get_difference_wavelets(sub, layout, events_condition_1, events_condition_2, times)
    >>> isinstance(mask, np.ndarray)
    True
    >>> isinstance(pvals, np.ndarray)
    True
    '''
    spec_condition_1 = get_uncorrected_wavelets(sub, layout, events_condition_1, times)
    spec_condition_2 = get_uncorrected_wavelets(sub, layout, events_condition_2, times)

    mask, pvals = time_perm_cluster(spec_condition_1._data, spec_condition_2._data, p_thresh, ignore_adjacency, n_perm, n_jobs)
    return mask, pvals