# Last Away ~ Coherence: BJH027

In general, if the imaginary part of C(x,y) is positive, then x and y are interacting and x is earlier than y, indicating that information is flowing from x to y. At specific frequencies, however, ‘earlier’ and ‘later’ are ambigious; e.g. at 10 Hz 10 ms earlier is the same as 90 ms later. For the present interpretation we assumed that the smaller delay in absolute value is the more probable explanation; e.g. in the above example we would favor ‘10 ms earlier’ over ‘90 ms later’. Note, that we can make this interpretation just from the sign of the imaginary part of coherency without actually calculating a delay for which we would need a reliable real part of coherency.

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy import signal, stats
import mat73
import re
from neurodsp.timefrequency import compute_wavelet_transform
import os
import mne
import mne_connectivity
import IPython
import seaborn as sns
import scipy
import joblib
import h5io
import dask.array as da 

import statsmodels
from statsmodels import stats
from statsmodels.stats import multitest

# Import required code for visualizing example models
from neurodsp.utils import create_times
from neurodsp.plts.time_series import plot_time_series
from neurodsp.spectral import compute_spectrum, rotate_powerlaw
from neurodsp.plts.spectral import plot_power_spectra



In [None]:
## Prep paths ##

subject = 'BJH027'
raw_data_dir = f"/home/brooke/pacman/raw_data/{subject}"
preproc_data_dir = f"/home/brooke/pacman/preprocessing/{subject}/ieeg"

In [None]:
## Load Neural Data

# load
last_away_epochs = mne.read_epochs(f"{preproc_data_dir}/{subject}_bp_filtered_clean_last_away_events.fif")

# get good epochs (for behavioral data only)
good_epochs = [i for i,x in enumerate(last_away_epochs.get_annotations_per_epoch()) if not x]
bad_epochs = [i for i,x in enumerate(last_away_epochs.get_annotations_per_epoch()) if  x]

# load behavioral data
last_away_data = pd.read_csv(f"{raw_data_dir}/behave/{subject}_last_away_events.csv")

# set info as metadata
last_away_epochs.metadata = last_away_data

# onlt good epochs
last_away_epochs = last_away_epochs[good_epochs]

In [None]:
## Dictionary of electrode locations ##

# Dictionary mapping ROI to elecs
# Pull mapping ROI to elecs
%run ../../scripts/roi.py
ROIs = ROIs[subject]

## prep lists

# primary ROI
hc_list = []
hc_indices = []
hc_names = []
ofc_list = []
ofc_indices = []
ofc_names = []
amyg_list = []
amyg_names = [] 
amyg_indices = []
cing_list = []
cing_names = [] 
cing_indices = []

# control ROI
insula_list = []
insula_names = []  
insula_indices = []
dlpfc_list = []
dlpfc_names = []  
dlpfc_indices = []
ec_list = []
ec_names = []  
ec_indices = []

# exclude bad ROI from list
pairs_long_name = [ch.split('-') for ch in last_away_epochs.info['ch_names']]
bidx = len(last_away_epochs.info['bads']) +1
pairs_name = pairs_long_name[bidx:len(pairs_long_name)]

# sort ROI into lists
for ix in range(0, len(pairs_name)):
    if pairs_name[ix][0] in ROIs['hc'] or pairs_name[ix][1] in ROIs['hc']:
        hc_list.append(last_away_epochs.info['ch_names'][ix + bidx])
        hc_names.append(pairs_name[ix])
        hc_indices.append(ix)
    if pairs_name[ix][0] in ROIs['ofc'] or pairs_name[ix][1] in ROIs['ofc']:
        ofc_list.append(last_away_epochs.info['ch_names'][ix + bidx])
        ofc_names.append(pairs_name[ix])
        ofc_indices.append(ix)
    if pairs_name[ix][0] in ROIs['amyg'] or pairs_name[ix][1] in ROIs['amyg']:
        amyg_list.append(last_away_epochs.info['ch_names'][ix + bidx])       
        amyg_names.append(pairs_name[ix])
        amyg_indices.append(ix)
    if pairs_name[ix][0] in ROIs['cing'] or pairs_name[ix][1] in ROIs['cing']:
        cing_list.append(last_away_epochs.info['ch_names'][ix + bidx])       
        cing_names.append(pairs_name[ix])
        cing_indices.append(ix)
        
    # control roi
    if pairs_name[ix][0] in ROIs['insula'] or pairs_name[ix][1] in ROIs['insula']:
        insula_list.append(last_away_epochs.info['ch_names'][ix + bidx])       
        insula_names.append(pairs_name[ix])
        insula_indices.append(ix)
    if pairs_name[ix][0] in ROIs['dlpfc'] or pairs_name[ix][1] in ROIs['dlpfc']:
        dlpfc_list.append(last_away_epochs.info['ch_names'][ix + bidx])       
        dlpfc_names.append(pairs_name[ix])
        dlpfc_indices.append(ix)
    if pairs_name[ix][0] in ROIs['ec'] or pairs_name[ix][1] in ROIs['ec']:
        ec_list.append(last_away_epochs.info['ch_names'][ix + bidx])       
        ec_names.append(pairs_name[ix])
        ec_indices.append(ix)        
        

In [None]:
## funcions

def compute_coherence(epochs, ch_names, roi_indices, freqs, n_cycles,  workers = 8):
    """ function to compute TFR via Morlet wavelets
    
    epochs:                     MNE epoch object with channels of interest
    freqs:                      list of frequencies, should be log spaced
    n_cycles:                   number of cycles, adjust with freqs to balance temporal and frequency resolution
    workers:                    number of threads to use while calculating TFR
    """
    print('computing TFR')
    connect = mne_connectivity.spectral_connectivity_epochs(data = epochs,
                                                            names = ch_names,
                                                            method = ['coh', 'cohy', 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', 'wpli2_debiased'],
                                                            indices = roi_indices,
                                                            mode = 'cwt_morlet',
                                                            cwt_freqs = freqs,
                                                            cwt_n_cycles = n_cycles,
                                                            n_jobs = workers)

    return connect

def extract_freqs(lower_freq, higher_freq, freq_band, subdir, ROI, label, TFR, trials):
    """ function to extract and average the across the freqs within a given band and save out to csvs
    step is calculated based on getting ~4 samples per frequency cycle
    
    lower_freq, higher_freq:    non inclusive lower and upper bounds of the band
    freq_band:                  band name, as a string
    subdir:                     dir in sub/ieeg/ that specifies the time locking
    ROI:                        region name, as a string
    label:                      label, eg ghost, no ghost, choice locked etc, as a string
    TFR:                        MNE TFR object
    """
    
    # calculate step, ## note I made this up, but it seems reasonable?
    step = {
        'delta'      : int(np.floor(TFR.info['sfreq']/(2*4))),
        'theta'      : int(np.floor(TFR.info['sfreq']/(5*4))),
        'alpha'      : int(np.floor(TFR.info['sfreq']/(11*4))),
        'beta'       : int(np.floor(TFR.info['sfreq']/(22*4))),
        'gamma'      : int(np.floor(TFR.info['sfreq']/(50*4))),
        'hfa'        : int(np.floor(TFR.info['sfreq']/(110*4)))
    }
    
    # check if it needs to be calculated with subbands
    if freq_band == 'gamma' or freq_band == 'hfa':
        
        if TFR.info['sfreq'] < 1000:
                                raise Exception('Sampling Rate is below 1000, should calculate gamma or HFA') 
        
        subband_dict = {
            'gamma'    : [(30, 40), (35, 45), (40, 50), (45, 55), (50, 60), (55, 65), (60, 70)],
            'hfa'      : [(70, 90), (80, 100), (90, 110), (100, 120), (110, 130), (120, 140), (130, 150)]
        }
        
        for chix in range(len(TFR.ch_names)):
            subb_trial_power = []
            for subb in subband_dict[freq_band]:
                fidx = np.where((freqs > subb[0]) & (freqs < subb[1]))[0]
                subb_trial_power.append(TFR.data[:, chix, fidx, :].mean(axis=1))
            trial_power = np.mean(subb_trial_power, axis = 0)
            channel_df = pd.DataFrame(trial_power[:, ::step[freq_band]])
            channel_df["trial"] = trials
            channel_df.to_csv(f"{preproc_data_dir}/{subdir}/{TFR.ch_names[chix]}_{ROI}_trial_{freq_band}_{label}.csv")
    
    else:
        fidx = np.where((freqs > lower_freq) & (freqs < higher_freq))[0]
        for chix in range(len(TFR.ch_names)):
            trial_power = TFR.data[:, chix, fidx, :].mean(axis=1)
            channel_df = pd.DataFrame(trial_power[:, ::step[freq_band]])
            channel_df["trial"] = trials
            channel_df.to_csv(f"{preproc_data_dir}/{subdir}/{TFR.ch_names[chix]}_{ROI}_trial_{freq_band}_{label}.csv")

In [None]:
# Set frequencies

freqs = np.logspace(start = np.log10(1), stop = np.log10(150), num = 80, base = 10, endpoint = True)
n_cycles = np.logspace(np.log10(2), np.log10(30), base = 10, num = 80)

# formulas to check bandwidth and time bin
band_width = (freqs / n_cycles) * 2
time_bin = n_cycles / freqs / np.pi
print(freqs)
print(n_cycles)
print(time_bin)
print(band_width)


# Main Regions of Interest

## Hippocampus

In [14]:
def shuffle_and_combine_epochs(epoch1, epoch2):
    """
    Shuffles trials in the first epoch object and then combines it with the second epoch object.

    Parameters:
    epoch1 (mne.Epochs): The first epoch object to be shuffled.
    epoch2 (mne.Epochs): The second epoch object to be combined with the shuffled first epoch.

    Returns:
    mne.Epochs: The combined epoch object after shuffling trials in the first epoch.
    """

    # Shuffle the first epoch
    indices = np.arange(len(epoch1))
    np.random.shuffle(indices)
    shuffled_epoch1 = epoch1[indices]

    # Combine the shuffled epoch1 with epoch2
    combined_epochs = shuffled_epoch1.add_channels(epoch2)

    return combined_epochs

In [None]:
# only ROI of interest
last_away_ofc = last_away_epochs.copy().pick_channels(ofc_list)
last_away_hc = last_away_epochs.copy().pick_channels(hc_list)


In [None]:
last_away_hc._data[1, 1:10, 1]

In [None]:
last_away_roi = shuffle_and_combine_epochs(last_away_hc, last_away_ofc)

In [None]:
last_away_roi._data[1, 1:10, 1]

In [None]:
hc_index_list = np.array([item for idx, roi in enumerate(last_away_roi.info.ch_names) 
                  if roi in hc_list 
                  for item in [idx] * len(ofc_list)])

ofc_index_list = np.array([idx for idx, roi in enumerate(last_away_roi.info.ch_names) 
                  if roi in ofc_list] * len(hc_list))


len(ofc_index_list) == len(hc_index_list)


In [None]:
# Resample to 1000 
if last_away_roi.info['sfreq'] > 1000:
    last_away_roi= last_away_roi.resample(512)

In [None]:
theta_freqs = freqs[np.where((freqs > 3) & (freqs < 8))]
theta_cycles = n_cycles[np.where((freqs > 3) & (freqs < 8))]

In [None]:
# compute TRF

roi_coherence = []
roi_coherence = compute_coherence(last_away_roi, last_away_roi.info.ch_names, (hc_index_list, ofc_index_list), theta_freqs, theta_cycles, workers = 8)


In [None]:
pairs = [(a,b) for a in hc_list for b in ofc_list]
pairs

In [None]:
tmp = roi_coherence[0].get_data()

    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"Coherence: average HC - OFC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = roi_coherence[0].get_data()

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"Coherence: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()




In [None]:

tmp = np.imag(roi_coherence[1].get_data())
    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"Coherencey: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = np.imag(roi_coherence[1].get_data())

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"Coherencey: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()

In [None]:

tmp = roi_coherence[2].get_data()
    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"Imaginary Coherence: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
tmp = roi_coherence[2].get_data()
    
avg_theta  = tmp[:, :, 1536:3584].mean(axis = 0).mean(axis = 0)   
length = len(avg_theta)
new_x = np.linspace(-2, 2, length) 
fig, ax = plt.subplots(figsize = (22, 20))
ax.grid()
i = plt.axvline(x=0, color='black', linestyle='--', linewidth = 5)
ax.plot(new_x, avg_theta, color = '#A03E99', linewidth = 3)
ax.set_title(f"Imaginary Coherence: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = roi_coherence[2].get_data()

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"Imaginary Coherence: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()

In [None]:

tmp = roi_coherence[3].get_data()
    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"PLV: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
tmp = roi_coherence[3].get_data()
    
avg_theta  = tmp[:, :, 1536:3584].mean(axis = 0).mean(axis = 0)   
length = len(avg_theta)
new_x = np.linspace(-2, 2, length) 
fig, ax = plt.subplots(figsize = (22, 20))
ax.grid()
i = plt.axvline(x=0, color='black', linestyle='--', linewidth = 5)
ax.plot(new_x, avg_theta, color = '#A03E99', linewidth = 3)
ax.set_title(f"PLV: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = roi_coherence[3].get_data()

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"PLV: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()

In [None]:

tmp = roi_coherence[4].get_data()
    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"Corrected PLV: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
tmp = roi_coherence[4].get_data()
    
avg_theta  = tmp[:, :, 1536:3584].mean(axis = 0).mean(axis = 0)   
length = len(avg_theta)
new_x = np.linspace(-2, 2, length) 
fig, ax = plt.subplots(figsize = (22, 20))
ax.grid()
i = plt.axvline(x=0, color='black', linestyle='--', linewidth = 5)
ax.plot(new_x, avg_theta, color = '#A03E99', linewidth = 3)
ax.set_title(f"Corrected PLV: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = roi_coherence[4].get_data()

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"Corrected PLV: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()

In [None]:

tmp = roi_coherence[5].get_data()
    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"Pairwise Phase Consistency: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
tmp = roi_coherence[5].get_data()
    
avg_theta  = tmp[:, :, 1536:3584].mean(axis = 0).mean(axis = 0)   
length = len(avg_theta)
new_x = np.linspace(-2, 2, length) 
fig, ax = plt.subplots(figsize = (22, 20))
ax.grid()
i = plt.axvline(x=0, color='black', linestyle='--', linewidth = 5)
ax.plot(new_x, avg_theta, color = '#A03E99', linewidth = 3)
ax.set_title(f"Pairwise Phase Consistency: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = roi_coherence[5].get_data()

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"Pairwise Phase Consistency: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()

In [None]:

tmp = roi_coherence[6].get_data()
    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"Phase Lag Index: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
tmp = roi_coherence[6].get_data()
    
avg_theta  = tmp[:, :, 1536:3584].mean(axis = 0).mean(axis = 0)   
length = len(avg_theta)
new_x = np.linspace(-2, 2, length) 
fig, ax = plt.subplots(figsize = (22, 20))
ax.grid()
i = plt.axvline(x=0, color='black', linestyle='--', linewidth = 5)
ax.plot(new_x, avg_theta, color = '#A03E99', linewidth = 3)
ax.set_title(f" Phase Lad Index: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = roi_coherence[6].get_data()

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"Phase Lag Index: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()

In [None]:

tmp = roi_coherence[7].get_data()
    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"Directed Phase Lag Index: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
tmp = roi_coherence[7].get_data()
    
avg_theta  = tmp[:, :, 1536:3584].mean(axis = 0).mean(axis = 0)   
length = len(avg_theta)
new_x = np.linspace(-2, 2, length) 
fig, ax = plt.subplots(figsize = (22, 20))
ax.grid()
i = plt.axvline(x=0, color='black', linestyle='--', linewidth = 5)
ax.plot(new_x, avg_theta, color = '#A03E99', linewidth = 3)
ax.set_title(f"Directed Phase Lag Index: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = roi_coherence[7].get_data()

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"Directed Phase Lag Index: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()

In [None]:

tmp = roi_coherence[8].get_data()
    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"Weighted Phase Lag Index: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = roi_coherence[8].get_data()

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"Weighted Phase Lag Index: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()

In [None]:

tmp = roi_coherence[9].get_data()
    
fig, ax = plt.subplots(figsize = (22, 20))
i = ax.imshow(tmp[:, :, 1536:3584].mean(axis = 0), cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
ax.set_yticklabels(np.round(theta_freqs))
bar = plt.colorbar(i)
ax.set_title(f"Weighted + Corrected Phase Lag Index: average OFC-HC", fontsize=22, fontweight = 'bold')
fig.show()

In [None]:
plt.rcParams['figure.figsize'] = [22, 20]
plt.rcParams.update({'font.size': 18})

tmp = roi_coherence[9].get_data()

for idx, pair in enumerate(pairs):
    
    fig, ax = plt.subplots(figsize = (22, 20))
    i = ax.imshow(tmp[idx, :, 1536:3584], cmap = 'RdBu_r', interpolation="none", origin="lower", aspect = 'auto', extent=[-2, 2, theta_freqs[0], theta_freqs[-1]])
    ax.set_yticks(np.linspace(np.min(theta_freqs),np.max(theta_freqs),len(theta_freqs)))
    ax.set_yticklabels(np.round(theta_freqs))
    bar = plt.colorbar(i)
    ax.set_title(f"Weighted + Corrected Phase Lag Index: {pair}", fontsize=22, fontweight = 'bold')
    fig.show()