In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import random
from scipy.signal import welch
from collections import defaultdict

FONTSIZE = 4
TR = .3
n_permutations = 1_000
n_bootstraps = 200

datadir = Path("/scratch/fastfmri")

sub_to_task_mapping = {
    "020": [
        ("entrainA", [.125, .2]),
        ("entrainB", [.125, .175]),
        ("entrainC", [.125, .15]),
    ],
    "021": [
        ("entrainD", [.125, .2]),
        ("entrainE", [.15, .2]),
        ("entrainF", [.175, .2]),
    ],
}

def read_pkl(datadir, n_bootstraps, sub_id, roi_task_id, roi_frequency, task_id, experiment_id="1_frequency_tagging", mri_id="7T", fo=.8, roi_frequency_2=None, control_roi_size=False):

    import pickle

    if roi_frequency_2 is not None:
        if control_roi_size:
            bootstrap_pkl: Path = datadir / f"experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-{n_bootstraps}_batch-merged_desc-basic_roi-{roi_task_id}-{roi_frequency}_controlroisizetomatch-{roi_frequency_2}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/task-{task_id}_bootstrapped_data.pkl"
        else:
            bootstrap_pkl: Path = datadir / f"experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-{n_bootstraps}_batch-merged_desc-basic_roi-{roi_task_id}-{roi_frequency}-{roi_frequency_2}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/task-{task_id}_bootstrapped_data.pkl"
    else:
        bootstrap_pkl: Path = datadir / f"experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-{n_bootstraps}_batch-merged_desc-basic_roi-{roi_task_id}-{roi_frequency}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/task-{task_id}_bootstrapped_data.pkl"
    if not bootstrap_pkl.exists():
        print(f"Warning: {bootstrap_pkl} does not exist.\nReturn None")
        return None

    print(f"Reading: {bootstrap_pkl}")
    with open(bootstrap_pkl, 'rb') as f:
        data = pickle.load(f)

    return data

def set_base_dir(basedir):
    basedir = Path(basedir)
    if not basedir.exists():
        basedir.mkdir(exist_ok=True, parents=True)

    return basedir


class TimeSeries:
    def __init__(self, ts, TR, n_permutations=5_000, nperseg=600):
        self.timeseries = ts
        self.fs = 1/TR
        self.nperseg = nperseg
        self.n_permutations = n_permutations
        self.frequencies = None
                
    def process(self, search_frequencies):
        p_values, observed_statistics = {}, {}
        for f in search_frequencies:
            observed_statistic, observed_power_spectrum = self.calculate_observed_statistic(f)
            observed_statistics[f] = observed_statistic
            null_statistics, null_power_spectrums = self.calculate_null_statistics(f)
            p_values[f] = (np.sum(null_statistics >= observed_statistic) + 1) / (n_permutations+1)

        return p_values, observed_statistics, observed_power_spectrum, null_power_spectrums

    def calculate_observed_statistic(self, f):
        if self.frequencies is None:
            self.frequencies, power_spectrum = self._estimate_power_spectrum(self.timeseries)
        else:
            _, power_spectrum = self._estimate_power_spectrum(self.timeseries)
        power = self._estimate_power(self.timeseries, f)

        return power, power_spectrum

    def calculate_null_statistics(self, f):
        null_power_spectrums = []
        null_statistics = []
        for i in range(n_permutations):
            y_shuffle = np.random.permutation(self.timeseries.copy())
            null_power_spectrums.append(self._estimate_power_spectrum(y_shuffle)[1])
            null_statistics.append(self._estimate_power(y_shuffle, f))

        return null_statistics, null_power_spectrums

    def _estimate_power_spectrum(self, ts):
        frequencies, power_spectrum = welch(ts, self.fs, nperseg=self.nperseg)

        return (frequencies, power_spectrum)

    def _estimate_power(self, ts, f):
        frequencies, power_spectrum = self._estimate_power_spectrum(ts)
        return np.interp(f, frequencies, power_spectrum)

In [None]:
def decorate_fig_1B(fig, ax, frequencies, p_values, observed_statistics, add_im=False, sub_id=None, roi_frequency=None, fontsize=FONTSIZE):
    for f in frequencies:
        ax.text(f+.005, observed_statistics[f], f"p={-np.log10(p_values[f]):.3f}", fontsize=FONTSIZE)
    if add_im:
        _frequencies = frequencies.copy()
        second_order_frequencies = [
            np.abs(frequencies[0]-frequencies[1]), 
            np.abs(frequencies[1]+frequencies[0]),
            frequencies[0]*2,
            frequencies[1]*2,
        ]
        third_order_frequencies = [
            np.abs(2*frequencies[0] - frequencies[1]),
            np.abs(2*frequencies[1] - frequencies[0]),
        ]
        _frequencies += second_order_frequencies
        _frequencies += third_order_frequencies
        #import pdb; pdb.set_trace()
    else:
        _frequencies = frequencies
    for f in _frequencies:
        if f not in second_order_frequencies and f not in third_order_frequencies:
            c = 'b'
        elif f in second_order_frequencies:
            c = 'cyan'
        elif f in third_order_frequencies:
            c = 'g'
        else: 
            raise ValueError(f"{f} not identified as a harmonic.")
        ax.axvline(x=f, c=c, linestyle=':', zorder=1, lw=.75)
    ax.set_xlim((0,.5))
    ax.set_ylabel("Power", fontsize=FONTSIZE)
    ax.set_xlabel("Frequency", fontsize=FONTSIZE)
    ax.tick_params(axis="both", length=0, labelsize=FONTSIZE)
    for i in ("top", "right", "bottom", "left"):
        ax.spines[i].set_visible(False)
    ax.set_title(f"{sub_id}, roi-{roi_frequency}, {frequencies}", fontsize=FONTSIZE)

    return fig, ax

def plot_power_spectrum(ts, observed_power_spectrum, null_power_spectrums, add_im=False, sub_id=None, roi_frequency=None, close_figure=False, png_out=None):
    """Fig 1b"""
    fig, ax = plt.subplots(figsize=(2,1), dpi=400)
    ax.plot(ts.frequencies, observed_power_spectrum, c='k', zorder=2, lw=.5)
    null_power_spectrums = np.vstack(null_power_spectrums)
    null_power_spectrum = np.mean(null_power_spectrums, axis=0)
    std_dev_values = np.std(null_power_spectrum, axis=0)
    confidence_interval = 1.96 * std_dev_values / np.sqrt(n_permutations)
    ax.fill_between(
        ts.frequencies, 
        null_power_spectrum - confidence_interval, null_power_spectrum + confidence_interval,
        color='r', 
        alpha=.8,
    )
    fig, ax = decorate_fig_1B(fig, ax, frequencies, p_values, observed_statistics, add_im=add_im, sub_id=sub_id, roi_frequency=roi_frequency)

    fig.tight_layout()

    if png_out:
        fig.savefig(png_out,dpi='figure')

    if close_figure:
        plt.close()

def save_bootstrapped_statistics(rel_path, data_dict, pkldir = Path("/scratch/fastfmri/pickles")):

    import pickle

    if not pkldir.exists():
        pkldir.mkdir(exist_ok=True, parents=True)
        
    file_path = pkldir / rel_path 
    with open(file_path, 'wb') as f:
        pickle.dump(data_dict, f)

    assert file_path.exists(), f"{file_path} not created."

def check_difference(arr, diff=.3):
    differences = np.diff(arr)

    return np.all(np.isclose(differences,diff))

def extract_carpet_data(data, task_id, task_quadrant, bootstrap_id, phased_flag):

    if phased_flag:
        data_tps = data[f'data-test_task-{task_id}{task_quadrant}_roi_phaseadjusted_timepoints']
        data_bold = data[f'data-test_task-{task_id}{task_quadrant}_roi_phaseadjusted_bold'][:,:,bootstrap_id]
        updated_tps, mean_bold = [], []
        n_voxels = data_bold.shape[1]
        for single_tp in np.unique(data_tps):
            coords = (data_tps == single_tp)
            tp_all = np.all((coords).sum(0)) == 1
            if tp_all:
                updated_tps.append(single_tp)
                mean_bold.append(data_bold[coords])
        updated_tps = np.array(updated_tps)
        assert check_difference(updated_tps)
        
        return updated_tps, np.vstack(mean_bold)

    else:
        # Select timeseries (timepoints x voxels x bootstraps)
        return data[f'data-test_task-{task_id}{task_quadrant}_roi_timepoints'], data[f'data-test_task-{task_id}{task_quadrant}_roi_bold'][:,:,bootstrap_id]
                
                
def extract_bootstrapped_mean_from_data(data, task_id_2, task_quadrant, bootstrap_id, phased_flag):

    if phased_flag:
        data_tps = data[f'data-test_task-{task_id_2}{task_quadrant}_roi_phaseadjusted_timepoints']
        data_bold = data[f'data-test_task-{task_id_2}{task_quadrant}_roi_phaseadjusted_bold'][:,:,bootstrap_id]
        updated_tps, mean_bold = [], []
        n_voxels = data_bold.shape[1]
        for single_tp in np.unique(data_tps):
            coords = (data_tps == single_tp)
            tp_all = np.all((coords).sum(0)) == 1
            if tp_all:
                updated_tps.append(single_tp)
                mean_bold.append(data_bold[coords].mean())

        updated_tps = np.array(updated_tps)
        mean_bold = np.array(mean_bold)
        assert check_difference(updated_tps)
        assert mean_bold.shape == updated_tps.shape

        return updated_tps, mean_bold

    else:
        x = data[f'data-test_task-{task_id_2}{task_quadrant}_roi_timepoints']
        y = data[f'data-test_task-{task_id_2}{task_quadrant}_roi_bold'][:,:,bootstrap_id]
        tps = x.mean(1)
        y_bootstrapped_mean = y.mean(1)

        return tps, y_bootstrapped_mean

7T intermodulation experiment
- sub-020
    - task_roi: entrainA, look for f in [.125, .2] 
        - roi: f=.125
        - roi: f=.2
    - task_roi: entrainB, look for f in [.125, .175]
        - roi: f=.125
        - roi: f=.175
    - task_roi: entrainC, look for f in [.125, .15]
        - roi: f=.125
        - roi: f=.15
- sub-021
    - task_roi: entrainD, look for f in [.125, .2]
        - roi: f=.125
        - roi: f=.2
    - task_roi: entrainE, look for f in [.15, .2]
        - roi: f=.15
        - roi: f=.2
    - task_roi: entrainF, look for f in [.175, 2]
        - roi: f=.175
        - roi: f=.2

1) For a task condition show carpet plots for each frequency (carpet plot for each frequency of a `task_roi`/two plots total)
    - show example of one bootstrap only
2) Compute mean timeseries across each ROI and bootstrap (**carpet plot** shows that there are two population of phase shifts, therefore computing the mean is ill-advised, unless phase shifted)
    - compute statistics of the observed frequency using timeseries shuffling
        - provides a p-value for each frequency in the task_roi
    - also compute mean timeseries across bootstrapped mean timeseries `y_bootstrapped_mean`, and compute statistics
3) Compute power spectrum from `y_bootstrapped_mean` for all `task_roi`s
    - create carpet plot (carpet plot for each frequency of a `task_roi`/two plots total)
        - might have to play with normalization in order to emphasize the frequencies of interest

- Consider subsequent analyses using localizers from under task conditions to reproduce results within a subject
    - i.e., sub-020, task: `entrainD`, can resolve consistent results with `entrainE` and `entrainF`

In [None]:
def find_raw_bold(i):
    
    import os
    
    experiment_id = str(i.parent).split("experiment-")[1].split('_mri-')[0]
    mri_id = str(i.parent).split("mri-")[1].split('_')[0]
    sub_id = i.stem.split('sub-')[1].split('_')[0]
    ses_id = i.stem.split('ses-')[1].split('_')[0]
    task_id = i.stem.split('task-')[1].split('_')[0]
    run_id = i.stem.split('run-')[1].split('_')[0]

    directory = f"/data/{experiment_id}/{mri_id}/bids/derivatives/oscprep_grayords_fmapless/bold_preproc/sub-{sub_id}/ses-{ses_id}/func"
    raw_bold = [f"{directory}/{file}" for file in os.listdir(directory) if f"run-{run_id}" in file and f"task-{task_id}" in file and file.endswith("bold.dtseries.nii")]
    assert len(raw_bold) == 1, f"Multiple raw bolds found: {raw_bold}"

    return Path(raw_bold[0])

def average_bold(bold_list):
    for bold_ix, bold in enumerate(bold_list):
        _bold_data = nib.load(bold).get_fdata()
        if bold_ix == 0:
            y_all = _bold_data.copy() 
        else:
            y_all += _bold_data.copy()
        
    y_all /= len(bold_list)
    y_all = (( y_all - y_all.mean(0)) / y_all.std(0) ).T

    return y_all

def read_bootstrap_txt(bootstrap_txt, bootstrap_idx):
    with open(bootstrap_txt, "r") as f:
        lines = f.readlines()

    fs = lines[bootstrap_idx]
    raw_bolds = []
    raw_windowed_bolds = []
    processed_bolds = []
    for i in fs.split(','):
        i = Path(i.strip())
        raw_windowed_bold = Path(str(i).replace("desc-denoised_bold.dtseries.nii","desc-windowed_bold.dtseries.nii"))
        raw_bold = find_raw_bold(Path(i))
        assert i.exists(), f"{i} not found."
        assert raw_windowed_bold.exists(), f"{raw_windowed_bold} not found"
        assert raw_bold.exists(), f"{raw_bold} not found."
        raw_bolds.append(raw_bold)
        raw_windowed_bolds.append(raw_windowed_bold)
        processed_bolds.append(i)

    raw_avg = average_bold(raw_bolds)
    raw_windowed_avg = average_bold(raw_windowed_bolds)
    processed_avg = average_bold(processed_bolds)

    #print(f"Raw: {raw_avg.shape}")
    #print(f"Raw: {raw_windowed_avg.shape}")
    #print(f"Raw: {processed_avg.shape}")

    return raw_avg, raw_windowed_avg, processed_avg

def find_quadrant_id_from_keys(_dict):
    for i in _dict.keys():
        if task_id in i:
            q_idx = i.find("Q")
            q_id = i[q_idx:q_idx+2]
            assert q_id in ['Q1', 'Q2']
            return q_id
    raise ValueError("No quadrant id found.")


Plot **test** set timeseries of a single bootstrap from a 50/50 data split for each subject. Timeseries are extracted from ROIs defined by the fractional overlap across 200 bootstrapped activation profiles. This procedure generates vertices with f1, f2, f1/f2 intersection encoding.
- Plot raw, windowed, windowed+denoised, windowed+denoised+rephased
    - Phase delays are modelled from the **train** set, and applied to the **test** set 
    - saved directory: `./ComputeCanada/frequency_tagging/figures/dual_frequency_timeseries`

In [None]:
def decorate_fig_1A(fig, ax, im, f1, f2, n_f1, n_f2, n_f1f2, FONTSIZE=FONTSIZE, TR=TR):

    cbar = plt.colorbar(im, ax=ax, shrink=.5, drawedges=False)
    cbar.ax.set_title("Z-score", fontsize=FONTSIZE-2)
    cbar.ax.tick_params(axis="both", length=0, labelsize=FONTSIZE)
    cbar.outline.set_edgecolor('none')

    #ax.set_title(f"{sub_id}, roi-{task_id}, roi-frequency-{f}", fontsize=FONTSIZE)
    ax.title.set_position([.75,1.05])

    ax.set_ylabel("Voxel", fontsize=FONTSIZE)
    ax.set_yticks([])

    ax.set_xlabel("Acquisition Time (s)", fontsize=FONTSIZE)
    xticks = [i for i in ax.get_xticks()[1:]]
    ax.set_xticks(xticks)
    ax.set_xticklabels([f"{i*TR:.2f}" for i in xticks], fontsize=FONTSIZE)
    ax.tick_params(axis="both", length=0)

    period_f1 = 1/f1
    period_f2 = 1/f2
    
    ax.plot([0,period_f1/TR], [-20,-20], c='white', zorder=1)
    ax.plot([0,period_f2/TR], [-50,-50], c='white', zorder=2)
    
    square_f1 = plt.Polygon([(-2, 0), (-15, 0), (-15, n_f1), (-2, n_f1)], closed=True, color='red', linewidth=0.)
    square_f1f2 = plt.Polygon([(-2, n_f1), (-15, n_f1), (-15, n_f1+n_f1f2), (-2, n_f1+n_f1f2)], closed=True, color='gold', linewidth=0.)
    square_f2 = plt.Polygon([(-2, n_f1+n_f1f2), (-15, n_f1+n_f1f2), (-15, n_f1+n_f1f2+n_f2), (-2, n_f1+n_f1f2+n_f2)], closed=True, color='blue', linewidth=0.)
    timescale_f1 = plt.Polygon([(0, -10), (period_f1/TR, -10), (period_f1/TR, -30), (0, -30)], closed=True, color='red', linewidth=0., zorder=10)
    timescale_f2 = plt.Polygon([(0, -40), (period_f2/TR, -40), (period_f2/TR, -60), (0, -60)], closed=True, color='blue', linewidth=0., zorder=10)

    for square in [square_f1, square_f1f2, square_f2, timescale_f1, timescale_f2]:
        ax.add_patch(square)

    for i in ("top", "right", "bottom", "left"):
        ax.spines[i].set_visible(False)

    return fig, ax

Experiment settings

In [None]:
datadir = Path("/scratch/fastfmri")
n_bootstraps = 200
bootstrap_id = 0
roi_frequency_2 = None
control_roi_size = False
window_size = (39, 219)
close_figures = True
    
experiment_id = "1_frequency_tagging" 
normal_3T_sub_ids = ["000", "002", "003", "004", "005", "006", "007", "008", "009"] 
normal_7T_sub_ids = ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]
vary_sub_ids = ["020"]*3 + ["021"]*3
vary_task_ids = [f"entrain{i}" for i in ["A", "B", "C", "D", "E", "F"]]

fos = [.4,.6,.8,1.]
sub_ids = normal_3T_sub_ids*2 + normal_7T_sub_ids + vary_sub_ids*3*2
experiment_ids = ["1_frequency_tagging"]*len(normal_3T_sub_ids)*2 + ["1_attention"]*len(normal_7T_sub_ids) + ["1_frequency_tagging"]*len(vary_sub_ids)*3*2
mri_ids = ["3T"]*len(normal_3T_sub_ids)*2 + ["7T"]*len(normal_7T_sub_ids) + ["3T"]*len(vary_sub_ids)*3 + ["7T"]*len(vary_sub_ids)*3
roi_task_ids= ["entrain"]*len(normal_3T_sub_ids) + ["entrain"]*len(normal_3T_sub_ids) + ['AttendAway']*len(normal_7T_sub_ids) + (["entrainA"]*3 + ["entrainD"]*3 + ["entrainB"]*3 + ["entrainE"]*3 + ["entrainC"]*3 + ["entrainF"]*3) * 2
task_ids= ["control"]*len(normal_3T_sub_ids) + ["entrain"]*len(normal_3T_sub_ids) + ['AttendAway']*len(normal_7T_sub_ids) + vary_task_ids*3*2
roi_frequencies = [[.125, .2]]*(len(normal_3T_sub_ids)*2+len(normal_7T_sub_ids)) + ([[.125,.2]]*3 + [[.125,.2]]*3 + [[.125,.175]]*3 + [[.15,.2]]*3 + [[.125,.15]]*3 + [[.175,.2]]*3) * 2
task_frequencies = [[.125, .2]]*(len(normal_3T_sub_ids)*2+len(normal_7T_sub_ids)) + [[.125,.2],[.125,.175],[.125,.15],[.125,.2],[.15,.2],[.175,.2]]*3*2

for i in [sub_ids, experiment_ids, mri_ids, roi_task_ids, task_ids, roi_frequencies, task_frequencies]:
    print(len(i))

Run all

In [None]:
for fo in fos:
    for experiment_id, mri_id, sub_id, roi_task_id, task_id, frequencies, _task_frequencies in zip(
        experiment_ids, mri_ids, sub_ids, roi_task_ids, task_ids, roi_frequencies, task_frequencies,
    ):

        f_data = {}
        for roi_f in frequencies:
            f_data[roi_f] = read_pkl(
                datadir, 
                n_bootstraps, 
                sub_id, 
                roi_task_id, 
                roi_f, 
                task_id,
                experiment_id=experiment_id,
                mri_id=mri_id,
                fo=fo,
                roi_frequency_2=roi_frequency_2,
                control_roi_size=control_roi_size,
            )

        task_quadrant = find_quadrant_id_from_keys(f_data[frequencies[0]])
                
        assert frequencies[1]>frequencies[0]
        f1_dict = f_data[frequencies[0]].copy()
        f2_dict = f_data[frequencies[1]].copy()
        f1_coords = f1_dict['roi_coords']
        f2_coords = f2_dict['roi_coords']
        f1_only_coords = f1_coords.astype(int) + f2_coords.astype(int)
        f1_only_coords = f1_only_coords[f1_coords]
        f2_only_coords = f1_coords.astype(int) + f2_coords.astype(int)
        f2_only_coords = f2_only_coords[f2_coords]
        # Masks
        inter_from_f1 = f1_only_coords == 2
        f1_from_f1 = f1_only_coords == 1
        f2_from_f2 = f2_only_coords == 1
        n_f1, n_f1f2, n_f2 = f1_from_f1.sum(), inter_from_f1.sum(), f2_from_f2.sum()

        # Load data from nifti
        # 1) untruncated, 2) preprocessed
        bootstrap_txt = Path(f"/scratch/fastfmri/experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-{window_size[0]}-{window_size[1]}_n-100_batch-00_desc-basic_bootstrap/sub-{sub_id}/task-{task_id}{task_quadrant}_test_splits.txt")
        assert bootstrap_txt.exists()
        data_from_dtseries_raw, data_from_dtseries_windowed, data_from_dtseries_preprocessed = read_bootstrap_txt(bootstrap_txt, bootstrap_id) # Load single bootstrap

        data_from_dtseries_raw = np.hstack(
            [
                data_from_dtseries_raw[f1_coords,:][f1_from_f1,:].T,
                data_from_dtseries_raw[f1_coords,:][inter_from_f1,:].T,
                data_from_dtseries_raw[f2_coords,:][f2_from_f2,:].T,
            ]
        )
        data_from_dtseries_windowed = np.hstack(
            [
                data_from_dtseries_windowed[f1_coords,1:][f1_from_f1,:].T,
                data_from_dtseries_windowed[f1_coords,1:][inter_from_f1,:].T,
                data_from_dtseries_windowed[f2_coords,1:][f2_from_f2,:].T,
            ]
        )
        data_from_dtseries_preprocessed = np.hstack(
            [
                data_from_dtseries_preprocessed[f1_coords,1:][f1_from_f1,:].T,
                data_from_dtseries_preprocessed[f1_coords,1:][inter_from_f1,:].T,
                data_from_dtseries_preprocessed[f2_coords,1:][f2_from_f2,:].T,
            ]
        )
        # Load data from pickle
        # 3) preprocessed, 4) preprocessed & phased
        _, f1_data_from_pkl_preprocessed = extract_carpet_data(f1_dict, task_id, task_quadrant, bootstrap_id, False)
        f1_phased_tps, f1_data_from_pkl_preprocessed_phased = extract_carpet_data(f1_dict, task_id, task_quadrant, bootstrap_id, True)
        _, f2_data_from_pkl_preprocessed = extract_carpet_data(f2_dict, task_id, task_quadrant, bootstrap_id, False)
        f2_phased_tps, f2_data_from_pkl_preprocessed_phased = extract_carpet_data(f2_dict, task_id, task_quadrant, bootstrap_id, True)
        intersected_phased_tps = [i for i in set(f1_phased_tps).intersection(f2_phased_tps)]
        f1_phased_tp_mask = [tp in intersected_phased_tps for tp in f1_phased_tps]
        f2_phased_tp_mask = [tp in intersected_phased_tps for tp in f2_phased_tps]

        data_from_pkl_preprocessed = np.hstack(
            [
                f1_data_from_pkl_preprocessed[:,f1_from_f1],
                f1_data_from_pkl_preprocessed[:,inter_from_f1],
                f2_data_from_pkl_preprocessed[:,f2_from_f2],
            ]
        )
        data_from_pkl_preprocessed_phased = np.hstack(
            [
                f1_data_from_pkl_preprocessed_phased[:,f1_from_f1][f1_phased_tp_mask,:],
                f1_data_from_pkl_preprocessed_phased[:,inter_from_f1][f1_phased_tp_mask,:],
                f2_data_from_pkl_preprocessed_phased[:,f2_from_f2][f2_phased_tp_mask,:],
            ]
        )

        ts_labels = [
            "raw", "windowed", "denoised", "denoised_rephased"
        ]
        ts_data = [
            data_from_dtseries_raw, 
            data_from_dtseries_windowed,
            data_from_dtseries_preprocessed,
            #data_from_pkl_preprocessed, 
            data_from_pkl_preprocessed_phased,
        ]

        stim_start = 14
        cmap = "Greys_r"
        vmin, vmax = -1.31, 1.31

        # Get sorting order based on `data_from_dtseries_preprocessed`
        # Sort for each set of vertices: f1, f1f2, and f2 (this is the order that the reoriented data)
        # Note: `data_from_dtseries_preprocessed` == `data_from_pkl_preprocessed`

        try:
            # This will error out if there is any of f1, f2, or f1f2 has 0 vertices.. I THINK?
            y = data_from_dtseries_preprocessed.copy()
            y = (( y - y.mean(0)) / y.std(0) ).T
            sorted_voxels = {}
            y_f1 = y[:n_f1,:].copy()
            y_f1f2 = y[n_f1:n_f1+n_f1f2,:].copy()
            y_f2 = y[n_f1+n_f1f2:,:].copy()
            for f_group, y in zip(["f1","f1f2","f2"], [y_f1, y_f1f2, y_f2]):
                C = np.corrcoef(y)
                correlation_strength = np.abs(C).sum(axis=1)
                sorted_voxels[f_group] = np.argsort(correlation_strength)[::-1]

            for y_ix, (y_label, y) in enumerate(zip(ts_labels, ts_data)):
                fig, ax = plt.subplots(
                    nrows=1,ncols=1, figsize=(2,1.), dpi=300,
                    #gridspec_kw=dict(height_ratios=[286, 119]),
                )
                
                y = (( y - y.mean(0)) / y.std(0) ).T

                # Sort
                y[:n_f1, :] = y[:n_f1,:][sorted_voxels["f1"],:]
                y[n_f1:n_f1+n_f1f2, :] = y[n_f1:n_f1+n_f1f2,:][sorted_voxels["f1f2"],:]
                y[n_f1+n_f1f2:, :] = y[n_f1+n_f1f2:,:][sorted_voxels["f2"],:]

                im = ax.imshow(y, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
                if y_ix == 0:
                    for i in window_size:
                        ax.axvline(x=i/TR, color='yellow', linestyle='-', lw=1.)
                    ax.axvline(x=stim_start/TR, color='orange', linestyle='-', lw=1.)
                fig, ax = decorate_fig_1A(fig, ax, im, _task_frequencies[0], _task_frequencies[1], n_f1, n_f2, n_f1f2, FONTSIZE=FONTSIZE, TR=TR)

                fig.tight_layout()

                png_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/dual_frequency_timeseries")) / f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_task-{roi_task_id}_task-{task_id}_fo-{fo}_{y_ix}{y_label}.png"
                fig.savefig(png_out, dpi='figure')

                if close_figures:
                    plt.close()
        except:
            continue

Plot single-subject PSDs based on three ROIs (f1, f2, and f1/f2 intersection)
- main point: to show that peaks are clearly delineated in every subject (Shown as Z-scored power values)
- extra data will be generated as `.pkl`s and used in the next notebook to investigate trends of raw power values
- a PSD figure will be generated for each experiment
- hyperparameter: fractional overlap (or `fo`)

In [None]:
NORMAL_3T_SUB_IDS = ["000", "002", "003", "004", "005", "006", "007", "008", "009"]
NORMAL_3T = ("1_frequency_tagging", "3T", {key: [("entrain", [0.125, 0.2])] for key in NORMAL_3T_SUB_IDS})
NORMAL_7T_SUB_IDS = ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]
NORMAL_7T = ("1_attention", "7T", {key: [("AttendAway", [0.125, 0.2])] for key in NORMAL_7T_SUB_IDS})
VARY_3T = (
    "1_frequency_tagging",
    "3T",
    {
        "020": [
            ("entrainA", [.125, .2]),
            ("entrainB", [.125, .175]),
            ("entrainC", [.125, .15]),
        ],
        "021": [
            ("entrainD", [.125, .2]),
            ("entrainE", [.15, .2]),
            ("entrainF", [.175, .2]),
        ],
    }
)
VARY_7T = (
    "1_frequency_tagging",
    "7T",
    {
        "020": [
            ("entrainA", [.125, .2]),
            ("entrainB", [.125, .175]),
            ("entrainC", [.125, .15]),
        ],
        "021": [
            ("entrainD", [.125, .2]),
            ("entrainE", [.15, .2]),
            ("entrainF", [.175, .2]),
        ],
    }
)

def read_data(d):
    return d[0], d[1], d[2]

Used for generating all pickles
- Change `task_id_2`, and looping parameters (i.e., `dataset_id`, `control_roi_size`, and `phased_flag`)

In [None]:
import itertools

#for dataset_id, control_roi_size, phased_flag in itertools.product([NORMAL_7T], [True, False], [True,False]):
for dataset_id, control_roi_size, phased_flag in itertools.product([NORMAL_7T], [True,False], [True,False]):

    experiment_id, mri_id, sub_to_task_mapping = read_data(dataset_id)

    control_roi_size = control_roi_size
    phased_flag = phased_flag
    fo = .8
    frequency_grid = None
    experiment_info = []
    subject_task_level_psd = defaultdict(list)
    for sub_ix, (sub_id, sub_task_info) in enumerate(sub_to_task_mapping.items()):
        for task_ix, (task_id, frequencies) in enumerate(sub_task_info):

            bold_list = !ls /data/{experiment_id}/{mri_id}/bids/sub-{sub_id}/*/func/*{task_id}*nii.gz
            task_quadrant = list(set([i.split(task_id)[1].split('_')[0] for i in bold_list]))
            assert len(task_quadrant) == 1, f"More than 1 task quadrant detected: {task_quadrant}"
            task_quadrant = task_quadrant[0]

            assert len(frequencies) == 2 and frequencies[0]<frequencies[1]

            if sub_id == "020":
                task_id_2 = "entrainC"
            else:
                task_id_2 = "entrainF"
            task_id_2 = task_id

            for f_ix, f in enumerate(frequencies + [frequencies]):

                if control_roi_size and f_ix < 2:
                    print(f"Skipping {task_id} {f}")
                    continue

                if isinstance(f, list) and len(f) == 2:
                    if control_roi_size:
                        data = read_pkl(datadir, n_bootstraps, sub_id, task_id, f[0], task_id_2, experiment_id=experiment_id, mri_id=mri_id, fo=fo, roi_frequency_2=f[1], control_roi_size=control_roi_size)
                    else:
                        data = read_pkl(datadir, n_bootstraps, sub_id, task_id, f[0], task_id_2, experiment_id=experiment_id, mri_id=mri_id, fo=fo, roi_frequency_2=f[1])
                elif isinstance(f, float):
                    data = read_pkl(datadir, n_bootstraps, sub_id, task_id, f, task_id_2, experiment_id=experiment_id, mri_id=mri_id, fo=fo)
                else:
                    raise ValueError("")
                    
                if data is None:
                    print(f"sub-{sub_id}, ROI info: {task_id}/f={f} [n=0], task-frequencies: {frequencies}, No data found")
                    subject_task_level_psd[f"f-{f_ix}"].append(None)
                    continue
                n_voxels = data['roi_coords'].sum()
                print(f"sub-{sub_id}, ROI info: {task_id}/f={f} [n={n_voxels}], task-frequencies: {frequencies}")
                
                ####
                if task_id_2 == "AttendAway":
                    frequencies_to_probe = [.125, .2, .075]
                if task_id_2 == "entrain":
                    frequencies_to_probe = [.125, .2, .075]
                if task_id_2 == "control":
                    frequencies_to_probe = [.125, .2, .075]
                if task_id_2 == "entrainA":
                    frequencies_to_probe = [.125, .2, .075]
                if task_id_2 == "entrainB":
                    frequencies_to_probe = [.125, .175, .05]
                if task_id_2 == "entrainC":
                    frequencies_to_probe = [.125, .15, .025]
                if task_id_2 == "entrainD":
                    frequencies_to_probe = [.125, .2, .075]
                if task_id_2 == "entrainE":
                    frequencies_to_probe = [.15, .2, .05]
                if task_id_2 == "entrainF":
                    frequencies_to_probe = [.175, .2, .025]
                print(f"Processing {frequencies_to_probe} in sub-{sub_id}_task-{task_id_2}\n")
        
                # Select timeseries (timepoints x voxels x bootstraps)
                bootstrapped_means = []
                bootstrapped_statistics = defaultdict(list)
                for bootstrap_id in range(n_bootstraps):
                    
                    _, y_bootstrapped_mean = extract_bootstrapped_mean_from_data(data, task_id_2, task_quadrant, bootstrap_id, phased_flag)

                    bootstrapped_means.append(y_bootstrapped_mean) # Track

                    ts = TimeSeries(y_bootstrapped_mean, TR, n_permutations=n_permutations)

                    p_values, observed_statistics, observed_power_spectrum, null_power_spectrums = ts.process(frequencies_to_probe)
                    for test_f in frequencies_to_probe:
                        bootstrapped_statistics[f"test-{test_f}"].append((p_values[test_f], observed_statistics[test_f]))
                    """
                    for test_f in frequencies:
                        bootstrapped_statistics[f"sub-{sub_id}_task-{task_id}_roi-{f}_test-{test_f}_{task_id_2}"].append((p_values[test_f], observed_statistics[test_f]))
                    """
                
                if isinstance(f, list) and len(f) == 2:
                    if control_roi_size:
                        save_pkl = f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{task_id}_roi-{f[0]}_controlroisizetomatch-{f[1]}_task-{task_id_2}_fo-{fo}_phaseadjusted-{phased_flag}_n-{n_permutations}.pkl"
                        save_bootstrapped_statistics(save_pkl, bootstrapped_statistics)
                    else:
                        save_pkl = f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{task_id}_roi-{f[0]}-{f[1]}_task-{task_id_2}_fo-{fo}_phaseadjusted-{phased_flag}_n-{n_permutations}.pkl"
                        save_bootstrapped_statistics(save_pkl, bootstrapped_statistics)
                elif isinstance(f, float):
                    save_pkl = f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{task_id}_roi-{f}_task-{task_id_2}_fo-{fo}_phaseadjusted-{phased_flag}_n-{n_permutations}.pkl"
                    save_bootstrapped_statistics(save_pkl, bootstrapped_statistics)
                    
                del bootstrapped_statistics
                    
                bootstrapped_means = np.vstack(bootstrapped_means)
                y_bootstrapped_mean = bootstrapped_means.mean(0)
                
                ts = TimeSeries(y_bootstrapped_mean, TR, n_permutations=n_permutations)
                p_values, observed_statistics, observed_power_spectrum, null_power_spectrums = ts.process(frequencies_to_probe)
                
                if not phased_flag:
                    if frequency_grid is None:
                        frequency_grid = ts.frequencies
                    else:
                        assert np.allclose(frequency_grid, ts.frequencies, rtol=1e-05, atol=1e-08)

                subject_task_level_psd[f"f-{f_ix}"].append(observed_power_spectrum)

                # Plot
                """
                if sub_ix == 0 and task_ix == 0 and f_ix == 0:
                    plot_power_spectrum(ts, observed_power_spectrum, null_power_spectrums, add_im=True, sub_id=sub_id, roi_frequency=f)
                """
                
            experiment_info.append((sub_id, task_id,frequencies))

Use to generate single experiment PSD heatmaps

In [None]:
from copy import copy

def decorate_fig_1C(
    fig, ax, im, 
    frequencies_per_experiment, frequency_grid,
    arrow_pos = [],
    yticks=[], yticklabels=[], 
    xticks=[], xticklabels=[], 
    FONTSIZE=FONTSIZE,
    lower_f=True
):

    cbar = plt.colorbar(im, ax=ax, shrink=.5, drawedges=False)
    cbar.ax.set_title("Z-scored\nPSD", fontsize=FONTSIZE)
    cbar.ax.tick_params(axis="both", length=0, labelsize=FONTSIZE)
    cbar.outline.set_edgecolor('none')

    ax.set_ylabel("Experiments", fontsize=FONTSIZE)
    ax.set_yticks(yticks)
    ax.set_yticklabels(yticklabels, fontsize=FONTSIZE)

    ax.set_xlabel("Frequency (Hz)", fontsize=FONTSIZE)
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels, fontsize=FONTSIZE)
    
    ax.tick_params(axis="both", length=0)

    for experiment_ix, _fs in enumerate(frequencies_per_experiment):
        match_f = 0
        if lower_f:
            match_f = 1
        for _ix, f in enumerate(_fs):
            fc = 'red'
            if _ix == match_f:
                if lower_f is not None:
                    fc = 'cyan'
            arrow_props = dict(facecolor=fc, edgecolor=fc, arrowstyle='simple', linewidth=0, mutation_scale=3)
            _xpos = np.interp(f, frequency_grid, np.arange(len(frequency_grid)))
            ax.annotate('', xy=(_xpos,experiment_ix), xytext=(_xpos+1,experiment_ix), arrowprops=arrow_props, annotation_clip=False)


    for i in ("top", "right", "bottom", "left"):
        ax.spines[i].set_visible(False)

    return fig, ax

def clean_subject_task_level_psd(psd_list):
    for psd_ix, i in enumerate(psd_list):
        if i is None:
            for j in psd_list:
                if j is not None:
                    psd_list[psd_ix] = np.zeros_like(j)
    
    return psd_list

In [None]:
control_roi_size = False
phased_flag = False
n_permutations=5
close_figure = True # closes single ROI psds

for fo in [.4, .6, .8, 1.]:
    dataset_ids = [NORMAL_3T, NORMAL_3T, NORMAL_7T, VARY_3T, VARY_7T]
    dataset_labels = ["NORMAL_3T", "NORMAL_3T_CONTROL", "NORMAL_7T", "VARY_3T", "VARY_7T"]
    for dataset_ix, (dataset_label, dataset_id) in enumerate(zip(dataset_labels, dataset_ids)):
        experiment_id, mri_id, sub_to_task_mapping = read_data(dataset_id)
        frequency_grid = None
        experiment_info = []
        subject_task_level_psd = defaultdict(list)
        for sub_ix, (sub_id, sub_task_info) in enumerate(sub_to_task_mapping.items()):
            for task_ix, (task_id, frequencies) in enumerate(sub_task_info):

                bold_list = !ls /data/{experiment_id}/{mri_id}/bids/sub-{sub_id}/*/func/*{task_id}*nii.gz
                task_quadrant = list(set([i.split(task_id)[1].split('_')[0] for i in bold_list]))
                assert len(task_quadrant) == 1, f"More than 1 task quadrant detected: {task_quadrant}"
                task_quadrant = task_quadrant[0]

                assert len(frequencies) == 2 and frequencies[0]<frequencies[1]

                """
                if sub_id == "020":
                    task_id_2 = "entrainC"
                else:
                    task_id_2 = "entrainF"
                task_id_2 = task_id
                """
                if dataset_ix == 1:
                    task_id_2 = 'control'
                else:
                    task_id_2 = task_id

                for f_ix, f in enumerate(frequencies + [frequencies]):

                    if control_roi_size and f_ix < 2:
                        print(f"Skipping {task_id} {f}")
                        continue

                    if isinstance(f, list) and len(f) == 2:
                        if control_roi_size:
                            data = read_pkl(datadir, n_bootstraps, sub_id, task_id, f[0], task_id_2, experiment_id=experiment_id, mri_id=mri_id, fo=fo, roi_frequency_2=f[1], control_roi_size=control_roi_size)
                        else:
                            data = read_pkl(datadir, n_bootstraps, sub_id, task_id, f[0], task_id_2, experiment_id=experiment_id, mri_id=mri_id, fo=fo, roi_frequency_2=f[1])
                    elif isinstance(f, float):
                        data = read_pkl(datadir, n_bootstraps, sub_id, task_id, f, task_id_2, experiment_id=experiment_id, mri_id=mri_id, fo=fo)
                    else:
                        raise ValueError("")
                        
                    if data is None:
                        print(f"sub-{sub_id}, ROI info: {task_id}/f={f} [n=0], task-frequencies: {frequencies}, No data found")
                        subject_task_level_psd[f"f-{f_ix}"].append(None)
                        continue
                    n_voxels = data['roi_coords'].sum()
                    print(f"sub-{sub_id}, ROI info: {task_id}/f={f} [n={n_voxels}], task-frequencies: {frequencies}")
                    
                    ####
                    if task_id_2 == "AttendAway":
                        frequencies_to_probe = [.125, .2]
                    if task_id_2 == "entrain":
                        frequencies_to_probe = [.125, .2]
                    if task_id_2 == "control":
                        frequencies_to_probe = [.125, .2]
                    if task_id_2 == "entrainA":
                        frequencies_to_probe = [.125, .2]
                    if task_id_2 == "entrainB":
                        frequencies_to_probe = [.125, .175]
                    if task_id_2 == "entrainC":
                        frequencies_to_probe = [.125, .15]
                    if task_id_2 == "entrainD":
                        frequencies_to_probe = [.125, .2]
                    if task_id_2 == "entrainE":
                        frequencies_to_probe = [.15, .2]
                    if task_id_2 == "entrainF":
                        frequencies_to_probe = [.175, .2]
                    print(f"Processing {frequencies_to_probe} in sub-{sub_id}_task-{task_id_2}\n")
            
                    # Select timeseries (timepoints x voxels x bootstraps)
                    bootstrapped_means = []
                    bootstrapped_statistics = defaultdict(list)
                    for bootstrap_id in range(n_bootstraps):
                        
                        _, y_bootstrapped_mean = extract_bootstrapped_mean_from_data(data, task_id_2, task_quadrant, bootstrap_id, phased_flag)

                        bootstrapped_means.append(y_bootstrapped_mean) # Track

                        ts = TimeSeries(y_bootstrapped_mean, TR, n_permutations=n_permutations)

                        p_values, observed_statistics, observed_power_spectrum, null_power_spectrums = ts.process(frequencies_to_probe)
                        for test_f in frequencies_to_probe:
                            bootstrapped_statistics[f"test-{test_f}"].append((p_values[test_f], observed_statistics[test_f]))
                        
                    bootstrapped_means = np.vstack(bootstrapped_means)
                    y_bootstrapped_mean = bootstrapped_means.mean(0)
                    
                    ts = TimeSeries(y_bootstrapped_mean, TR, n_permutations=n_permutations)
                    p_values, observed_statistics, observed_power_spectrum, null_power_spectrums = ts.process(frequencies_to_probe)
                    
                    if not phased_flag:
                        if frequency_grid is None:
                            frequency_grid = ts.frequencies
                        else:
                            assert np.allclose(frequency_grid, ts.frequencies, rtol=1e-05, atol=1e-08)

                    subject_task_level_psd[f"f-{f_ix}"].append(observed_power_spectrum)

                    # Plot
                    if isinstance(f, list):
                        figlabel = '-'.join(str(i) for i in f)
                    else:
                        figlabel = copy(f)
                    png_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/dual_frequency_roi_psd")) / f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{task_id}-{figlabel}_task-{task_id_2}_fo-{fo}_.png"
                    plot_power_spectrum(
                        ts, 
                        observed_power_spectrum, 
                        null_power_spectrums, 
                        add_im=True, 
                        sub_id=sub_id, 
                        roi_frequency=f, 
                        close_figure=close_figure,
                        png_out=png_out,
                    )
                    
                experiment_info.append((sub_id, task_id,frequencies))

        fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(3,1), dpi=300)

        xmax = (frequency_grid<.5).sum()

        all_frequencies = []
        for i in experiment_info:
            all_frequencies += i[2]
        all_frequencies = list(set(all_frequencies))

        for ix, (i, ax) in enumerate(zip(range(3), axs)):
            subject_task_level_psd[f"f-{i}"] = clean_subject_task_level_psd(subject_task_level_psd[f"f-{i}"])
            sub_task_psds = np.vstack(subject_task_level_psd[f'f-{i}'])
            sub_task_psds = ( sub_task_psds - sub_task_psds.mean(1, keepdims=True) ) / sub_task_psds.std(1, keepdims=True)
            im = ax.imshow(
                sub_task_psds[:,:xmax], 
                cmap='magma',
                interpolation='none', aspect='auto',
                vmax=4, vmin=0
            )

            if ix == 2:
                lower_f = None
            else: 
                lower_f = i==0
            xticklabels = [0, .1, .2, .3, .4, .5]
            decorate_fig_1C(
                fig, ax, im,
                [i[2] for i in experiment_info], frequency_grid,
                yticks=[i for i in range(len(experiment_info))], 
                yticklabels=[f"{i[0]} {i[1]}" for i in experiment_info],
                xticks = [np.where(frequency_grid == i)[0][0] for i in xticklabels],
                xticklabels=xticklabels,
                lower_f=lower_f,
                FONTSIZE=FONTSIZE-2
            )

        fig.tight_layout()
                
        png_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/dual_frequency_roi_and_group_psd")) / f"group-{dataset_label}_fo-{fo}.png"
        fig.savefig(png_out, dpi='figure')

Plot ROI timeseries (normal, and not rephased) of a single bootstrap from a sample dataset

In [None]:
bootstrap_id = 0
fig, ax = plt.subplots(figsize=(4,1), dpi=200)
x, y = extract_bootstrapped_mean_from_data(data, task_id_2, task_quadrant, bootstrap_id, True)
ax.plot(x,y, c='r', zorder=2)
x, y = extract_bootstrapped_mean_from_data(data, task_id_2, task_quadrant, bootstrap_id, False)
ax.plot(x,y, c='grey', zorder=1)
ax.set_title(f'sub-{sub_id} roi-{task_id}-{f} task-{task_id_2}')