In [None]:
import mne
import glob
import matplotlib.pyplot as plt
import numpy as np
# mne.viz.set_browser_backend('qt')
#%matplotlib qt

In [None]:
project_path = "/project/def-emayada/q1k/experimental/HSJ/"
pylossless_path = "derivatives/pylossless/"
sync_loss_path = "derivatives/sync_loss/"
postproc_path = "derivatives/postproc/"

In [None]:
conditions = ['sv06_d','sv15_d']
roi = ['E83']
eye = ['pupil_left']
decim = 2
freqs = np.arange(2, 50, 2)
n_cycles = freqs / 2

epoch_files = glob.glob(project_path + pylossless_path + sync_loss_path + postproc_path + 'epoch_fif_files/VEP/*epo.fif')
for item in epoch_files:
   print("[", item, "] ")

In [None]:
epoch_files = glob.glob(project_path + pylossless_path + sync_loss_path + postproc_path + 'epoch_fif_files/VEP/*epo.fif')
averaging_dict = {labels:[] for labels in conditions}

for filepath in epoch_files:
    print(filepath)
    new_epoch = mne.read_epochs(filepath)
    # the following is to select the first 50% of epochs
    #nepochs = len(new_epoch)
    #print("NTrials = " + str(nepochs))
    #trial_cap = round(nepochs/2)
    #print("trial cap = " + str(trial_cap))
    #new_epoch = new_epoch[:trial_cap]
    for condition in conditions:
        power, itc = mne.time_frequency.tfr_morlet(new_epoch[condition].pick(roi), n_cycles=n_cycles, return_itc=True, freqs=freqs, decim=decim)
        averaging_dict[condition].append((new_epoch[condition].average(picks=['eeg','misc']), power, itc))

In [None]:
def condition_summary(condition_label):
    print('Working on: ', condition_label)
    grand_average = mne.grand_average([item[0] for item in averaging_dict[condition_label]])
    display(grand_average)
    grand_average.plot()
    times = np.arange(0, 1.0, .1)
    fig = grand_average.plot_topomap(times=times, colorbar=True)
    fig.suptitle(condition_label)

for condition in conditions:
    condition_summary(condition)

In [None]:
color_dict = {'sv06_d':'blue', 'sv15_d':'red'}
linestyle_dict = {'sv06_d':'-', 'sv15_d':'-'}

evokeds = {
    'sv06_d': [item[0] for item in averaging_dict['sv06_d']],
    'sv15_d': [item[0] for item in averaging_dict['sv15_d']],
}

mne.viz.plot_compare_evokeds(evokeds,
                             combine='mean',
                             legend='lower right',
                             picks=roi, show_sensors='upper right',
                             colors=color_dict,
                             linestyles=linestyle_dict,
                             title='6Hz vs. 15Hz ERPs'
                            )

In [None]:
mne.viz.plot_compare_evokeds(evokeds,
                             combine='mean',
                             legend='lower right',
                             picks='pupil_left',
                             colors=color_dict,
                             linestyles=linestyle_dict,
                             title='6Hz vs. 15Hz ERPs'
                            )

In [None]:
# True for ERSP, False for ITC
def do_power_plotting(ersp=True):
    indexer = 1 if ersp else 2
    cond1 = mne.grand_average([item[indexer] for item in averaging_dict['sv06_d']])
    cond2 = mne.grand_average([item[indexer] for item in averaging_dict['sv15_d']])

    epochs_power_1 = np.array([item[indexer].data for item in averaging_dict['sv06_d']])[:, 0, :, :]
    epochs_power_2 = np.array([item[indexer].data for item in averaging_dict['sv15_d']])[:, 0, :, :]

    #times = 1e3 * epochs_power_1.times  # change unit to ms
    times = 1e3 * averaging_dict['sv06_d'][0][1].times
    fig1, (ax1t, ax1b) = plt.subplots(2, 1, figsize=(6, 4))
    fig1.subplots_adjust(0.12, 0.08, 0.96, 0.94, 0.2, 0.43)

    ax1t.imshow(
        epochs_power_1.mean(axis=0),
        extent=[times[0], times[-1], freqs[0], freqs[-1]],
        aspect="auto",
        origin="lower",
        cmap="RdBu_r",
    )

    ax1b.imshow(
        epochs_power_2.mean(axis=0),
        extent=[times[0], times[-1], freqs[0], freqs[-1]],
        aspect="auto",
        origin="lower",
        cmap="RdBu_r",
    )

    ax1t.set_ylabel("Frequency (Hz)")
    ax1t.set_title(f"target Induced power 06Hz")
    ax1b.set_title(f"target Induced power 15Hz")
    ax1b.set_xlabel("Time (ms)")


    F_obs, clusters, cluster_p_values, H0 = mne.stats.permutation_cluster_test(
        [epochs_power_1, epochs_power_2],
        out_type="mask",
        n_permutations=100,
        threshold=6.0,
        tail=0,
    ) # returns F difference, sampled, zscore

    # Grab any ERSP type timing window and use it to scale to ms
    times = 1e3 * averaging_dict['sv06_d'][0][1].times

    evoked_power_contrast = epochs_power_1.mean(axis=0) - epochs_power_2.mean(axis=0)
    signs = np.sign(evoked_power_contrast)

    F_obs_plot = np.nan * np.ones_like(F_obs)
    for c, p_val in zip(clusters, cluster_p_values):
        if p_val <= 0.05:
            F_obs_plot[c] = F_obs[c] * signs[c]
    max_F = np.nanmax(abs(F_obs_plot))

    fig, (ax, ax2) = plt.subplots(2, 1, figsize=(6, 4))
    ax.imshow(
        F_obs,
        extent=[times[0], times[-1], freqs[0], freqs[-1]],
        aspect="auto",
        origin="lower",
        cmap="gray",
    )

    ax.imshow(
        F_obs_plot,
        extent=[times[0], times[-1], freqs[0], freqs[-1]],
        aspect="auto",
        origin="lower",
        cmap="RdBu_r",
        vmin=-max_F,
        vmax=max_F,
    )
    ax.set_xlabel("Time (ms)")
    ax.set_ylabel("Frequency (Hz)")
    ax.set_title("Induced power")

    evoked_contrast = mne.combine_evoked(
        [cond1, cond2], weights=[1, -1]
    )
    evoked_contrast.plot(axes=ax2)

In [None]:
do_power_plotting(ersp=True)
do_power_plotting(ersp=False)

In [None]:
!jupyter nbconvert --output {"session_reports/group_half-trials_vp.html"} --TagRemovePreprocessor.remove_all_outputs_tags='{"exclude"}' --no-input --to html group_erp_tf_vp.ipynb