### drift testing

### one subject for testing

In [None]:
###
import sys
print(sys.path)
sys.path.append("C:/Users/jz421/Desktop/GlobalLocal/IEEG_Pipelines/") #need to do this cuz otherwise ieeg isn't added to path...

from ieeg.navigate import channel_outlier_marker, trial_ieeg, crop_empty_data, \
    outliers_to_nan
from ieeg.io import raw_from_layout, get_data
from ieeg.timefreq.utils import crop_pad
from ieeg.timefreq import gamma
from ieeg.calc.scaling import rescale
import mne
import os
import numpy as np
from ieeg.calc.reshape import make_data_same
from ieeg.calc.stats import time_perm_cluster

from misc_functions import calculate_RTs, save_channels_to_file, save_sig_chans
import matplotlib.pyplot as plt


def get_baseline(inst: mne.io.BaseRaw, base_times_list: list[tuple[float, float]]):
    all_data = []
    all_events = []
    
    inst = inst.copy()
    inst.load_data()
    ch_type = inst.get_channel_types(only_data_chs=True)[0]
    inst.set_eeg_reference(ref_channels="average", ch_type=ch_type)
    
    last_sample = 0
    for base_times in base_times_list:
        adjusted_base_times = [base_times[0] - 0.5, base_times[1] + 0.5]
        trials = trial_ieeg(inst, "experimentStart", adjusted_base_times, preload=True)
        
        if trials is None or len(trials) == 0:
            print(f"Warning: No trials found for base_times {base_times}. Skipping...")
            continue
        
        all_data.append(trials.get_data())
        all_events.extend([[last_sample + i, 0, 1] for i in range(len(trials))])
        last_sample += len(trials)
    
    combined_data = np.concatenate(all_data, axis=0)
    combined_trials = mne.EpochsArray(combined_data, inst.info, events=np.array(all_events), tmin=trials.tmin)
    
    HG_base = gamma.extract(combined_trials, copy=True, n_jobs=1)
    broadband_base = gamma.extract(combined_trials, copy=True, n_jobs=1, passband=(0,1000))
    crop_pad(HG_base, "0.5s")
    crop_pad(broadband_base, "0.5s")
    
    del inst
    return HG_base, broadband_base




def plot_HG_and_stats(sub, task, output_name, event=None, times=(-1, 1.5),
                      base_times=(-0.5, 0), LAB_root=None, channels=None,
                      full_trial_base=False):
    """
    Plot high gamma (HG) and statistics for a given subject and task using specified event.

    Parameters:
    - sub (str): The subject identifier.
    - task (str): The task identifier.
    - output_name (str): The name for the output files.
    - event (str, optional): Event name to process. Defaults to None.
    - times (tuple, optional): A tuple indicating the start and end times for processing. Defaults to (-1, 1.5).
    - base_times (tuple, optional): A tuple indicating the start and end base times for processing. Defaults to (-0.5, 0).
    - LAB_root (str, optional): The root directory for the lab. Will be determined based on OS if not provided. Defaults to None.
    - channels (list of strings, optional): The channels to plot and get stats for. Default is all channels.
    - full_trial_base (boolean): Whether to use the full trial as the baseline period. Default is False.
    This function will process the provided event for a given subject and task.
    High gamma (HG) will be computed, and statistics will be calculated and plotted.
    The results will be saved to output files.
    """
    pass


sub = 'D0057'
task = 'GlobalLocal'
output_name = "driftTest_Stimulus_experimentStartBaseNormalizedAnd0.5SecBeforeEventTrialsonXAxis"
events = ["Stimulus"]
times = (-1,1.5)
base_times = [(1,31),(41,71)]
LAB_root = None
channels = None
full_trial_base = False

if LAB_root is None:
    HOME = os.path.expanduser("~")
    if os.name == 'nt':  # windows
        LAB_root = os.path.join(HOME, "Box", "CoganLab")
    else:  # mac
        LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box",
                                "CoganLab")

layout = get_data(task, root=LAB_root)
filt = raw_from_layout(layout.derivatives['derivatives/clean'], subject=sub,
                       extension='.edf', desc='clean', preload=False)
save_dir = os.path.join(layout.root, 'derivatives', 'freqFilt', 'figs', sub)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

###

good = crop_empty_data(filt)
# %%

print(f"good channels before dropping bads: {len(good.ch_names)}")
print(f"filt channels before dropping bads: {len(filt.ch_names)}")

good.info['bads'] = channel_outlier_marker(good, 3, 2)
print("Bad channels in 'good':", good.info['bads'])

filt.drop_channels(good.info['bads'])  # this has to come first cuz if you drop from good first, then good.info['bads'] is just empty
good.drop_channels(good.info['bads'])

print("Bad channels in 'good' after dropping once:", good.info['bads'])

print(f"good channels after dropping bads: {len(good.ch_names)}")
print(f"filt channels after dropping bads: {len(filt.ch_names)}")

HG_base, broadband_base = get_baseline(filt, base_times)
good.load_data()

# If channels is None, use all channels
if channels is None:
    channels = good.ch_names
else:
    # Validate the provided channels
    invalid_channels = [ch for ch in channels if ch not in good.ch_names]
    if invalid_channels:
        raise ValueError(
            f"The following channels are not valid: {invalid_channels}")

    # Use only the specified channels
    good.pick_channels(channels)

ch_type = filt.get_channel_types(only_data_chs=True)[0]
good.set_eeg_reference(ref_channels="average", ch_type=ch_type)


# Create a baseline EpochsTFR using the stimulus event

# adjusted_base_times = [base_times[0] - 0.5, base_times[1] + 0.5]
# trials = trial_ieeg(good, "Stimulus", adjusted_base_times, preload=True)
# outliers_to_nan(trials, outliers=10)
# HG_base_fixationCross = gamma.extract(trials, copy=False, n_jobs=1)
# crop_pad(HG_base_fixationCross, "0.5s")


all_epochs_list = []

for event in events:
# Epoching and HG extraction for each specified event. Then concatenate all trials epochs objects together (do Stimulus/c25 and Stimulus/c75 for example, and combine to get all congruent trials)
    times_adj = [times[0] - 0.5, times[1] + 0.5]
    trials = trial_ieeg(good, event, times_adj, preload=True,
                        reject_by_annotation=False)
    all_epochs_list.append(trials)

# Concatenate all trials
all_trials = mne.concatenate_epochs(all_epochs_list)

outliers_to_nan(all_trials, outliers=10)
HG_ev1 = gamma.extract(all_trials, copy=True, n_jobs=1)
broadband_ev1 = gamma.extract(all_trials, copy=True, n_jobs=1, passband=(0,1000))
print("HG_ev1 before crop_pad: ", HG_ev1.tmin, HG_ev1.tmax)
crop_pad(HG_ev1, "0.5s")
print("HG_ev1 after crop_pad: ", HG_ev1.tmin, HG_ev1.tmax)

crop_pad(broadband_ev1, "0.5s")
HG_ev1_rescaled = rescale(HG_ev1, HG_base, copy=True, mode='zscore')
broadband_ev1_rescaled = rescale(broadband_ev1, broadband_base, copy=True, mode='zscore')
# HG_ev1_rescaled_fixationCross = rescale(HG_ev1, HG_base_fixationCross, copy=True, mode='zscore')

HG_base.decimate(2)
HG_ev1.decimate(2)

broadband_base.decimate(2)
broadband_ev1.decimate(2)

# HG_base_fixationCross.decimate(2)


HG_ev1_avgOverTime = np.nanmean(HG_ev1.get_data(), axis=2)
HG_ev1_rescaled_avgOverTime = np.nanmean(HG_ev1_rescaled.get_data(), axis=2)


HG_ev1_evoke = HG_ev1.average(method=lambda x: np.nanmean(x, axis=0)) #axis=0 should be set for actually running this, the axis=2 is just for drift testing.
HG_ev1_evoke_rescaled = HG_ev1_rescaled.average(method=lambda x: np.nanmean(x, axis=0))

if event == "Stimulus":
    print('plotting stimulus')
    fig = HG_ev1_evoke_rescaled.plot(unit=False, scalings=dict(sEEG=1)) #this line is not finishing...
    print('plotted')
    # for ax in fig.axes:
    #     ax.axvline(x=avg_RT, color='r', linestyle='--')
    print('about to save')
    fig.savefig(save_dir + '_HG_ev1_Stimulus_zscore.png')
    print('saved')
else:
    print('about to plot if not stimulus')
    fig = HG_ev1_evoke_rescaled.plot(unit=False, scalings=dict(sEEG=1))
    print('plotted non stimulus')
    fig.savefig(save_dir + f'_HG_ev1_{output_name}_zscore.png')

###
print(f"Shape of HG_ev1._data: {HG_ev1._data.shape}")
print(f"Shape of HG_base._data: {HG_base._data.shape}")

sig1 = HG_ev1._data
sig2 = HG_base._data
sig2 = make_data_same(sig2, sig1.shape)
print(f"Shape of sig1: {sig1.shape}")
print(f"Shape of sig2: {sig2.shape}")

mat = time_perm_cluster(sig1, sig2, 0.05, n_jobs=6, ignore_adjacency=1)
fig = plt.figure()
plt.imshow(mat, aspect='auto')
fig.savefig(save_dir + f'_{output_name}_stats.png', dpi=300)

channels = good.ch_names

#save channels with their indices 
save_channels_to_file(channels, sub, task, save_dir)

# save significant channels to a json
save_sig_chans(f'{output_name}', mat, channels, sub, save_dir)


In [None]:
filt[:,slice(0, 4000)]

### loop over all subjects

In [None]:
# crop_pad(HG_ev1, "0.5s")
HG_ev1.tmax

In [None]:
###
import sys
print(sys.path)
sys.path.append("C:/Users/jz421/Desktop/GlobalLocal/IEEG_Pipelines/") #need to do this cuz otherwise ieeg isn't added to path...

from ieeg.navigate import channel_outlier_marker, trial_ieeg, crop_empty_data, \
    outliers_to_nan
from ieeg.io import raw_from_layout, get_data
from ieeg.timefreq.utils import crop_pad
from ieeg.timefreq import gamma
from ieeg.calc.scaling import rescale
import mne
import os
import numpy as np
from ieeg.calc.reshape import make_data_same
from ieeg.calc.stats import time_perm_cluster

from misc_functions import calculate_RTs, save_channels_to_file, save_sig_chans
import matplotlib.pyplot as plt


def get_baseline(inst: mne.io.BaseRaw, base_times: tuple[float, float]):
    inst = inst.copy()
    inst.load_data()
    ch_type = inst.get_channel_types(only_data_chs=True)[0]
    inst.set_eeg_reference(ref_channels="average", ch_type=ch_type)

    adjusted_base_times = [base_times[0] - 0.5, base_times[1] + 0.5]
    trials = trial_ieeg(inst, "experimentStart", adjusted_base_times,
                        preload=True)
    # outliers_to_nan(trials, outliers=10)
    HG_base = gamma.extract(trials, copy=False, n_jobs=1)
    crop_pad(HG_base, "0.5s")
    del inst
    return HG_base


def plot_HG_and_stats(sub, task, output_name, event=None, times=(-1, 1.5),
                      base_times=(-0.5, 0), LAB_root=None, channels=None,
                      full_trial_base=False):
    """
    Plot high gamma (HG) and statistics for a given subject and task using specified event.

    Parameters:
    - sub (str): The subject identifier.
    - task (str): The task identifier.
    - output_name (str): The name for the output files.
    - event (str, optional): Event name to process. Defaults to None.
    - times (tuple, optional): A tuple indicating the start and end times for processing. Defaults to (-1, 1.5).
    - base_times (tuple, optional): A tuple indicating the start and end base times for processing. Defaults to (-0.5, 0).
    - LAB_root (str, optional): The root directory for the lab. Will be determined based on OS if not provided. Defaults to None.
    - channels (list of strings, optional): The channels to plot and get stats for. Default is all channels.
    - full_trial_base (boolean): Whether to use the full trial as the baseline period. Default is False.
    This function will process the provided event for a given subject and task.
    High gamma (HG) will be computed, and statistics will be calculated and plotted.
    The results will be saved to output files.
    """
    pass

subjects = ['D0057', 'D0059', 'D0063', 'D0065', 'D0069', 'D0071']
for sub in subjects:
    sub = sub
    output_name = "responseIncongruent_experimentStartBase1secTo101Sec"
    events = ["Response/i25","Response/i75"]
    times = (-1, 1.5)
    base_times = (1, 101)
    LAB_root = None
    channels = None
    full_trial_base = False

    if LAB_root is None:
        HOME = os.path.expanduser("~")
        if os.name == 'nt':  # windows
            LAB_root = os.path.join(HOME, "Box", "CoganLab")
        else:  # mac
            LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box",
                                    "CoganLab")

    layout = get_data(task, root=LAB_root)
    filt = raw_from_layout(layout.derivatives['derivatives/clean'], subject=sub,
                        extension='.edf', desc='clean', preload=False)
    save_dir = os.path.join(layout.root, 'derivatives', 'freqFilt', 'figs', sub)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    ###

    good = crop_empty_data(filt)
    # %%

    print(f"good channels before dropping bads: {len(good.ch_names)}")
    print(f"filt channels before dropping bads: {len(filt.ch_names)}")

    good.info['bads'] = channel_outlier_marker(good, 3, 2)
    print("Bad channels in 'good':", good.info['bads'])

    filt.drop_channels(good.info['bads'])  # this has to come first cuz if you drop from good first, then good.info['bads'] is just empty
    good.drop_channels(good.info['bads'])

    print("Bad channels in 'good' after dropping once:", good.info['bads'])

    print(f"good channels after dropping bads: {len(good.ch_names)}")
    print(f"filt channels after dropping bads: {len(filt.ch_names)}")

    HG_base = get_baseline(filt, base_times)
    good.load_data()

    # If channels is None, use all channels
    if channels is None:
        channels = good.ch_names
    else:
        # Validate the provided channels
        invalid_channels = [ch for ch in channels if ch not in good.ch_names]
        if invalid_channels:
            raise ValueError(
                f"The following channels are not valid: {invalid_channels}")

        # Use only the specified channels
        good.pick_channels(channels)

    ch_type = filt.get_channel_types(only_data_chs=True)[0]
    good.set_eeg_reference(ref_channels="average", ch_type=ch_type)

    all_epochs_list = []

    for event in events:
    # Epoching and HG extraction for each specified event. Then concatenate all trials epochs objects together (do Stimulus/c25 and Stimulus/c75 for example, and combine to get all congruent trials)
        times_adj = [times[0] - 0.5, times[1] + 0.5]
        trials = trial_ieeg(good, event, times_adj, preload=True,
                            reject_by_annotation=False)
        all_epochs_list.append(trials)

    # Concatenate all trials
    all_trials = mne.concatenate_epochs(all_epochs_list)

    outliers_to_nan(all_trials, outliers=10)
    HG_ev1 = gamma.extract(all_trials, copy=False, n_jobs=1)
    crop_pad(HG_ev1, "0.5s")
    HG_ev1_rescaled = rescale(HG_ev1, HG_base, copy=True, mode='zscore')

    HG_base.decimate(2)
    HG_ev1.decimate(2)
    RTs, skipped = calculate_RTs(good)
    avg_RT = np.median(RTs)

    HG_ev1_evoke = HG_ev1.average(method=lambda x: np.nanmean(x, axis=0))
    HG_ev1_evoke_rescaled = HG_ev1_rescaled.average(method=lambda x: np.nanmean(x, axis=0))

    if event == "Stimulus":
        print('plotting stimulus')
        fig = HG_ev1_evoke_rescaled.plot(unit=False, scalings=dict(sEEG=1)) #this line is not finishing...
        print('plotted')
        for ax in fig.axes:
            ax.axvline(x=avg_RT, color='r', linestyle='--')
        print('about to save')
        fig.savefig(save_dir + '_HG_ev1_Stimulus_zscore.png')
        print('saved')
    else:
        print('about to plot if not stimulus')
        fig = HG_ev1_evoke_rescaled.plot(unit=False, scalings=dict(sEEG=1))
        print('plotted non stimulus')
        fig.savefig(save_dir + f'_HG_ev1_{output_name}_zscore.png')

    ###
    print(f"Shape of HG_ev1._data: {HG_ev1._data.shape}")
    print(f"Shape of HG_base._data: {HG_base._data.shape}")

    sig1 = HG_ev1._data
    sig2 = HG_base._data
    sig2 = make_data_same(sig2, sig1.shape)
    print(f"Shape of sig1: {sig1.shape}")
    print(f"Shape of sig2: {sig2.shape}")

    mat = time_perm_cluster(sig1, sig2, 0.05, n_jobs=6, ignore_adjacency=1)
    fig = plt.figure()
    plt.imshow(mat, aspect='auto')
    fig.savefig(save_dir + f'_{output_name}_stats.png', dpi=300)

    channels = good.ch_names

    #save channels with their indices 
    save_channels_to_file(channels, sub, task, save_dir)

    # save significant channels to a json
    save_sig_chans(f'{output_name}', mat, channels, sub, save_dir)

In [None]:
###
import sys
print(sys.path)
sys.path.append("C:/Users/jz421/Desktop/GlobalLocal/IEEG_Pipelines/") #need to do this cuz otherwise ieeg isn't added to path...

from ieeg.navigate import channel_outlier_marker, trial_ieeg, crop_empty_data, \
    outliers_to_nan
from ieeg.io import raw_from_layout, get_data
from ieeg.timefreq.utils import crop_pad
from ieeg.timefreq import gamma
from ieeg.calc.scaling import rescale
import mne
import os
import numpy as np
from ieeg.calc.reshape import make_data_same
from ieeg.calc.stats import time_perm_cluster

from misc_functions import calculate_RTs, save_channels_to_file, save_sig_chans
import matplotlib.pyplot as plt


def get_baseline(inst: mne.io.BaseRaw, base_times: tuple[float, float]):
    inst = inst.copy()
    inst.load_data()
    ch_type = inst.get_channel_types(only_data_chs=True)[0]
    inst.set_eeg_reference(ref_channels="average", ch_type=ch_type)

    adjusted_base_times = [base_times[0] - 0.5, base_times[1] + 0.5]
    trials = trial_ieeg(inst, "experimentStart", adjusted_base_times,
                        preload=True)
    # outliers_to_nan(trials, outliers=10)
    HG_base = gamma.extract(trials, copy=False, n_jobs=1)
    crop_pad(HG_base, "0.5s")
    del inst
    return HG_base


def plot_HG_and_stats(sub, task, output_name, event=None, times=(-1, 1.5),
                      base_times=(-0.5, 0), LAB_root=None, channels=None,
                      full_trial_base=False):
    """
    Plot high gamma (HG) and statistics for a given subject and task using specified event.

    Parameters:
    - sub (str): The subject identifier.
    - task (str): The task identifier.
    - output_name (str): The name for the output files.
    - event (str, optional): Event name to process. Defaults to None.
    - times (tuple, optional): A tuple indicating the start and end times for processing. Defaults to (-1, 1.5).
    - base_times (tuple, optional): A tuple indicating the start and end base times for processing. Defaults to (-0.5, 0).
    - LAB_root (str, optional): The root directory for the lab. Will be determined based on OS if not provided. Defaults to None.
    - channels (list of strings, optional): The channels to plot and get stats for. Default is all channels.
    - full_trial_base (boolean): Whether to use the full trial as the baseline period. Default is False.
    This function will process the provided event for a given subject and task.
    High gamma (HG) will be computed, and statistics will be calculated and plotted.
    The results will be saved to output files.
    """
    pass

subjects = ['D0057', 'D0059', 'D0063', 'D0065', 'D0069', 'D0071']
for sub in subjects:
    sub = sub
    output_name = "stimulus_experimentStartBase1secTo101Sec"
    events = ["Stimulus"]
    times = (-1, 1.5)
    base_times = (1, 101)
    LAB_root = None
    channels = None
    full_trial_base = False

    if LAB_root is None:
        HOME = os.path.expanduser("~")
        if os.name == 'nt':  # windows
            LAB_root = os.path.join(HOME, "Box", "CoganLab")
        else:  # mac
            LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box",
                                    "CoganLab")

    layout = get_data(task, root=LAB_root)
    filt = raw_from_layout(layout.derivatives['derivatives/clean'], subject=sub,
                        extension='.edf', desc='clean', preload=False)
    save_dir = os.path.join(layout.root, 'derivatives', 'freqFilt', 'figs', sub)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    ###

    good = crop_empty_data(filt)
    # %%

    print(f"good channels before dropping bads: {len(good.ch_names)}")
    print(f"filt channels before dropping bads: {len(filt.ch_names)}")

    good.info['bads'] = channel_outlier_marker(good, 3, 2)
    print("Bad channels in 'good':", good.info['bads'])

    filt.drop_channels(good.info['bads'])  # this has to come first cuz if you drop from good first, then good.info['bads'] is just empty
    good.drop_channels(good.info['bads'])

    print("Bad channels in 'good' after dropping once:", good.info['bads'])

    print(f"good channels after dropping bads: {len(good.ch_names)}")
    print(f"filt channels after dropping bads: {len(filt.ch_names)}")

    HG_base = get_baseline(filt, base_times)
    good.load_data()

    # If channels is None, use all channels
    if channels is None:
        channels = good.ch_names
    else:
        # Validate the provided channels
        invalid_channels = [ch for ch in channels if ch not in good.ch_names]
        if invalid_channels:
            raise ValueError(
                f"The following channels are not valid: {invalid_channels}")

        # Use only the specified channels
        good.pick_channels(channels)

    ch_type = filt.get_channel_types(only_data_chs=True)[0]
    good.set_eeg_reference(ref_channels="average", ch_type=ch_type)

    all_epochs_list = []

    for event in events:
    # Epoching and HG extraction for each specified event. Then concatenate all trials epochs objects together (do Stimulus/c25 and Stimulus/c75 for example, and combine to get all congruent trials)
        times_adj = [times[0] - 0.5, times[1] + 0.5]
        trials = trial_ieeg(good, event, times_adj, preload=True,
                            reject_by_annotation=False)
        all_epochs_list.append(trials)

    # Concatenate all trials
    all_trials = mne.concatenate_epochs(all_epochs_list)

    outliers_to_nan(all_trials, outliers=10)
    HG_ev1 = gamma.extract(all_trials, copy=False, n_jobs=1)
    crop_pad(HG_ev1, "0.5s")
    HG_ev1_rescaled = rescale(HG_ev1, HG_base, copy=True, mode='zscore')

    HG_base.decimate(2)
    HG_ev1.decimate(2)
    RTs, skipped = calculate_RTs(good)
    avg_RT = np.median(RTs)

    HG_ev1_evoke = HG_ev1.average(method=lambda x: np.nanmean(x, axis=0))
    HG_ev1_evoke_rescaled = HG_ev1_rescaled.average(method=lambda x: np.nanmean(x, axis=0))

    if event == "Stimulus":
        print('plotting stimulus')
        fig = HG_ev1_evoke_rescaled.plot(unit=False, scalings=dict(sEEG=1)) #this line is not finishing...
        print('plotted')
        for ax in fig.axes:
            ax.axvline(x=avg_RT, color='r', linestyle='--')
        print('about to save')
        fig.savefig(save_dir + '_HG_ev1_Stimulus_zscore.png')
        print('saved')
    else:
        print('about to plot if not stimulus')
        fig = HG_ev1_evoke_rescaled.plot(unit=False, scalings=dict(sEEG=1))
        print('plotted non stimulus')
        fig.savefig(save_dir + f'_HG_ev1_{output_name}_zscore.png')

    ###
    print(f"Shape of HG_ev1._data: {HG_ev1._data.shape}")
    print(f"Shape of HG_base._data: {HG_base._data.shape}")

    sig1 = HG_ev1._data
    sig2 = HG_base._data
    sig2 = make_data_same(sig2, sig1.shape)
    print(f"Shape of sig1: {sig1.shape}")
    print(f"Shape of sig2: {sig2.shape}")

    mat = time_perm_cluster(sig1, sig2, 0.05, n_jobs=6, ignore_adjacency=1)
    fig = plt.figure()
    plt.imshow(mat, aspect='auto')
    fig.savefig(save_dir + f'_{output_name}_stats.png', dpi=300)

    channels = good.ch_names

    #save channels with their indices 
    save_channels_to_file(channels, sub, task, save_dir)

    # save significant channels to a json
    save_sig_chans(f'{output_name}', mat, channels, sub, save_dir)

In [None]:
###
import sys
print(sys.path)
sys.path.append("C:/Users/jz421/Desktop/GlobalLocal/IEEG_Pipelines/") #need to do this cuz otherwise ieeg isn't added to path...

from ieeg.navigate import channel_outlier_marker, trial_ieeg, crop_empty_data, \
    outliers_to_nan
from ieeg.io import raw_from_layout, get_data
from ieeg.timefreq.utils import crop_pad
from ieeg.timefreq import gamma
from ieeg.calc.scaling import rescale
import mne
import os
import numpy as np
from ieeg.calc.reshape import make_data_same
from ieeg.calc.stats import time_perm_cluster

from misc_functions import calculate_RTs, save_channels_to_file, save_sig_chans
import matplotlib.pyplot as plt


def get_baseline(inst: mne.io.BaseRaw, base_times: tuple[float, float]):
    inst = inst.copy()
    inst.load_data()
    ch_type = inst.get_channel_types(only_data_chs=True)[0]
    inst.set_eeg_reference(ref_channels="average", ch_type=ch_type)

    adjusted_base_times = [base_times[0] - 0.5, base_times[1] + 0.5]
    trials = trial_ieeg(inst, "experimentStart", adjusted_base_times,
                        preload=True)
    # outliers_to_nan(trials, outliers=10)
    HG_base = gamma.extract(trials, copy=False, n_jobs=1)
    crop_pad(HG_base, "0.5s")
    del inst
    return HG_base


def plot_HG_and_stats(sub, task, output_name, event=None, times=(-1, 1.5),
                      base_times=(-0.5, 0), LAB_root=None, channels=None,
                      full_trial_base=False):
    """
    Plot high gamma (HG) and statistics for a given subject and task using specified event.

    Parameters:
    - sub (str): The subject identifier.
    - task (str): The task identifier.
    - output_name (str): The name for the output files.
    - event (str, optional): Event name to process. Defaults to None.
    - times (tuple, optional): A tuple indicating the start and end times for processing. Defaults to (-1, 1.5).
    - base_times (tuple, optional): A tuple indicating the start and end base times for processing. Defaults to (-0.5, 0).
    - LAB_root (str, optional): The root directory for the lab. Will be determined based on OS if not provided. Defaults to None.
    - channels (list of strings, optional): The channels to plot and get stats for. Default is all channels.
    - full_trial_base (boolean): Whether to use the full trial as the baseline period. Default is False.
    This function will process the provided event for a given subject and task.
    High gamma (HG) will be computed, and statistics will be calculated and plotted.
    The results will be saved to output files.
    """
    pass

subjects = ['D0057', 'D0059', 'D0063', 'D0065', 'D0069', 'D0071']
for sub in subjects:
    sub = sub
    output_name = "response_experimentStartBase1secTo101Sec"
    events = ["Response"]
    times = (-1, 1.5)
    base_times = (1, 101)
    LAB_root = None
    channels = None
    full_trial_base = False

    if LAB_root is None:
        HOME = os.path.expanduser("~")
        if os.name == 'nt':  # windows
            LAB_root = os.path.join(HOME, "Box", "CoganLab")
        else:  # mac
            LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box",
                                    "CoganLab")

    layout = get_data(task, root=LAB_root)
    filt = raw_from_layout(layout.derivatives['derivatives/clean'], subject=sub,
                        extension='.edf', desc='clean', preload=False)
    save_dir = os.path.join(layout.root, 'derivatives', 'freqFilt', 'figs', sub)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    ###

    good = crop_empty_data(filt)
    # %%

    print(f"good channels before dropping bads: {len(good.ch_names)}")
    print(f"filt channels before dropping bads: {len(filt.ch_names)}")

    good.info['bads'] = channel_outlier_marker(good, 3, 2)
    print("Bad channels in 'good':", good.info['bads'])

    filt.drop_channels(good.info['bads'])  # this has to come first cuz if you drop from good first, then good.info['bads'] is just empty
    good.drop_channels(good.info['bads'])

    print("Bad channels in 'good' after dropping once:", good.info['bads'])

    print(f"good channels after dropping bads: {len(good.ch_names)}")
    print(f"filt channels after dropping bads: {len(filt.ch_names)}")

    HG_base = get_baseline(filt, base_times)
    good.load_data()

    # If channels is None, use all channels
    if channels is None:
        channels = good.ch_names
    else:
        # Validate the provided channels
        invalid_channels = [ch for ch in channels if ch not in good.ch_names]
        if invalid_channels:
            raise ValueError(
                f"The following channels are not valid: {invalid_channels}")

        # Use only the specified channels
        good.pick_channels(channels)

    ch_type = filt.get_channel_types(only_data_chs=True)[0]
    good.set_eeg_reference(ref_channels="average", ch_type=ch_type)

    all_epochs_list = []

    for event in events:
    # Epoching and HG extraction for each specified event. Then concatenate all trials epochs objects together (do Stimulus/c25 and Stimulus/c75 for example, and combine to get all congruent trials)
        times_adj = [times[0] - 0.5, times[1] + 0.5]
        trials = trial_ieeg(good, event, times_adj, preload=True,
                            reject_by_annotation=False)
        all_epochs_list.append(trials)

    # Concatenate all trials
    all_trials = mne.concatenate_epochs(all_epochs_list)

    outliers_to_nan(all_trials, outliers=10)
    HG_ev1 = gamma.extract(all_trials, copy=False, n_jobs=1)
    crop_pad(HG_ev1, "0.5s")
    HG_ev1_rescaled = rescale(HG_ev1, HG_base, copy=True, mode='zscore')

    HG_base.decimate(2)
    HG_ev1.decimate(2)
    RTs, skipped = calculate_RTs(good)
    avg_RT = np.median(RTs)

    HG_ev1_evoke = HG_ev1.average(method=lambda x: np.nanmean(x, axis=0))
    HG_ev1_evoke_rescaled = HG_ev1_rescaled.average(method=lambda x: np.nanmean(x, axis=0))

    if event == "Stimulus":
        print('plotting stimulus')
        fig = HG_ev1_evoke_rescaled.plot(unit=False, scalings=dict(sEEG=1)) #this line is not finishing...
        print('plotted')
        for ax in fig.axes:
            ax.axvline(x=avg_RT, color='r', linestyle='--')
        print('about to save')
        fig.savefig(save_dir + '_HG_ev1_Stimulus_zscore.png')
        print('saved')
    else:
        print('about to plot if not stimulus')
        fig = HG_ev1_evoke_rescaled.plot(unit=False, scalings=dict(sEEG=1))
        print('plotted non stimulus')
        fig.savefig(save_dir + f'_HG_ev1_{output_name}_zscore.png')

    ###
    print(f"Shape of HG_ev1._data: {HG_ev1._data.shape}")
    print(f"Shape of HG_base._data: {HG_base._data.shape}")

    sig1 = HG_ev1._data
    sig2 = HG_base._data
    sig2 = make_data_same(sig2, sig1.shape)
    print(f"Shape of sig1: {sig1.shape}")
    print(f"Shape of sig2: {sig2.shape}")

    mat = time_perm_cluster(sig1, sig2, 0.05, n_jobs=6, ignore_adjacency=1)
    fig = plt.figure()
    plt.imshow(mat, aspect='auto')
    fig.savefig(save_dir + f'_{output_name}_stats.png', dpi=300)

    channels = good.ch_names

    #save channels with their indices 
    save_channels_to_file(channels, sub, task, save_dir)

    # save significant channels to a json
    save_sig_chans(f'{output_name}', mat, channels, sub, save_dir)

### debug why the z-scores are negative below

In [None]:
good.plot()

In [None]:
HG_ev1.plot()