In [4]:
import glob
from pathlib import Path
import itertools
import nibabel as nib
import numpy as np
import pandas as pd

def search(base_dir, wildcard, error=True):
    search_path = Path(base_dir) / wildcard
    files = glob.glob(str(search_path))

    if not files:
        if error:
            raise FileNotFoundError(f"No files were found in: {search_path}")
        else:
            return []

    return files

def filter_run_ids(run_ids):
    KEEP = ['01', '02', '03', 'IMTest', 'IMRetest']
    KEEP = [f"run-{i}" for i in KEEP]
    filtered_run_ids = []
    for run_id in run_ids:
        if run_id.split('/')[-1] in KEEP or "X" in run_id:
            filtered_run_ids.append(run_id)

    return filtered_run_ids

In [5]:
scratch_dir = Path("/scratch/fastfmri")
IM_MODULATION_EXPERIMENT_IDS = ["020", "021"]
SUPPORTED_IDS = [
    "1_frequency_tagging_3T_entrain",
    #"1_frequency_tagging_7T_entrain",
    "1_attention_7T_AttendAway",
]
DSCALAR_TEMPLATE = "/opt/app/notebooks/data/dscalars/S1200.MyelinMap_BC_MSMAll.32k_fs_LR.dscalar.nii"


experiment_id = "1_frequency_tagging"
mri_id = "3T"
task_base = "entrain"
n_batches = 2
session_ids = ["run-ALL"]

assert f"{experiment_id}_{mri_id}_{task_base}" in SUPPORTED_IDS



sub_fla_dirs = search(scratch_dir, f"experiment-{experiment_id}_mri-{mri_id}*smooth-0*batch-00_desc-IMsubtraction_bootstrap/first_level_analysis/sub-*")
sub_fla_dirs.sort()
sub_ids = []
task_suffices = []
for i in sub_fla_dirs:
    _dir = Path(i)
    sub_id = _dir.stem
    if sub_id in [f"sub-{_id}" for _id in IM_MODULATION_EXPERIMENT_IDS]:
        continue
    sub_ids.append(sub_id)
    
    search_str = str(_dir).split("/")[3]
    truncate_id = search_str.split('truncate-')[1].split('_')[0]
    smooth_id = search_str.split('smooth-')[1].split('_')[0]

    tasks = search(scratch_dir, f"experiment-{experiment_id}*mri-{mri_id}*smooth-0*batch-00_desc-IMsubtraction_bootstrap/first_level_analysis/{sub_id}/*/task-{task_base}*")
    assert len(tasks) == 1
    task_suffix = tasks[0][-2:]
    task_suffices.append(task_suffix)

In [6]:
import sys
sys.path.append("./ComputeCanada/frequency_tagging")
from im_recall_precision import *

import matplotlib.pyplot as plt
import seaborn as sns

f_1, f_2, f_im = 0.125, 0.2, 0.075

def process_phase_delay(
    data: np.ndarray,
    stimulated_frequency: float,
):
    """
    """
    max_phasedelay = 1/float(stimulated_frequency) # in seconds
    max_indices = data == max_phasedelay
    non_max_indices = data != max_phasedelay
    data[max_indices] = 0 # Unphase max values: [0, max_phasedelay)
    data[non_max_indices] -= max_phasedelay / 2 # Move the stimulus to the right by pi/2 account for different between stimulus and sine wave
    data[non_max_indices] %= max_phasedelay # Rephase data: [0, max_phasedelay)

    return data

for ix, ((sub_id, task_suffix), ses_id) in enumerate(itertools.product(zip(sub_ids, task_suffices), session_ids)):
    
    task_f1_data, task_f2_data, task_im_data = load_multitype_maps(
        [ses_id],
        mri_id, 
        smooth_id,
        truncate_id,
        scratch_dir,
        sub_id,
        task_suffix,
        f_1, f_2, f_im,
        metric_types=["z_score", "p_value", "stat", "phasedelay"],
        task_base=task_base,
        experiment_id=experiment_id,
        dscalar_template=DSCALAR_TEMPLATE,
    )

    bootstrapped_brainmaps = {}
    metric_types = ["phasedelay", "z_score", "activations", "tasklock"]
    for metric_type in metric_types:
        for batch_ix in range(n_batches):
            batch_id = str(batch_ix).zfill(2)
            if metric_type != "tasklock":
                fs = search(scratch_dir,  f"experiment-{experiment_id}_mri-{mri_id}*smooth-0*batch-{batch_id}_desc-IMsubtraction_bootstrap/{sub_id}/bootstrap/*{task_base}*{f_im}_data-train_n-100_{metric_type}.dtseries.nii")
            else:
                fs = search(scratch_dir,  f"experiment-{experiment_id}_mri-{mri_id}*smooth-0*batch-{batch_id}_desc-IMsubtraction_bootstrap/{sub_id}/bootstrap/*{task_base}*n-100_{metric_type}.dtseries.nii")
            assert len(fs) == 1
            _brainmap = read_map(fs[0], dscalar_template=DSCALAR_TEMPLATE)
            if batch_ix == 0:
                bootstrapped_brainmaps[metric_type] = _brainmap
            else:
                bootstrapped_brainmaps[metric_type] = np.concatenate((bootstrapped_brainmaps[metric_type], _brainmap), axis=0)
    bootstrapped_brainmaps['phasedelay'] = process_phase_delay(bootstrapped_brainmaps['phasedelay'], f_im)

    # Get wholebrain cortex mask    
    wb_coverage_mask = (task_f1_data['stat'][ses_id]>0).astype(int)

    # Get P-value corrected map
    task_f1_mask = process_mask(task_f1_data, correction_type="fdr")[ses_id].astype(int)
    task_f2_mask = process_mask(task_f2_data, correction_type="fdr")[ses_id].astype(int)
    f1_f2_intersection = (task_f1_mask + task_f2_mask) == 2
    f1_f2_intersection = (f1_f2_intersection[wb_coverage_mask==1]).astype(int)

    task_im_mask = process_mask(task_im_data, correction_type="fdr")[ses_id].astype(int)
    task_im_mask = (task_im_mask[wb_coverage_mask==1]).astype(int)

    out_im_in_inter = (task_im_mask == 0) * (f1_f2_intersection == 1)
    in_coords = (task_im_mask == 1) * (f1_f2_intersection == 1)
    out_coords = (task_im_mask == 1) * (f1_f2_intersection == 0)
    print(sub_id, out_im_in_inter.sum(), in_coords.sum(), out_coords.sum())

    # phase delay
    in_phasedelay = process_phase_delay(task_im_data['phasedelay'][ses_id][wb_coverage_mask==1][in_coords], f_im)
    in_zscore = task_im_data['z_score'][ses_id][wb_coverage_mask==1][in_coords]
    out_phasedelay = process_phase_delay(task_im_data['phasedelay'][ses_id][wb_coverage_mask==1][out_coords], f_im)
    out_zscore = task_im_data['z_score'][ses_id][wb_coverage_mask==1][out_coords]
    other_phasedelay = process_phase_delay(task_im_data['phasedelay'][ses_id][wb_coverage_mask==1][out_im_in_inter], f_im)
    other_zscore = task_im_data['z_score'][ses_id][wb_coverage_mask==1][out_im_in_inter]

    # Plot phase delay and z-scores using different IM & intersection masks
    fig, axs = plt.subplots(ncols=2, figsize=(4,1.2), dpi=200)

    data = [in_phasedelay, out_phasedelay, other_phasedelay]
    sns.violinplot(data=data, ax=axs[0], split=False, inner='quart')
    for ix, metric in enumerate(data):
        jitter = np.random.uniform(-.35, -.05, size=metric.shape)
        axs[0].text(ix-.25,metric.max(), f"n={metric.shape[0]}", fontsize=5)
        axs[0].scatter(np.zeros_like(metric)+ix+jitter, metric, s=2, c='grey')
        axs[0].set_title(f"[{sub_id}] Phase delay", fontsize=5)
        axs[0].set_ylabel("phase delay", fontsize=5)
    data = [in_zscore, out_zscore, other_zscore]
    sns.violinplot(data=data, ax=axs[1], split=False, inner='quart')
    for ix, metric in enumerate([in_zscore, out_zscore, other_zscore]):
        jitter = np.random.uniform(-.35, -.05, size=metric.shape)
        axs[1].scatter((np.zeros_like(metric)+ix)+jitter, metric, s=2, c='grey')
        axs[1].set_title(f"[{sub_id}] Z-score", fontsize=5)
        axs[1].set_xlabel("z-score", fontsize=5)

    for ax in axs:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xticks([0,1,2])
        ax.set_xticklabels(["In","Out","Other"], fontsize=5)

    fig.tight_layout()

    # Plot metrics, in relation to phase delays, for each IM & interesction masks
    for metric_type in metric_types:

        fig, axs = plt.subplots(ncols=3, nrows=3,figsize=(6,4))
        mean_max, stdev_max, cov_max = None, None, None
        mean_min, stdev_min, cov_min = 0, 0, 0
        for ix, (coords, pd, title, c) in enumerate(zip([in_coords, out_coords, out_im_in_inter], [in_phasedelay, out_phasedelay, other_phasedelay], ["In", "Out", "Other"], ["blue", "orange", "green"])):
            
            if metric_type == 'activations':
                _mean = bootstrapped_brainmaps[metric_type][:,wb_coverage_mask==1][:, coords].sum(0)
                axs[0,ix].scatter(pd, _mean, alpha=1., c=c, s=2)
                axs[0,ix].set_ylabel("sum")
                if mean_max is None or mean_max < _mean.max():
                    mean_max = _mean.max()
                if mean_min == 0 or mean_min > _mean.min():
                    mean_min = _mean.min()
            else:
                _mean = bootstrapped_brainmaps[metric_type][:,wb_coverage_mask==1][:, coords].mean(0)
                axs[0,ix].scatter(pd, _mean, alpha=1., c=c, s=2)
                axs[0,ix].set_ylabel("mean")
                if mean_max is None or mean_max < _mean.max():
                    mean_max = _mean.max()
                if mean_min == 0 or mean_min > _mean.min():
                    mean_min = _mean.min()
                
                _stdev = bootstrapped_brainmaps[metric_type][:,wb_coverage_mask==1][:, coords].std(0)
                axs[1,ix].scatter(pd, _stdev, alpha=1., c=c,s=2)
                axs[1,ix].set_ylabel("stdev.")
                if stdev_max is None or stdev_max < _stdev.max():
                    stdev_max = _stdev.max()

                
                _cov = _stdev / _mean
                axs[2,ix].scatter(pd, _cov, alpha=1., c=c,s=2)
                axs[2,ix].set_xlabel("phase delay")
                axs[2,ix].set_ylabel(f"cov.")
                if cov_max is None or cov_max < _cov.max():
                    cov_max = _cov.max()

            for _axs in axs:
                for ix, ax in enumerate(_axs):
                    ax.set_xticks([0,1/f_im])
                    top_x = 1/f_im
                    ax.set_xticklabels([0, f"{top_x:.1f}"])

            for ix in range(3):
                if metric_type == "activations":
                    axs[0,ix].set_yticks([0, mean_max*1.2])
                    axs[0,ix].set_ylim([0, mean_max*1.2])
                else:
                    axs[0,ix].set_yticks([0, mean_max*1.2])
                    axs[0,ix].set_ylim([0, mean_max*1.2])
                    axs[1,ix].set_yticks([0, stdev_max*1.2])
                    axs[1,ix].set_ylim([0, stdev_max*1.2])
                    axs[2,ix].set_yticks([0, cov_max*1.2])
                    axs[2,ix].set_ylim([0, cov_max*1.2])

            fig.suptitle(metric_type)    

            fig.tight_layout()

> [0;32m/tmp/ipykernel_22562/1404099134.py[0m(62)[0;36m<module>[0;34m()[0m
[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m    [0;31m# Get wholebrain cortex mask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 62 [0;31m    [0mwb_coverage_mask[0m [0;34m=[0m [0;34m([0m[0mtask_f1_data[0m[0;34m[[0m[0;34m'stat'[0m[0;34m][0m[0;34m[[0m[0mses_id[0m[0;34m][0m[0;34m>[0m[0;36m0[0m[0;34m)[0m[0;34m.[0m[0mastype[0m[0;34m([0m[0mint[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     63 [0;31m[0;34m[0m[0m
[0m[0;32m     64 [0;31m    [0;31m# Get P-value corrected map[0m[0;34m[0m[0;34m[0m[0m
[0m
{'phasedelay': array([[ 4.0360988 ,  2.29583772,  3.22038587, ...,  6.66666635,
         6.66666635,  6.66666635],
       [ 0.26915916,  2.60448964,  5.95670827, ...,  6.66666635,
         6.66666635,  6.66666635],
       [ 2.80789502,  0.98088964, 11.75829379, ...,  6.66666635,
         6.66666635,  6.66666635],
       ...,
       [ 4.653