In [1]:

import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections import defaultdict
from pathlib import Path

import sys
sys.path.append("ComputeCanada/frequency_tagging")
from utils import (
    set_base_dir,
    f1_f2_data, 
    plot_power_spectrum,
    psd_analyze_rois as analyze_rois,
    psd_store_data_in_dict as store_data_in_dict,
    extract_im_products,
    get_roi_colour_codes,
    get_frequency_text_codes,
    extract_im_products,
    change_font,
    DFM_ROI_PSDS, PICKLE_DIR, SCRATCH_DIR,
    MAIN
)
change_font()

Cohort info

In [2]:
NORMAL_3T_SUB_IDS = ["000", "002", "003", "004", "005", "006", "007", "008", "009"]
NORMAL_3T = ("1_frequency_tagging", "3T", {key: [("entrain", [0.125, 0.2])] for key in NORMAL_3T_SUB_IDS})
NORMAL_7T_SUB_IDS = ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]
NORMAL_7T = ("1_attention", "7T", {key: [("AttendAway", [0.125, 0.2])] for key in NORMAL_7T_SUB_IDS})
VARY_3T = (
    "1_frequency_tagging",
    "3T",
    {
        "020": [
            ("entrainA", [.125, .2]),
            ("entrainB", [.125, .175]),
            ("entrainC", [.125, .15]),
        ],
        "021": [
            ("entrainD", [.125, .2]),
            ("entrainE", [.15, .2]),
            ("entrainF", [.175, .2]),
        ],
    }
)
VARY_7T = (
    "1_frequency_tagging",
    "7T",
    {
        "020": [
            ("entrainA", [.125, .2]),
            ("entrainB", [.125, .175]),
            ("entrainC", [.125, .15]),
        ],
        "021": [
            ("entrainD", [.125, .2]),
            ("entrainE", [.15, .2]),
            ("entrainF", [.175, .2]),
        ],
    }
)

task_to_test_frequencies_map = {
    "entrain": [.125, .2],
    "control": [.125, .2],
    "AttendAway": [.125, .2],
    "entrainA": [.125, .2],
    "entrainD": [.125, .2],
    "entrainB": [.125, .175],
    "entrainC": [.125, .15],
    "entrainE": [.15, .2],
    "entrainF": [.175, .2],
}

def read_data(d):
    return d[0], d[1], d[2]

def save_bootstrapped_statistics(pkl_path, data_dict):
    import pickle
    with open(pkl_path, 'wb') as f:
        pickle.dump(data_dict, f)

def load_bootstrapped_statistics(pkl_path):
    import pickle
    with open(pkl_path, 'rb') as f:
        return pickle.load(f)

Save all statistics as pickle files

In [3]:
datadir = SCRATCH_DIR
n_permutations = 500
n_bootstraps = 400
TR = .3
fos = [.8]
pvals = ["uncp"]
nperseg = 572

# directories
fig_dir = Path(set_base_dir(str(DFM_ROI_PSDS)))
pickle_dir = PICKLE_DIR

# data stores
frequency_grid = None # track, ensuring frequency_grid is consistent
#experiment_info = []
group_data_dict = None
# Loop over roi params: fos and pvals
for fo, pval in itertools.product(fos, pvals):

    # Loop over datasets
    dataset_ids = [NORMAL_3T, NORMAL_3T, NORMAL_7T, VARY_3T, VARY_7T]
    dataset_labels = ["NORMAL_3T", "NORMAL_3T_CONTROL", "NORMAL_7T", "VARY_3T", "VARY_7T"]
    for dataset_ix, (dataset_label, dataset_id) in enumerate(zip(dataset_labels, dataset_ids)):
        experiment_id, mri_id, sub_to_task_mapping = read_data(dataset_id)

        # Loop over subjects
        for sub_ix, (sub_id, sub_task_info) in enumerate(sub_to_task_mapping.items()):
            for task_ix, (roi_task_id, _) in enumerate(sub_task_info):


                if dataset_ix == 1:
                    task_id = "control"
                else:
                    task_id = roi_task_id

                test_frequencies = task_to_test_frequencies_map[task_id] # Load first order frequencies
                assert len(test_frequencies) == 2 and test_frequencies[0] < test_frequencies[1]
                
                pkl_handler = f1_f2_data(
                    datadir, 
                    n_bootstraps, 
                    sub_id, 
                    roi_task_id, 
                    test_frequencies[0], test_frequencies[1], 
                    task_id, 
                    experiment_id=experiment_id, 
                    mri_id=mri_id, 
                    fo=fo, 
                    pval=pval
                )
                
                # Update frequencies to get secondary, and tertiary IM frequencies
                im_test_frequencies_map = extract_im_products(test_frequencies[0], test_frequencies[1])
                im_test_frequencies = [v for v in im_test_frequencies_map.values()]

                info = (
                    f"Processing {dataset_label}, sub-{sub_id}, roi-task-id-{roi_task_id}, task-{task_id}, pval-{pval}, fo-{fo}\n"
                    f"   - Primary frequencies: {test_frequencies}\n"
                    f"   - Test frequencies: {im_test_frequencies}"
                )
                print(info)

                # Loop over rephase (enabling rephasing of timeseries with phase delay), frequency of ROIs (f1, f2, f1&f2)
                for rephase, f_type in itertools.product([True,False], ["f1", "f2", "f1f2"]):
                    if rephase and f_type == "f1f2":
                        n_intersected_rois = (pkl_handler.f_data['f1']['roi_coords'] * pkl_handler.f_data['f2']['roi_coords']).sum()
                        if n_intersected_rois.sum() == 0:
                            print(f"Skipping since 0 voxels were found at the intersection.")
                            continue
                        for rephase_with in ["f1", "f2"]:
                            pkl_path = pickle_dir / f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{roi_task_id}-{f_type}_task-{task_id}_rephase-{rephase}-{rephase_with}_pval-{pval}_fo-{fo}.pkl"
                            png_out = fig_dir / f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{roi_task_id}-{f_type}_task-{task_id}_rephase-{rephase}-{rephase_with}_pval-{pval}_fo-{fo}.png"
                            if png_out.exists() and pkl_path.exists():
                                data_dict = load_bootstrapped_statistics(pkl_path)
                            else:
                                frequency_grid, observed_statistics, observed_power_spectrum, null_power_spectrums, p_values, bootstrapped_statistics = analyze_rois(
                                    pkl_handler, 
                                    f_type, 
                                    im_test_frequencies, 
                                    n_bootstraps, 
                                    TR, 
                                    n_permutations=n_permutations, 
                                    nperseg=nperseg, 
                                    rephase=rephase, 
                                    rephase_with=rephase_with,
                                    frequency_grid=frequency_grid,
                                )
                                # Save pkl_path
                                data_dict = {
                                    "frequency_grid": frequency_grid, 
                                    "observed_power_spectrum": observed_power_spectrum, 
                                    "null_power_spectrums": null_power_spectrums, 
                                    "observed_statistics": observed_statistics, 
                                    "p_values": p_values, 
                                    "bootstrapped_statistics": bootstrapped_statistics,
                                }
                                save_bootstrapped_statistics(pkl_path, data_dict)
                                # Save png_out
                                plot_power_spectrum(frequency_grid, observed_power_spectrum, null_power_spectrums, n_permutations, test_frequencies, p_values, observed_statistics, add_im=True, sub_id=sub_id, roi_frequency=f_type, close_figure=True, png_out=png_out)
                            group_data_dict = store_data_in_dict(
                                dataset_label, sub_id,
                                roi_task_id, pval, fo, f_type,
                                rephase, rephase_with,
                                im_test_frequencies_map, data_dict,
                                n_bootstraps,
                                loaded_data_dict=group_data_dict
                            )
                    else:
                        pkl_path = pickle_dir / f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{roi_task_id}-{f_type}_task-{task_id}_rephase-{rephase}_pval-{pval}_fo-{fo}.pkl"
                        png_out = fig_dir / f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{roi_task_id}-{f_type}_task-{task_id}_rephase-{rephase}_pval-{pval}_fo-{fo}.png"
                        if png_out.exists() and pkl_path.exists():
                            data_dict = load_bootstrapped_statistics(pkl_path)
                        else:
                            frequency_grid, observed_statistics, observed_power_spectrum, null_power_spectrums, p_values, bootstrapped_statistics = analyze_rois(
                                pkl_handler, 
                                f_type, 
                                im_test_frequencies, 
                                n_bootstraps, 
                                TR, 
                                n_permutations=n_permutations, 
                                nperseg=nperseg, 
                                rephase=rephase,
                                rephase_with=None,
                                frequency_grid=frequency_grid,
                            )
                            # Save pkl_path
                            data_dict = {
                                "frequency_grid": frequency_grid, 
                                "observed_power_spectrum": observed_power_spectrum, 
                                "null_power_spectrums": null_power_spectrums, 
                                "observed_statistics": observed_statistics, 
                                "p_values": p_values, 
                                "bootstrapped_statistics": bootstrapped_statistics,
                            }
                            save_bootstrapped_statistics(pkl_path, data_dict)
                            # Save png_out
                            plot_power_spectrum(frequency_grid, observed_power_spectrum, null_power_spectrums, n_permutations, test_frequencies, p_values, observed_statistics, add_im=True, sub_id=sub_id, roi_frequency=f_type, close_figure=True, png_out=png_out)
                        group_data_dict = store_data_in_dict(
                            dataset_label, sub_id,
                            roi_task_id, pval, fo, f_type,
                            rephase, None,
                            im_test_frequencies_map, data_dict,
                            n_bootstraps,
                            loaded_data_dict=group_data_dict
                        )

df = pd.DataFrame(group_data_dict)

In [4]:
def plot_1_cohort_psds(ax, psd_list,xmax, add_cbar=False, fontsize=6,linewidth=.4):
    # Stack all PSDs
    vstacked_psd = None
    for i in psd_list:
        n_fs = i.shape[-1]
        if vstacked_psd is None:
            vstacked_psd = np.zeros((3,n_fs))-1
            vstacked_psd = np.vstack((vstacked_psd,i))
        else:
            vstacked_psd = np.vstack((vstacked_psd,np.zeros((1,n_fs))-1))
            vstacked_psd = np.vstack((vstacked_psd,i))


    # Plot all psds
    vstacked_psd = np.ma.array(vstacked_psd, mask=(vstacked_psd==-1))
    cmap = plt.get_cmap("BuPu")
    cmap.set_bad("white")
    im = ax.imshow(vstacked_psd[:,:xmax], cmap=cmap,interpolation="none",aspect="auto",vmin=-.5,vmax=3)

    if add_cbar:
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
        axins = inset_axes(
            ax, width="5%",height="25%",loc="center right",borderpad=-1
        )
        cbar = fig.colorbar(im, cax=axins)
        cbar.outline.set_linewidth(0)
        cbar.set_ticks([0,.5,1,1.5,2,2.5,3])
        cbar.ax.tick_params(axis='y', pad=0.4, length=2, width=linewidth, direction="in", which="both",colors='white')
        cbar.set_ticklabels([f"{i:.1f}" for i in [0,.5,1,1.5,2,2.5,3]], fontsize=fontsize,c='k')
        cbar.set_label("Z-scored PSD",fontsize=fontsize)

def plot_2_cohort_statistics(ax, pval_list,power_list, fontsize=6,linewidth=.4):
    vstacked_pval = None
    vstacked_power = None
    for pval, power in zip(pval_list,power_list):
        n_stats = pval.shape[-1]
        if vstacked_pval is None:
            vstacked_pval = np.zeros((3,n_stats))-1
            vstacked_pval = np.vstack((vstacked_pval,pval))
        else:
            vstacked_pval = np.vstack((vstacked_pval,np.zeros((1,n_stats))-1))
            vstacked_pval = np.vstack((vstacked_pval,pval))
        if vstacked_power is None:
            vstacked_power = np.zeros((3,n_stats))-1
            vstacked_power = np.vstack((vstacked_power,power))
        else:
            vstacked_power = np.vstack((vstacked_power,np.zeros((1,n_stats))-1))
            vstacked_power = np.vstack((vstacked_power,power))

    assert vstacked_pval.shape == vstacked_power.shape
    # Plot empty imshow
    empty_arr = np.ma.array(np.zeros_like(vstacked_pval),mask=(np.zeros_like(vstacked_pval)==0))
    cmap = plt.get_cmap("binary")
    cmap.set_bad("white")
    im = ax.imshow(empty_arr,cmap=cmap,interpolation="none",aspect="auto")

    return vstacked_pval, vstacked_power

def plot_1_arrows(ax, test_frequencies_list,frequency_grid,f1f2_c,start_yloc=2.2,arrow_down_marker_s=12,space_btwn_cohort=1):

    yloc = None
    for f1,f2,n_sub_ids in test_frequencies_list:

        if yloc is None:
            yloc = start_yloc
        
        for _f,c in zip([f1,f2],f1f2_c):
            ax.scatter(
                np.interp(_f, frequency_grid, range(len(frequency_grid))), 
                yloc, 
                s=arrow_down_marker_s, 
                c=c,
                edgecolors='k',
                linewidths=0.25, marker='v'
            )
        
        yloc+=(n_sub_ids+space_btwn_cohort)

def plot_1_spines_and_ticks(ax, test_frequencies_list,xmax,roi_c,frequency_grid,start_yloc=2.5,linewidth=.4,linestyle='dashed',space_btwn_cohort=1,ytick_on=False,fontsize=6):
    
    for i in ("top", "right", "bottom", "left"):
        ax.spines[i].set_visible(False)

    yloc_top = None
    yloc_bottom = None
    ytick_tracker = []
    yticklabel_tracker = []
    for _,_,n_sub_ids in test_frequencies_list:

        if yloc_top is None:
            yloc_top = start_yloc
        yloc_bottom = yloc_top + n_sub_ids

        ax.plot([0,xmax],[yloc_bottom,yloc_bottom],c='k',lw=linewidth,linestyle=linestyle) # bottom
        ax.plot([0,xmax],[yloc_top,yloc_top],c='k',lw=linewidth,linestyle=linestyle) # top
        ax.plot([0,0],[yloc_bottom,yloc_top],c='k',lw=linewidth,linestyle=linestyle)
        ax.plot([xmax,xmax],[yloc_bottom,yloc_top],c='k',lw=linewidth,linestyle=linestyle)

        ytick_tracker += [i for i in range(int(yloc_top)+1,int(yloc_top)+n_sub_ids+1)]
        yticklabel_tracker += [i for i in range(n_sub_ids)]

        yloc_top +=(n_sub_ids+space_btwn_cohort)
        
    ax.plot([0,xmax],[0.5,0.5],c=roi_c,lw=linewidth) # bottom
    ax.plot([0,xmax],[1.5,1.5],c=roi_c,lw=linewidth) # top
    ax.fill_between([0,xmax],[.5,.5],[1.5,1.5],color=roi_c) # top
    
    YTICKLABELS = [1,2,3,4,5,6,7,8,9,1,2,3,4,5,6,7,8,9,10,11,12,13,14,14,14,15,15,15,14,14,14,15,15,15]
    if ytick_on:
        ax.set_yticks(ytick_tracker)
        ax.set_yticklabels([f"{i:03}" for i in YTICKLABELS], fontsize=fontsize)
        ax.tick_params(axis="y", length=4.,width=linewidth,pad=0)
    else:
        ax.set_yticks([])
    ax.set_xticks([np.interp(_f, frequency_grid, range(len(frequency_grid))) for _f in [0,.1,.2,.3,.4,.5]])
    ax.set_xticklabels([f"{i:.1f}" for i in [0,.1,.2,.3,.4,.5]],fontsize=fontsize)
    ax.tick_params(axis="x", length=2.,width=linewidth,pad=0)

def plot_2_spines_and_ticks(ax, test_frequencies_list,xmax,c_dict,start_yloc=2.5,linewidth=.4,linestyle='dashed',space_btwn_cohort=1,ytick_on=False,fontsize=6):
    
    for i in ("top", "right", "bottom", "left"):
        ax.spines[i].set_visible(False)

    yloc_top = None
    yloc_bottom = None
    ytick_tracker = []
    yticklabel_tracker = []
    for _,_,n_sub_ids in test_frequencies_list:

        if yloc_top is None:
            yloc_top = start_yloc
        yloc_bottom = yloc_top + n_sub_ids
        ax.plot([0,xmax],[yloc_bottom,yloc_bottom],c='k',lw=linewidth,linestyle=linestyle) # bottom
        ax.plot([0,xmax],[yloc_top,yloc_top],c='k',lw=linewidth,linestyle=linestyle) # top
        for _xpos in [0,2,4,xmax]:
            ax.plot([_xpos,_xpos],[yloc_bottom,yloc_top],c='k',lw=linewidth,linestyle=linestyle)

        ytick_tracker += [i for i in range(int(yloc_top)+1,int(yloc_top)+n_sub_ids+1)]
        yticklabel_tracker += [i for i in range(n_sub_ids)]

        yloc_top +=(n_sub_ids+space_btwn_cohort)

    for _xpos, _ypos, f_type in zip([0,2,4],[2,4,6],["f1","f2","f1f2"]):
        ax.plot([_xpos,_ypos],[0.5,0.5],c=c_dict[f_type],lw=linewidth) # bottom
        ax.plot([_xpos,_ypos],[1.5,1.5],c=c_dict[f_type],lw=linewidth) # top
        ax.fill_between([_xpos,_ypos],[.5,.5],[1.5,1.5],color=c_dict[f_type]) # top
    
    if ytick_on:
        ax.set_yticks(ytick_tracker)
        ax.set_yticklabels([f"{i+1:03}" for i in yticklabel_tracker], fontsize=fontsize)
        ax.tick_params(axis="y", length=4.,width=linewidth,pad=0)
    else:
        ax.set_yticks([])
    ax.set_xticks([])

In [None]:
FONTSIZE = 6
DPI = 300
fig = plt.figure(
    layout="constrained",figsize=(7.2,6.), # max 7.2
    dpi=DPI
)
ax_dict = fig.subplot_mosaic(
    [
        ["f1", "f2", "f1f2", "stats", "Legend"],
    ],
    gridspec_kw={
        "width_ratios": [1,1,1,.7,.75]
    },
)
roi_c_dict = get_roi_colour_codes()
frequency_text_codes = get_frequency_text_codes()

roi_pval = "uncp"
roi_fractional_overlap = .8
rephase = False
rephase_with = "None"
dataset_labels = ["NORMAL_3T_CONTROL","NORMAL_3T","NORMAL_7T"] + 6*["VARY_3T"] + 6*["VARY_7T"]
roi_task_ids = ["entrain","entrain","AttendAway"] + 2 * [f"entrain{i}" for i in ["A","B","C","D","E","F"]]
ax_dataset_labels = ["N3TC","N3T","N7T"] + [f"V3T{i}" for i in ["A","B","C","D","E","F"]] + [f"V7T{i}" for i in ["A","B","C","D","E","F"]]
roi_f_types = ["f1","f2","f1f2"]
statistic_f_types = ["f1","f2"] # COLUMN B
arrow_down_marker_s = 10
pval_marker_s = 8

# Threshold x-axis of all PSDs
frequency_grid = data_dict['frequency_grid']
xmax = (frequency_grid < .5).sum()

cohort_psds = defaultdict(list)
cohort_test_frequencies = defaultdict(list)
cohort_pval_statistics = []
cohort_power_statistics = []
for ix, (dataset_label, roi_task_id, ax_dataset_label) in enumerate(zip(dataset_labels, roi_task_ids, ax_dataset_labels)):
    
    n_sub_ids = len(df[(df.experiment_id==dataset_label) & (df.roi_task_id==roi_task_id)].sub_id.unique())
    bootstrapped_pvals = np.zeros((n_sub_ids,len(roi_f_types)*len(statistic_f_types)))
    bootstrapped_power = np.zeros((n_sub_ids,len(roi_f_types)*len(statistic_f_types)))
    bootstrapped_col_ix = 0

    for roi_f_type in roi_f_types:
        # Subset data frame
        subset_df = df[
            (df.experiment_id==dataset_label) &
            (df.roi_task_id==roi_task_id) &
            (df.roi_pval==roi_pval) &
            (df.roi_fractional_overlap==roi_fractional_overlap) &
            (df.roi_f_type==roi_f_type) &
            (df.rephase==rephase) &
            (~df.rephase_with.notnull())
        ]
        f1 = subset_df['f1'].unique()
        f2 = subset_df['f2'].unique()
        cohort_test_frequencies[roi_f_type].append((f1[0],f2[0],n_sub_ids))
        for f in [f1,f2]:
            if f.shape[0] != 1:
                raise ValueError(f"1 frequency expected: {f}")
        
        """Store PSDs"""
        # Stack single-subject PSDs
        psds = []
        for psd in subset_df.power_spectrum:
            psd = ( psd-psd.mean() ) / psd.std() # Z-score
            psds.append(psd)
        psds = np.vstack(psds)
        cohort_psds[roi_f_type].append(psds)
        """Store statistics for COLUMN B figure"""
        for _f_type in statistic_f_types:
            for bootstrapped_row_ix, (_power, _pval) in enumerate(zip(subset_df[f"bootstrap_power_{_f_type}"],subset_df[f"bootstrap_pval_{_f_type}"])):
                bootstrapped_pvals[bootstrapped_row_ix,bootstrapped_col_ix] = _pval
                bootstrapped_power[bootstrapped_row_ix,bootstrapped_col_ix] = _power
            bootstrapped_col_ix+=1
    cohort_pval_statistics.append(bootstrapped_pvals)
    cohort_power_statistics.append(bootstrapped_power)


arrow_marker_s=20
"""Plot PSDs"""
for f_type,ax in ax_dict.items():
    if f_type.startswith("f"):
        ytick_on = False
        add_cbar=False
        if f_type == "f1":
            f1f2_c = ["k","white"]
            ytick_on = True
        if f_type == "f2":
            f1f2_c = ["white","k"]
        if f_type == "f1f2":
            f1f2_c = ["k","k"]
            add_cbar=True
        plot_1_cohort_psds(ax,cohort_psds[f_type],xmax,add_cbar=add_cbar,fontsize=FONTSIZE)
        plot_1_arrows(ax,cohort_test_frequencies[f_type],frequency_grid,f1f2_c,arrow_down_marker_s=arrow_marker_s)
        plot_1_spines_and_ticks(ax,cohort_test_frequencies[f_type],xmax,roi_c_dict[f_type],frequency_grid,ytick_on=ytick_on,fontsize=FONTSIZE)
ax = ax_dict["stats"]
pval, power = plot_2_cohort_statistics(ax,cohort_pval_statistics,cohort_power_statistics,)
plot_2_spines_and_ticks(ax,cohort_test_frequencies["f1"],6,roi_c_dict,fontsize=FONTSIZE,)
ax.set_xlim(-3.,ax.get_xlim()[-1])
ax.scatter(.5, 2.2, s=arrow_marker_s, c='k',edgecolors='k',linewidths=0.25, marker='v')
ax.scatter(1.5, 2.2, s=arrow_marker_s, c='white',edgecolors='k',linewidths=0.25, marker='v')
ax.scatter(2.5, 2.2, s=arrow_marker_s, c='white',edgecolors='k',linewidths=0.25, marker='v')
ax.scatter(3.5, 2.2, s=arrow_marker_s, c='k',edgecolors='k',linewidths=0.25, marker='v')
ax.scatter(4.5, 2.2, s=arrow_marker_s, c='k',edgecolors='k',linewidths=0.25, marker='v')
ax.scatter(5.5, 2.2, s=arrow_marker_s, c='k',edgecolors='k',linewidths=0.25, marker='v')
pval_marker_s = 30
for i in range(pval.shape[0]):
    for j in range(pval.shape[1]):
        _power = power[i,j]
        _pval = pval[i,j]
        c='white'
        if _pval == -1:
            continue
        if np.isnan(_power):
            ax.scatter(j+.5,i,s=25,marker='x',c='k',linewidths=0.25)
            continue
        if _pval > 1:
            raise ValueError()
        elif _pval==1.:
            s=pval_marker_s
            c='k'
        elif _pval > .95:
            s=pval_marker_s
            c='grey'
        elif _pval>.8:
            s=pval_marker_s*.7
        elif _pval>.6:
            s=pval_marker_s*.4
        else:
            s=pval_marker_s*.1
        ax.scatter(j+.5,i,s=s,c=c,edgecolors='k',linewidths=0.25)

ax = ax_dict["Legend"]
plot_2_cohort_statistics(ax,cohort_pval_statistics,cohort_power_statistics)
for i in ("top", "right", "bottom", "left"):
    ax.spines[i].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
ax.text(-2.3,3,f"Frequency of the region", fontsize=FONTSIZE)
ax.scatter(-2,4,s=arrow_marker_s,c=roi_c_dict["f1"],marker='s')
ax.scatter(-1.5,4,s=arrow_marker_s,c=roi_c_dict["f1"],marker='s')
ax.scatter(-2,5,s=arrow_marker_s,c=roi_c_dict["f2"],marker='s')
ax.scatter(-1.5,5,s=arrow_marker_s,c=roi_c_dict["f2"],marker='s')
ax.scatter(-2,6,s=arrow_marker_s,c=roi_c_dict["f1f2"],marker='s')
ax.scatter(-1.5,6,s=arrow_marker_s,c=roi_c_dict["f1f2"],marker='s')
ax.text(-1.0,4.1,f"$f_1$",fontsize=FONTSIZE)
ax.text(-1.0,5.1,f"$f_2$",fontsize=FONTSIZE)
ax.text(-1.0,6.1,f"Multiplex",fontsize=FONTSIZE)

ax.text(-2.3,9,f"Tested frequency\nmatches region",fontsize=FONTSIZE)
ax.scatter(-2,10,s=arrow_marker_s,c='k',marker='v',edgecolors='k',linewidths=.25)
ax.scatter(-2,11,s=arrow_marker_s,c='white',marker='v',edgecolors='k',linewidths=.25)
ax.text(-1.4,10.1,"Match",fontsize=FONTSIZE)
ax.text(-1.4,11.1,"No match",fontsize=FONTSIZE)

ax.text(-2.3,14,f"$P$ value<.05\n(% bootstraps)",fontsize=FONTSIZE)
ax.scatter(-2,15,s=pval_marker_s*.1,c='white',edgecolors='k',linewidths=.25)
ax.scatter(-2,16,s=pval_marker_s*.4,c='white',edgecolors='k',linewidths=.25)
ax.scatter(-2,17,s=pval_marker_s*.7,c='white',edgecolors='k',linewidths=.25)
ax.scatter(-2,18,s=pval_marker_s,c='grey',edgecolors='k',linewidths=.25)
ax.scatter(-2,19,s=pval_marker_s,c='k',edgecolors='k',linewidths=.25)
ax.text(-1.4,15.1,"<0.6",fontsize=FONTSIZE)
ax.text(-1.4,16.1,"0.6-0.8",fontsize=FONTSIZE)
ax.text(-1.4,17.1,"0.8-.95",fontsize=FONTSIZE)
ax.text(-1.4,18.1,"0.95-1.0",fontsize=FONTSIZE)
ax.text(-1.4,19.1,"1.0",fontsize=FONTSIZE)

ax=ax_dict["f1"]
ax.set_ylabel("Subjects",fontsize=FONTSIZE,c='white')
line = plt.Line2D([.0195,.0195],[.765,.935],transform=fig.transFigure,color='k',linestyle='-',lw=.4)
fig.add_artist(line)
line = plt.Line2D([.0195,.0195],[.765-.19,.935-.19],transform=fig.transFigure,color='k',linestyle='-',lw=.4)
fig.add_artist(line)
line = plt.Line2D([.0195,.0195],[.765-.19-(.19*5/10),.935-(.19*2)],transform=fig.transFigure,color='k',linestyle='-',lw=.4)
fig.add_artist(line)
line = plt.Line2D([.0195,.0195],[.765-.19-(.19*5/10)-(.19*12/10),.935-(.19*2)-(.19*5/10)],transform=fig.transFigure,color='k',linestyle='-',lw=.4)
fig.add_artist(line)
line = plt.Line2D([.0195,.0195],[.765-.19-(.19*5/10)-(.19*12/10)-(.19*12/10),.935-(.19*2)-(.19*5/10)-(.19*12/10)],transform=fig.transFigure,color='k',linestyle='-',lw=.4)
fig.add_artist(line)

fig.text(.0135, .515,"7T",fontsize=FONTSIZE+2,color='k',rotation=90,ha='center')
fig.text(.0135, .515+(.19*7.5/10),"3T",fontsize=FONTSIZE+2,color='k',rotation=90,ha='center')
fig.text(.0135, .515+(.19*15.5/10),"3T Control",fontsize=FONTSIZE+2,color='k',rotation=90,ha='center')
fig.text(.0135, .515-(.19*9.5/10),"3T Vary",fontsize=FONTSIZE+2,color='k',rotation=90,ha='center')
fig.text(.0135, .515-(.19*21.5/10),"7T Vary",fontsize=FONTSIZE+2,color='k',rotation=90,ha='center')

fig.savefig(MAIN / "Fig3_PSD_across_all_experiments.png",dpi=300)

In [None]:
from scipy.stats import wilcoxon

def calculate_im_frequency(row, im_expression):
    if im == "f2-f1":
        return row["f2"]-row["f1"]
    elif im == "2f1-f2":
        return 2*row["f1"]-row["f2"]
    elif im == "2f2-f1":
        return 2*row["f2"]-row["f1"]
    elif im == "2f1":
        return 2*row["f1"]
    elif im == "2f2":
        return 2*row["f2"]
    elif im == "f1+f2":
        return row["f1"]+row["f2"]
    elif im == "f1":
        return row["f1"]
    elif im == "f2":
        return row["f2"]
    else:
        raise ValueError(f"Not Implemented: {im_expression}")

def find_closest_index(array, value):
    """Find the index in the array closest to the given value."""
    idx = np.abs(array - value).argmin()
    return idx

def calculate_fsnr(row, frequency_grid, foi, foi_range):

    psd = row['power_spectrum']
    im_idx = find_closest_index(frequency_grid, foi)
    im_ub_idx = find_closest_index(frequency_grid,foi+foi_range)
    im_lb_idx = find_closest_index(frequency_grid,foi-foi_range)
    peak_power = psd[im_idx]
    peripheral_power = np.concatenate((psd[im_lb_idx:im_idx], psd[im_idx+1:im_ub_idx+1]))

    return peak_power / peripheral_power.mean()

def get_im_pvalues(im,pval,roi_fractional_overlap,roi_f_type,rephase,rephase_with,frequency_grid,foi_range,total_control_datasets=9,total_datasets=9+4+6+6):
    if rephase:
        if roi_f_type in ["f1","f2"]:
            im_df = df[(df.roi_pval==pval) & (df.roi_fractional_overlap==roi_fractional_overlap) & (df.roi_f_type==roi_f_type) & (df.rephase==rephase)][["roi_f_type","experiment_id","sub_id","roi_task_id","f1","f2",f"bootstrap_power_{im}",f"bootstrap_pval_{im}","rephase_with","power_spectrum"]]
        if roi_f_type == "f1f2":
            im_df = df[(df.roi_pval==pval) & (df.roi_fractional_overlap==roi_fractional_overlap) & (df.roi_f_type==roi_f_type) & (df.rephase==rephase) & (df.rephase_with==rephase_with)][["roi_f_type","experiment_id","sub_id","roi_task_id","f1","f2",f"bootstrap_power_{im}",f"bootstrap_pval_{im}","rephase_with","power_spectrum"]]
    else:
        im_df = df[(df.roi_pval==pval) & (df.roi_fractional_overlap==roi_fractional_overlap) & (df.roi_f_type==roi_f_type) & (df.rephase==rephase)][["roi_f_type","experiment_id","sub_id","roi_task_id","f1","f2",f"bootstrap_power_{im}",f"bootstrap_pval_{im}","rephase_with","power_spectrum"]]
        
    im_df["im_frequency"] = im_df.apply(lambda row: calculate_im_frequency(row, im), axis=1)
    im_df["im_fsnr"] = im_df.apply(lambda row: calculate_fsnr(row,frequency_grid,calculate_im_frequency(row,im),foi_range),axis=1)

    control_im_df = im_df[(im_df.experiment_id.str.endswith("CONTROL"))]
    im_df = im_df[~(im_df.experiment_id.str.endswith("CONTROL"))]
    assert control_im_df.shape[0] == total_control_datasets
    assert im_df.shape[0] <= total_datasets


    #return control_im_df[f"bootstrap_pval_{im}"].values, im_df[f"bootstrap_pval_{im}"].values, control_im_df, im_df
    return control_im_df[f"im_fsnr"].values, im_df[f"im_fsnr"].values, control_im_df, im_df



pval = "uncp"
roi_fractional_overlap = .8
rephase = False
rephase_with = None
foi_range = .025

if rephase:
    if rephase_with not in ["f1","f2"]:
        raise ValueError(f"{rephase_with} must be set to f1 or f2.")

roi_f_types = ["f1","f2","f1f2"]
im_codes = ["f1","f2","f2-f1","f1+f2","2f1","2f2","2f1-f2","2f2-f1"]

fig, ax_dict = plt.subplot_mosaic(
    mosaic=[roi_f_types,[f"{i}C" for i in roi_f_types]],
    dpi=300,
    figsize=(4,2),
    layout="constrained"
)
im_coords = {
    "f1": 1,
    "f2": 2,
    "f2-f1": 4,
    "f1+f2": 5,
    "2f1": 6,
    "2f2": 7,
    "2f1-f2": 9,
    "2f2-f1": 10,
}

store_im_data_across_rois = {}
for roi_f_type in roi_f_types:
    ax = ax_dict[roi_f_type]
    for ax_key in [roi_f_type,f"{roi_f_type}C"]:
        ax=ax_dict[ax_key]
        ax.set_xticks([i for i in im_coords.values()])
        ax.set_xticklabels([i for i in im_coords.keys()],fontsize=FONTSIZE,rotation=90)
        ax.set_xlim(0,11)
        if ax_key.endswith("C"):
            ax.set_yticks([-15,15])
            ax.set_yticklabels([-15,15],fontsize=FONTSIZE)
            ax.set_ylim(-15,15)
            for _spine in ["top","bottom","right"]:
                ax.spines[_spine].set_visible(False)
        else:
            ax.set_yticks([0,15])
            ax.set_yticklabels([0,15],fontsize=FONTSIZE)
            ax.set_ylim(-2,17)
            for _spine in ["top","bottom","right"]:
                ax.spines[_spine].set_visible(False)
    for im in im_codes:
        control_pvals,test_pvals,_,_ = get_im_pvalues(
            im, pval, roi_fractional_overlap, roi_f_type, rephase, rephase_with,
            data_dict['frequency_grid'],foi_range
        )
        
        x_pos = np.zeros_like(test_pvals) + im_coords[im]
        ax_dict[roi_f_type].scatter(x_pos, test_pvals, c=roi_c_dict[im],s=pval_marker_s,zorder=8)
        x_pos = np.zeros_like(control_pvals) + im_coords[im]
        ax_dict[roi_f_type].scatter(x_pos-.25, control_pvals, c='lightgrey',s=pval_marker_s/2,ec='k',lw=.25,zorder=10)
        x_pos = np.zeros_like(control_pvals) + im_coords[im]
        ax_dict[roi_f_type].scatter(x_pos+.25, test_pvals[:9], c='darkgrey',s=pval_marker_s/2,ec='k',lw=.25,zorder=10)

        x_pos = np.zeros_like(control_pvals) + im_coords[im]
        pval_test_vs_control = wilcoxon(test_pvals[:9],control_pvals,alternative="greater")
        c='grey'
        if pval_test_vs_control.pvalue < .05:
            c='red'
        ax_dict[f"{roi_f_type}C"].scatter(x_pos, test_pvals[:9]-control_pvals, c=c,s=pval_marker_s)
        store_im_data_across_rois[(roi_f_type,im)] = [
            control_pvals,
            test_pvals[:9],
            pval_test_vs_control,
        ]

In [None]:
from statsmodels.stats.multitest import multipletests
pvalues = [i[2].pvalue for i in store_im_data_across_rois.values()]
_, p_values, _, _ = multipletests(pvalues,method="fdr_bh")

for pval, (k,v) in zip(p_values, store_im_data_across_rois.items()):
    store_im_data_across_rois[k].append(pval)
    print(k,store_im_data_across_rois[k][2:])

In [8]:
def format_f1_f2_equation(f1,f2):
    if int(f1)==f1:
        f1=int(f1)
    if int(f2)==f2:
        f2=int(f2)
    if f1==f2:
        if f1 == 1 or f1 == -1: f1 = ""
        if f2 == 1 or f2 == -1: f2 = ""
        return f"{f1}$f_1$+{f2}$f_2$"
    elif f1<0 and f2>0:
        if f1 == 1 or f1 == -1: f1 = ""
        if f2 == 1 or f2 == -1: f2 = ""
        try:
            return f"{f2}$f_2$-{-f1}$f_1$"
        except:
            return f"{f2}$f_2$-{f1}$f_1$"
    elif f2<0 and f1>0:
        if f1 == 1 or f1 == -1: f1 = ""
        if f2 == 1 or f2 == -1: f2 = ""
        try:
            return f"{f1}$f_2$-{-f2}$f_1$"
        except:
            return f"{f1}$f_2$-{f2}$f_1$"
    elif f1>=0 and f2 >= 0:
        if f1 == 1 or f1 == -1: f1 = ""
        if f2 == 1 or f2 == -1: f2 = ""
        if f1==0:
            return f"{f2}$f_2$"
        elif f2==0:
            return f"{f1}$f_1$"
        else:
            return f"{f1}$f_1$+{f2}$f_2$"
    else:
        print(f1,f2)
        raise ValueError()

In [None]:


pval = 'uncp'
roi_fractional_overlap = .8
rephase = False
rephase_with = None

LINEWIDTH=.5

def get_pct_of_datasets_with_gt_max_fsnr(im_df, max_fsnr):
    experiment_id_total_count = im_df.experiment_id.value_counts().to_dict()
    n_gt_max_fsnr_all_datasets = 0
    pct_of_datasets = {}
    for experiment_id, n_experiments in experiment_id_total_count.items():
        n_gt_max_fsnr = im_df[(im_df.im_fsnr>max_fsnr) & (im_df.experiment_id==experiment_id)].shape[0]
        n_gt_max_fsnr_all_datasets += n_gt_max_fsnr
        pct_of_datasets[experiment_id] = f"{n_gt_max_fsnr}/{n_experiments}"

    total_datasets = np.sum([i for i in experiment_id_total_count.values()])
    pct_of_datasets["ALL"] = f"{n_gt_max_fsnr_all_datasets}/{total_datasets}"
    
    return pct_of_datasets

mosaic = [[f"im-{im}_roi-{i}" for i in roi_f_types] for im in im_codes]
fig, ax_dict = plt.subplot_mosaic(
    mosaic,dpi=300,figsize=(6,7),layout="constrained"
)

im_above_df = {}
pct_of_datasets_across_all_frequencies = {}
for im in im_codes:
    im_c = roi_c_dict[im]
    for roi_f_type in roi_f_types:
        control_pvals, test_pvals, control_im_df, im_df = get_im_pvalues(im, pval, roi_fractional_overlap, roi_f_type, rephase, rephase_with, data_dict['frequency_grid'], foi_range)
        ax = ax_dict[f"im-{im}_roi-{roi_f_type}"]
        max_fsnr = control_im_df.im_fsnr.values.max()
        pct_of_datasets = get_pct_of_datasets_with_gt_max_fsnr(im_df, max_fsnr)
        ax.axvline(max_fsnr,c='k',lw=LINEWIDTH*4,zorder=5,linestyle='dotted')
        im_above_df[f"roi-{roi_f_type}_im-{im}"] = im_df[(im_df.im_fsnr>max_fsnr)]
        data = im_df.im_fsnr.values
        below_threshold = data[data<max_fsnr]
        above_threshold = data[data>=max_fsnr]
        n_bins = 50
        _xticks = [0,5,10]
        if im in ["f1","f2"]:
            _xticks = [0,7,14]
        _,_,patches = ax.hist(below_threshold,alpha=1.,bins=n_bins,zorder=2,color='grey',width=_xticks[-1]*.01)
        _,_,patches = ax.hist(above_threshold,alpha=1.,bins=n_bins,zorder=2,color=im_c,width=_xticks[-1]*.01)
        for _spine in ["top","right"]:
            ax.spines[_spine].set_visible(False)
        for _spine in ["left","bottom"]:
            ax.spines[_spine].set_linewidth(LINEWIDTH)
        ax.tick_params(axis="both", length=2.,width=LINEWIDTH,pad=0)
        if im=="f1":
            if roi_f_type == "f1":
                ax.set_title(f"Region, {format_f1_f2_equation(1,0)}", fontsize=FONTSIZE)
            if roi_f_type == "f2":
                ax.set_title(f"Region, {format_f1_f2_equation(0,1)}", fontsize=FONTSIZE)
            if roi_f_type == "f1f2":
                ax.set_title(f"Region, Multiplex", fontsize=FONTSIZE)
        if im=="2f2-f1":
            ax.set_xlabel("SNR", fontsize=FONTSIZE)
        if roi_f_type=="f1":
            if im == "f1":
                ax.set_ylabel(format_f1_f2_equation(1,0),fontsize=FONTSIZE,rotation=0)
            if im == "f2":
                ax.set_ylabel(format_f1_f2_equation(0,1),fontsize=FONTSIZE,rotation=0)
            if im == "f2-f1":
                ax.set_ylabel(format_f1_f2_equation(-1,1),fontsize=FONTSIZE,rotation=0)
            if im == "f1+f2":
                ax.set_ylabel(format_f1_f2_equation(1,1),fontsize=FONTSIZE,rotation=0)
            if im == "2f2-f1":
                ax.set_ylabel(format_f1_f2_equation(-1,2),fontsize=FONTSIZE,rotation=0)
            if im == "2f1-f2":
                ax.set_ylabel(format_f1_f2_equation(2,-1),fontsize=FONTSIZE,rotation=0)
            if im == "2f1":
                ax.set_ylabel(format_f1_f2_equation(2,0),fontsize=FONTSIZE,rotation=0)
            if im == "2f2":
                ax.set_ylabel(format_f1_f2_equation(0,2),fontsize=FONTSIZE,rotation=0)
        ax.set_xlim(0,_xticks[-1])
        ax.set_xticks(_xticks)
        ax.set_xticklabels(_xticks,fontsize=FONTSIZE)
        ax.set_ylim(0,3)
        ax.set_yticks([0,3])
        ax.set_yticklabels([0,3],fontsize=FONTSIZE)
        texts = '\n'.join([f"{dataset_id}$=${pct_of_datasets[_key]}" for _key, dataset_id in zip(["NORMAL_3T","NORMAL_7T","VARY_3T","VARY_7T","ALL"],["3T","7T","3T Vary","7T Vary","All"])])
        ax.text(ax.get_xlim()[-1]*1.02, ax.get_ylim()[-1]*.25, texts,fontsize=FONTSIZE-2,zorder=20,c='k')
        # Store
        pct_of_datasets_across_all_frequencies[(im,roi_f_type)] = pct_of_datasets

fig.savefig(MAIN / "Extended_Figure_X_fsnr_threshold.png",dpi=300)

In [None]:
for im in ['f2-f1','2f1',"2f2"]:
    pval = "uncp"
    roi_fractional_overlap = .8
    rephase = False
    rephase_with = None

    control_pvals, test_pvals, control_im_df,im_df = get_im_pvalues(
        im, pval, roi_fractional_overlap, roi_f_type, rephase, rephase_with, data_dict['frequency_grid'], foi_range
    )

    im_df = im_df[(im_df.experiment_id.str.startswith("VARY"))]
    im_df = im_df.dropna(subset=f"bootstrap_power_{im}")

    fig,ax_dict = plt.subplot_mosaic([["power",'snr',"pval"]],dpi=300,figsize=(4,1),layout="constrained")
    for k in ["power","pval"]:
        ax_dict[k].scatter(im_df.im_frequency,im_df[f"bootstrap_{k}_{im}"],s=5)
        ax_dict[k].set_ylabel(k,fontsize=FONTSIZE)
    ax_dict["snr"].scatter(im_df.im_frequency,im_df.im_fsnr,s=5)
    ax_dict["snr"].set_ylabel("snr",fontsize=FONTSIZE)
    ax_dict["snr"].set_xlabel(f"frequency of {im}",fontsize=FONTSIZE)
    ax_dict["snr"].set_title(im,fontsize=FONTSIZE)

    for _k, ax in ax_dict.items():
        ax.tick_params(axis="both", length=2.,width=.4,pad=0,labelsize=FONTSIZE)

    

In [None]:
from matplotlib.colors import LinearSegmentedColormap
cmap = plt.cm.get_cmap("YlGn")
n_colors = 4
color_indices = np.linspace(0, 1, n_colors)
colors = [cmap(index) for index in color_indices]
custom_cmap = LinearSegmentedColormap.from_list('custom_YlGn', colors, N=n_colors)
custom_cmap = [custom_cmap(i) for i in range(4)]
custom_cmap

In [None]:
import warnings
warnings.filterwarnings('ignore')

LINEWIDTH = .5
FONTSIZE = 8
x_axis = []

for im in im_codes:
    for roi_f_type in roi_f_types:
        im_f_type_code = frequency_text_codes[im]
        roi_f_type_code = frequency_text_codes[roi_f_type]
        x_axis.append(f"{im_f_type_code} [{roi_f_type_code}]")

weight_counts = {}
for experiment_id in ["NORMAL_3T","NORMAL_7T","VARY_3T","VARY_7T"]:
    #weight_counts[experiment_id] = np.array([int(pct_of_datasets_across_all_frequencies[(im,roi_f_type)][experiment_id].split("/")[0]) for roi_f_type in roi_f_types for im in im_codes])
    weight_counts[experiment_id] = np.array([int(pct_of_datasets_across_all_frequencies[(im,roi_f_type)][experiment_id].split("/")[0]) for im in im_codes for roi_f_type in roi_f_types])
bottom = np.zeros_like(weight_counts["NORMAL_3T"])
width = 1

xtick_pos = []
_xtick_pos = 0
for i in range(bottom.shape[0]):
    _xtick_pos +=1
    if i%3 == 0:
        _xtick_pos+=.5
    xtick_pos.append(_xtick_pos)


mosaic = [
    5*["1"]+["legend"],
    5*["2"]+["legend"],
    ["3","3a","3b","3c","3d","legend"],
]
fig, ax_dict = plt.subplot_mosaic(
    mosaic=mosaic,figsize=(7,3.6),dpi=300,layout="tight",
    gridspec_kw={
        "width_ratios":[1,1,1,1,1,1],
        "height_ratios":[1,1,1.2],
    }
)

for ax_key, ax in ax_dict.items():
    if ax_key == "legend" or ax_key.startswith("3") or ax_key.startswith("4") or ax_key == "empty":
        continue
    ax.set_xlim(xtick_pos[0]-width,xtick_pos[-1]+width)
    LINEWIDTH = .4
    for _spine in ["top","right","bottom"]:
        ax.spines[_spine].set_visible(False)
        ax.spines.left.set_linewidth(LINEWIDTH)
        ax.spines.left.set_position(('outward',-1))
    ax.tick_params(axis="x", length=2.,width=LINEWIDTH,pad=.4)
    ax.tick_params(axis="y", length=4.,width=LINEWIDTH,pad=.4)


ax = ax_dict["2"]
all_bars = []
for (experiment_id, weight_count),bar_c in zip(weight_counts.items(),custom_cmap):
    p = ax.bar(x_axis,weight_count,width,label=experiment_id,bottom=bottom,color=bar_c)
    all_bars.append(p)
    bottom+=weight_count

for bars_per_experiment in all_bars:
    for ix, (bar_per_im_and_roi_f_type, _xtick_pos) in enumerate(zip(bars_per_experiment,xtick_pos)):
        bar_per_im_and_roi_f_type.set_edgecolor('k')
        bar_per_im_and_roi_f_type.set_x(_xtick_pos-(width/2))
        bar_per_im_and_roi_f_type.set_linewidth(LINEWIDTH)
        if ix % 3 == 0:
            x_label_c = roi_c_dict["f1"]
        if ix % 3 == 1:
            x_label_c = roi_c_dict["f2"]
        if ix % 3 == 2:
            x_label_c = roi_c_dict["f1f2"]
        ax.fill_between(
            [_xtick_pos-(width/2),_xtick_pos+(width/2)],
            [-3]*2, [-.5]*2,
            color=x_label_c,linewidth=LINEWIDTH,edgecolor='k',zorder=1
        )
ax.set_ylabel("Peak count\n(n=25)", fontsize=FONTSIZE)
ax.set_xlabel("Search frequency", fontsize=FONTSIZE)
ax.set_xticks([j for i,j in enumerate(xtick_pos) if i%3==1])
ax.set_xticklabels([j.split("[")[0] for i,j in enumerate(x_axis) if i%3==1],rotation=0, fontsize=FONTSIZE)
ax.set_yticks([0,25])
ax.set_yticklabels([0,25], fontsize=FONTSIZE)
ax.set_ylim(-3,25.5)
ax.spines.left.set_bounds(0,25)

ordered_x_key=[]
for im in im_codes:
    for roi_f_type in roi_f_types:
        ordered_x_key.append((roi_f_type,im))

ax = ax_dict["1"]
for _xtick_pos, x_key in zip(xtick_pos,ordered_x_key):
    offset = .2
    control_x_pos = _xtick_pos - offset
    test_x_pos = _xtick_pos + offset
    stored_im_data = store_im_data_across_rois[x_key]
    control_data = stored_im_data[0]
    test_data = stored_im_data[1]
    # Scatter
    ax.scatter([control_x_pos]*control_data.shape[0],control_data,s=8,c='lightgrey',edgecolors='k',linewidths=0.25,zorder=10)
    ax.scatter([test_x_pos]*test_data.shape[0],test_data,s=8,c='k',edgecolors='k',linewidths=0.25,zorder=10)
    # P-value
    _max = max(control_data.max(),test_data.max())+1
    pvalue = stored_im_data[3]
    if pvalue < .05:
        ax.plot([control_x_pos,test_x_pos],[_max,_max],color='k',linewidth=.25,linestyle='-',zorder=20)
        ax.text((control_x_pos+test_x_pos)/2,_max+.5,f"$*$",fontsize=FONTSIZE,horizontalalignment='center',verticalalignment='center')
    # Line
    for _control, _test in zip(control_data,test_data):
        ax.plot([control_x_pos,test_x_pos],[_control,_test],color='k',linewidth=.25,linestyle=':',zorder=2)
for bars_per_experiment in all_bars:
    for ix, (bar_per_im_and_roi_f_type, _xtick_pos) in enumerate(zip(bars_per_experiment,xtick_pos)):
        bar_per_im_and_roi_f_type.set_edgecolor('k')
        bar_per_im_and_roi_f_type.set_x(_xtick_pos-(width/2))
        bar_per_im_and_roi_f_type.set_linewidth(LINEWIDTH)
        if ix % 3 == 0:
            x_label_c = roi_c_dict["f1"]
        if ix % 3 == 1:
            x_label_c = roi_c_dict["f2"]
        if ix % 3 == 2:
            x_label_c = roi_c_dict["f1f2"]
        ax.fill_between(
            [_xtick_pos-(width/2),_xtick_pos+(width/2)],
            [-3]*2, [-.5]*2,
            color=x_label_c,linewidth=LINEWIDTH,edgecolor='k',zorder=1
        )
ax.set_ylabel("\nSNR", fontsize=FONTSIZE)
#ax.set_xlabel("Search frequency", fontsize=FONTSIZE)
ax.set_xticks([j for i,j in enumerate(xtick_pos) if i%3==1])
ax.set_xticklabels([j.split("[")[0] for i,j in enumerate(x_axis) if i%3==1],rotation=0, fontsize=FONTSIZE)
max_y = int(ax.get_ylim()[-1])+1
ax.set_ylim(-3,max_y+1)
ax.spines.left.set_bounds(0,max_y)
ax.set_yticks([0,max_y])
ax.set_yticklabels([0,max_y],fontsize=FONTSIZE)

for _k in ["legend"]:
    ax = ax_dict[_k]
    for i in ("top", "right", "bottom", "left"):
        ax.spines[i].set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])
    ax.set_ylim(0,1)
    ax.set_xlim(0,1)

"""ROW 3"""
def calculate_fsnr(row,frequency_grid,f1_scalar,f2_scalar,foi_range):
    foi = f1_scalar*row["f1"] + f2_scalar*row["f2"]
    psd = row['power_spectrum']
    im_idx = find_closest_index(frequency_grid, foi)
    im_ub_idx = find_closest_index(frequency_grid,foi+foi_range)
    im_lb_idx = find_closest_index(frequency_grid,foi-foi_range)
    peak_power = psd[im_idx]
    peripheral_power = np.concatenate((psd[im_lb_idx:im_idx], psd[im_idx+1:im_ub_idx+1]))

    return peak_power / peripheral_power.mean()

def get_im_pvalues(f1_scalar,f2_scalar,pval,roi_fractional_overlap,roi_f_type,rephase,rephase_with,frequency_grid,foi_range,total_control_datasets=9,total_datasets=9+4+6+6):
    if rephase:
        if roi_f_type in ["f1","f2"]:
            im_df = df[(df.roi_pval==pval) & (df.roi_fractional_overlap==roi_fractional_overlap) & (df.roi_f_type==roi_f_type) & (df.rephase==rephase)][["roi_f_type","experiment_id","sub_id","roi_task_id","f1","f2","rephase_with","power_spectrum"]]
        if roi_f_type == "f1f2":
            im_df = df[(df.roi_pval==pval) & (df.roi_fractional_overlap==roi_fractional_overlap) & (df.roi_f_type==roi_f_type) & (df.rephase==rephase) & (df.rephase_with==rephase_with)][["roi_f_type","experiment_id","sub_id","roi_task_id","f1","f2","rephase_with","power_spectrum"]]
    else:
        im_df = df[(df.roi_pval==pval) & (df.roi_fractional_overlap==roi_fractional_overlap) & (df.roi_f_type==roi_f_type) & (df.rephase==rephase)][["roi_f_type","experiment_id","sub_id","roi_task_id","f1","f2","rephase_with","power_spectrum"]]
        
    im_df["im_fsnr"] = im_df.apply(lambda row: calculate_fsnr(row,frequency_grid,f1_scalar,f2_scalar,foi_range),axis=1)

    control_im_df = im_df[(im_df.experiment_id.str.endswith("CONTROL"))]
    im_df = im_df[~(im_df.experiment_id.str.endswith("CONTROL"))]
    assert control_im_df.shape[0] == total_control_datasets
    assert im_df.shape[0] <= total_datasets


    #return control_im_df[f"bootstrap_pval_{im}"].values, im_df[f"bootstrap_pval_{im}"].values, control_im_df, im_df
    return control_im_df[f"im_fsnr"].values, im_df[f"im_fsnr"].values, control_im_df, im_df

def calculate_fraction_of_datasets(f1_scalar,f2_scalar,pval,roi_fractional_overlap,roi_f_type,rephase,rephase_with,foi_range):
    control_pvals, test_pvals, control_im_df, im_df = get_im_pvalues(f1_scalar,f2_scalar, pval, roi_fractional_overlap, roi_f_type, rephase, rephase_with, data_dict['frequency_grid'], foi_range)
    #im_df = im_df[(im_df.experiment_id.str.contains("NORMAL"))]
    max_fsnr = control_im_df.im_fsnr.values.max()
    pct_of_datasets = get_pct_of_datasets_with_gt_max_fsnr(im_df, max_fsnr)
    x,y = pct_of_datasets["ALL"].split('/')
    fraction_of_datasets = float(x)/float(y)
    return fraction_of_datasets

def calculate_im(f1,f2,f1_scalar,f2_scalar):
    return round(f1*f1_scalar+f2*f2_scalar,10)

from matplotlib.colors import LinearSegmentedColormap
# Define the RGBA values for your 8 colors
colors_rgba = [
    (1,1,1,1),
    (0,0,0,1),
    tuple([i for i in roi_c_dict['f1'][0]]),
    tuple([i for i in roi_c_dict['f2'][0]]),
    tuple([i for i in roi_c_dict['f2-f1'][0]]),
    tuple([i for i in roi_c_dict['2f1'][0]]),
    tuple([i for i in roi_c_dict['2f2'][0]]),
    tuple([i for i in roi_c_dict['f1+f2'][0]]),
    tuple([i for i in roi_c_dict['2f1-f2'][0]]),
    tuple([i for i in roi_c_dict['2f2-f1'][0]]),
]

# Create a colormap with the given RGBA values
custom_cmap_3 = LinearSegmentedColormap.from_list("custom_cmap", colors_rgba, N=len(colors_rgba))
# setup
all_test_frequencies = [
    [.125,.2], # entrain/atendaway
    [.125,.175], # entrainB
    [.125,.15], # entrainC
    [.15,.2], # entrainE
    [.175,.2], # entrainF
]
scalars = np.arange(-2,2+.5,.5)
n_scalars = scalars.shape[0]
X_scaling = np.zeros((n_scalars,n_scalars,2))
f1_fractions = np.zeros((n_scalars,n_scalars))
f2_fractions = np.zeros((n_scalars,n_scalars))
f1f2_fractions = np.zeros((n_scalars,n_scalars))
xticks = []
yticks = []
xticklabels = []
yticklabels = []
for f1_ix,f1_scalar in enumerate(scalars[::-1]):
    #f1_ix = len(scalars) - f1_ix - 1
    for f2_ix,f2_scalar in enumerate(scalars):
        f1_scalar = round(f1_scalar,2)
        f2_scalar = round(f2_scalar,2)
        if int(round(f1_scalar))==f1_scalar:
            xticks.append(f1_ix)
            xticklabels.append(int(f1_scalar))
        if int(round(f2_scalar))==f2_scalar:
            yticks.append(f2_ix)
            yticklabels.append(int(f2_scalar))
        #print(f1_ix, f2_ix, f1_scalar,f2_scalar)
        f1_fractions[f1_ix,f2_ix] = calculate_fraction_of_datasets(f1_scalar,f2_scalar,pval,roi_fractional_overlap,"f1",rephase,rephase_with,foi_range)
        f2_fractions[f1_ix,f2_ix] = calculate_fraction_of_datasets(f1_scalar,f2_scalar,pval,roi_fractional_overlap,"f2",rephase,rephase_with,foi_range)
        f1f2_fractions[f1_ix,f2_ix] = calculate_fraction_of_datasets(f1_scalar,f2_scalar,pval,roi_fractional_overlap,"f1f2",rephase,rephase_with,foi_range)

        X_scaling[f1_ix,f2_ix,0] = f1_scalar
        X_scaling[f1_ix,f2_ix,1] = f2_scalar
# Get mask
X = np.zeros((n_scalars,n_scalars))
im_frequencies_across_all_test_frequencies = {}
for test_frequencies in all_test_frequencies:
    _key = tuple([i for i in test_frequencies])
    im_frequencies_across_all_test_frequencies[_key] = np.zeros((n_scalars,n_scalars))
    for f1_ix,f1_scalar in enumerate(scalars):
        for f2_ix,f2_scalar in enumerate(scalars):
            f1_scalar = round(f1_scalar,2)
            f2_scalar = round(f2_scalar,2)
            im_frequency = f1_scalar*test_frequencies[0]+f2_scalar*test_frequencies[1]
            if im_frequency <= 0:
                X[9-f1_ix-1,f2_ix] += 1
            else:
                im_frequencies_across_all_test_frequencies[_key][f1_ix,f2_ix] = im_frequency

X[X>0] = -1
X+=1
def get_idx_from_grid(scalars,val):
    return np.where(scalars==val)[0][0]
bw_scalars = scalars[::-1]
X[get_idx_from_grid(bw_scalars,1),get_idx_from_grid(scalars,0)] = 2 # f1
X[get_idx_from_grid(bw_scalars,0),get_idx_from_grid(scalars,1)] = 3 # f2
X[get_idx_from_grid(bw_scalars,-1),get_idx_from_grid(scalars,1)] = 4 # f2-f1
X[get_idx_from_grid(bw_scalars,2),get_idx_from_grid(scalars,0)] = 5 # 2f1
X[get_idx_from_grid(bw_scalars,0),get_idx_from_grid(scalars,2)] = 6 # 2f2
X[get_idx_from_grid(bw_scalars,1),get_idx_from_grid(scalars,1)] = 7 # f1+f2
X[get_idx_from_grid(bw_scalars,2),get_idx_from_grid(scalars,-1)] = 8 # 2f1-f2
X[get_idx_from_grid(bw_scalars,-1),get_idx_from_grid(scalars,2)] = 9 # 2f2-f1

P = im_frequencies_across_all_test_frequencies[(.125,.2)][::-1,:]
for ax_key, _X,_cmap,_title in zip(
    ["3","3a","3b","3c","3d"],
    [P,X,f1_fractions,f2_fractions,f1f2_fractions],
    ["viridis",custom_cmap_3,"magma","magma","magma"],
    ["Power","IM frequencies",frequency_text_codes["f1"],frequency_text_codes["f2"],"Multiplex"],
):
    ax = ax_dict[ax_key]
    ax.tick_params(axis='y',direction='out',right=True,left=False,length=2,width=LINEWIDTH,pad=.2,colors='k')
    ax.tick_params(axis='x',direction='out',length=2,width=LINEWIDTH,pad=.2,colors='k')
    xticks = list(set(xticks))
    xticks.sort()
    yticks = list(set(yticks))
    yticks.sort()
    xticklabels = list(set(xticklabels))
    xticklabels.sort()
    yticklabels = list(set(yticklabels))
    yticklabels.sort()
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_xticklabels(xticklabels,fontsize=FONTSIZE)
    ax.set_yticklabels(yticklabels[::-1],fontsize=FONTSIZE)
    ax.set_xlabel(r"$\beta$",fontsize=FONTSIZE)
    ax.set_ylabel(r"$\alpha$",rotation=0,fontsize=FONTSIZE)
    if ax_key == "3":
        _X = np.ma.masked_where((X==0),_X)
        im = ax.imshow(_X,cmap=_cmap,vmax=np.nanmax(_X))
        ax.set_title(r"Frequency",fontsize=FONTSIZE)
    elif ax_key != "3a":
        _X = np.ma.masked_where((X==0),_X)
        im = ax.imshow(_X,cmap=_cmap,vmax=1.)
        ax.set_title(f"Region, {_title}",fontsize=FONTSIZE)
    else:
        im = ax.imshow(_X,cmap=_cmap,vmax=9)
        ax.set_title(r"$f_{\text{IM}}$"+"$=$"+r"$\alpha$$f_1$+$\beta$$f_2$",fontsize=FONTSIZE)
    for _spine in ["top","right","bottom","left"]:
        #ax.spines[_spine].set_visible(False)
        ax.spines[_spine].set_linewidth(LINEWIDTH)
    ax.yaxis.set_label_position("right")
    ax.yaxis.set_tick_params(labelleft=False, labelright=True, labelcolor='k')
    ax.xaxis.set_label_position("bottom")
    ax.xaxis.set_tick_params(labelbottom=True, labelright=False, labelcolor='k')
    ax.set_xlim(.5+1,8.5)
    ax.set_ylim(8.5,.5-1)

fig.savefig(MAIN / "Fig4_IM_count_across_fundamental_frequency_encoded_populations.png",dpi=300)

In [None]:
f1_fractions

In [None]:
import matplotlib.pyplot as plt
import numpy as np
fontsize=8; linewidth=.5
# Example data
data = np.random.rand(10, 10)

# Create a figure and axis
fig, ax = plt.subplots(figsize=(2,2),dpi=300,layout="constrained")

# Plot the data using imshow
im = ax.imshow(data, cmap='magma',vmin=0,vmax=1)

# Add the colorbar
#cbar = fig.colorbar(cax)

from mpl_toolkits.axes_grid1.inset_locator import inset_axes
axins = inset_axes(
    ax, width="5%",height="25%",loc="center right",borderpad=-1
)
cbar = fig.colorbar(im, cax=axins)
cbar.outline.set_linewidth(0)
cbar.set_ticks([0,1])
cbar.ax.tick_params(axis='y', pad=0.4, length=2, width=linewidth, direction="in", which="both",colors='white')
cbar.set_ticklabels([f"{i:.0f}" for i in [0,1]], fontsize=fontsize,c='k')
cbar.set_label("Fraction of\nexperiments",fontsize=fontsize)
# Save the figure with the colorbar
fig.savefig(MAIN / "CBAR_f_im_count.png",dpi=300, bbox_inches="tight")

# Show the plot (optional)
plt.show()


In [None]:
fig, ax = plt.subplots(figsize=(2,2),dpi=300,)
ax.set_xlim(0,15)
ax.set_ylim(0,35)

counter = 0
for ix, i in enumerate(roi_c_dict.keys()):
    counter +=3
    c = roi_c_dict[i]
    txt = frequency_text_codes[i]
    ax.scatter(1, counter, marker='s',c=c,s=20, linewidths=.5, edgecolors="k")
    ax.text(1.7, counter-.9, txt,c='k',fontsize=8)
counter += 3
ax.scatter(1, counter, marker='s',c="k",s=20, linewidths=.5, edgecolors="k")
ax.text(1.7, counter-.9, "Control IM",c='k',fontsize=8)

ax.text(8, 2.6, f"$f_{2}$=0.2Hz", fontsize=6)
ax.text(8, 5, f"$f_{1}$=0.125Hz", fontsize=6)

ax.scatter(7, 9, s=20, edgecolors="k",linewidths=.5,c="lightgrey")
ax.text(7.5, 8.1, f"Control", fontsize=8)

ax.scatter(7, 9+3, s=20, edgecolors="k",linewidths=.5,c="k")
ax.text(7.5, 8.1+3, f"Entrainment", fontsize=8)

custom_cmap = LinearSegmentedColormap.from_list('custom_YlGn', colors, N=n_colors)
custom_cmap = [custom_cmap(i) for i in range(4)]
counter = 13
for i, (c, exp_id) in enumerate(zip(custom_cmap,["3T","7T","3T Vary", "7T Vary"])):
    counter += 3
    ax.scatter(7, counter, marker='s',s=20, edgecolors="k",linewidths=.5,c=c)
    ax.text(7.5, counter-.8, exp_id, fontsize=8)

fig.savefig(MAIN / "LEGEND_3.png",dpi=300, bbox_inches="tight")

In [None]:
for im in im_codes:
    for roi_f_type in roi_f_types:
        print(f"\nroi-{roi_f_type} im-{im}")
        _df = im_above_df[f"roi-{roi_f_type}_im-{im}"]
        for row_ix,row in _df.iterrows():
            print(row.experiment_id,row.sub_id,row.roi_task_id,row.im_frequency, row.im_fsnr)

In [None]:
mosaic = [[i for i in range(len(im_frequencies_across_all_test_frequencies))]]
fig, ax_dict = plt.subplot_mosaic(mosaic=mosaic,figsize=(4,1),dpi=300,layout="constrained")
for (_, ax), k in zip(ax_dict.items(),im_frequencies_across_all_test_frequencies.keys()):
    ax.imshow(im_frequencies_across_all_test_frequencies[k],cmap="magma")
    ax.set_title(f"{k}",fontsize=FONTSIZE)

In [None]:
show_top = 2
mosaic = ["im_frequencies","f1","f2","f1f2"]
fig, ax_dict = plt.subplot_mosaic(mosaic=[mosaic[1:]],figsize=(3,1.),dpi=300,layout="constrained")
for ax_ix, (ax_key, _X) in enumerate(zip(["f1","f2","f1f2"],[f1_fractions,f2_fractions,f1f2_fractions])):
    ax = ax_dict[ax_key]
    fractions = _X.copy()
    _ = ax.hist(fractions[X!=0].flatten(),width=.02,bins=50,zorder=2)
    # Add vertical lines
    top_vals = fractions.flatten()
    top_vals = list(set([i for i in top_vals]))
    top_vals.sort()
    top_vals = top_vals[-show_top:][::-1]
    for rank_ix, _top_val in enumerate(top_vals):
        coef = X_scaling[fractions==_top_val]
        f1_coef = coef[:,0]
        f2_coef = coef[:,1]
        for _f1,_f2 in zip(f1_coef,f2_coef):
            if ax_ix == 0 and _f1 == 1 and _f2 == 0:
                ax.plot([_top_val,_top_val],[ax.get_ylim()[0],ax.get_ylim()[-1]*.3],color='r',linewidth=LINEWIDTH,zorder=1)
                ax.text(_top_val,(ax.get_ylim()[-1]*.3)+1.4,format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE,ha="center",va="center")
            elif ax_ix == 0 and _f1 == 2 and _f2 == 0:
                ax.plot([_top_val,_top_val],[ax.get_ylim()[0],ax.get_ylim()[-1]*.6],color='r',linewidth=LINEWIDTH,zorder=1)
                ax.text(_top_val,(ax.get_ylim()[-1]*.6)+1.4,format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE,ha="center",va="center")
            elif ax_ix == 0 and _f1 == 0 and _f2 == 1:
                ax.plot([_top_val,_top_val],[ax.get_ylim()[0],ax.get_ylim()[-1]*.6],color='r',linewidth=LINEWIDTH,zorder=1)
                ax.text(_top_val,(ax.get_ylim()[-1]*.6)+1.4,format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE,ha="center",va="center")
            elif ax_ix == 1 and _f1 == 0 and _f2 == 1:
                ax.plot([_top_val,_top_val],[ax.get_ylim()[0],ax.get_ylim()[-1]*.6],color='r',linewidth=LINEWIDTH,zorder=1)
                ax.text(_top_val,(ax.get_ylim()[-1]*.6)+1.4,format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE,ha="center",va="center")
            elif ax_ix == 1 and _f1 == 0 and _f2 == .5:
                ax.plot([_top_val,_top_val],[ax.get_ylim()[0],ax.get_ylim()[-1]*.6],color='darkgrey',linewidth=LINEWIDTH,zorder=1)
                ax.text(_top_val,(ax.get_ylim()[-1]*.6)+1.4,format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE,ha="center",va="center")
            elif ax_ix == 1 and _f1 == 0 and _f2 == 2:
                ax.plot([_top_val,_top_val],[ax.get_ylim()[0],ax.get_ylim()[-1]*10.],color='r',linewidth=LINEWIDTH,zorder=1)
                ax.text(_top_val,(ax.get_ylim()[-1]*.3)+1.4,format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE,ha="center",va="center",zorder=6)
                
            elif ax_ix == 2 and _f1 == 0 and _f2 == 1:
                ax.plot([_top_val-.005,_top_val-.005],[ax.get_ylim()[0],ax.get_ylim()[-1]*.6],color='r',linewidth=LINEWIDTH,zorder=1)
                ax.text(_top_val,(ax.get_ylim()[-1]*.6)+1.4,format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE,ha="center",va="center")
            elif ax_ix == 2 and _f1 == -1 and _f2 == 1:
                ax.plot([_top_val,_top_val],[ax.get_ylim()[0],ax.get_ylim()[-1]*.6],color='r',linewidth=LINEWIDTH,zorder=1)
                ax.text(_top_val-.1,(ax.get_ylim()[-1]*.6)+1.4,format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE,ha="center",va="center")
            elif ax_ix == 2 and _f1 == 1 and _f2 == 0:
                ax.plot([_top_val,_top_val],[ax.get_ylim()[0],ax.get_ylim()[-1]*.3],color='r',linewidth=LINEWIDTH,zorder=1)
                ax.text(_top_val-.05,(ax.get_ylim()[-1]*.3)+1.4,format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE,ha="center",va="center")
            else:
                ax.text(_top_val-.07,np.random.uniform(ax.get_ylim()[0],ax.get_ylim()[-1]*.5),format_f1_f2_equation(_f1,_f2),rotation=0,fontsize=FONTSIZE)
    ax.tick_params(axis="both", length=2.,width=LINEWIDTH,pad=0)
    if ax_ix == 0:
        ax.set_ylabel(r"$f_{\text{IM}}$"+" count (n=37)", fontsize=FONTSIZE)
    if ax_ix == 1:
        ax.set_xlabel("Fraction of experiments", fontsize=FONTSIZE)
    for _spine in ["top","right"]:
        ax.spines[_spine].set_visible(False)
    for _spine in ["left","bottom"]:
        ax.spines[_spine].set_linewidth(LINEWIDTH)
    ax.set_xlim(0,1)
    ax.set_xticks([0,1])
    ax.set_xticklabels([0,1],fontsize=FONTSIZE)
    ax.set_ylim(0,12)
    ax.set_yticks([0,12])
    ax.set_yticklabels([0,12],fontsize=FONTSIZE)

fig.savefig(MAIN / "Fig4_control_and_true_im_count_across_fundamental_frequency_encoded_populations.png",dpi=300)