In [1]:
import nibabel as nib
import numpy as np

import sys
sys.path.append("/opt/wbplot")

from wbplot import dscalar

from IPython.display import Image
import matplotlib.pyplot as plt

from pathlib import Path
import pandas as pd
from collections import defaultdict
from functools import lru_cache

sys.path.append("ComputeCanada/frequency_tagging")
from dfm import (
    get_roi_colour_codes,
    change_font,
)
change_font()

/opt/app/notebooks/font_library/aptos-extrabold.ttf
/opt/app/notebooks/font_library/aptos-black-italic.ttf
/opt/app/notebooks/font_library/aptos-italic.ttf
/opt/app/notebooks/font_library/aptos-light-italic.ttf
/opt/app/notebooks/font_library/aptos.ttf
/opt/app/notebooks/font_library/aptos-light.ttf
/opt/app/notebooks/font_library/aptos-extrabold-italic 2.ttf
/opt/app/notebooks/font_library/aptos-black.ttf
/opt/app/notebooks/font_library/aptos-semibold.ttf
/opt/app/notebooks/font_library/aptos-bold.ttf
/opt/app/notebooks/font_library/aptos-extrabold-italic.ttf


Get HCP info
- `hcp_mappings`: dict of ROI: dscalars
- `hcp_rois`: 

In [2]:
"""Get HCP labels
"""
dlabel_dir = Path("/opt/app/notebooks/data/dlabels")
hcp_label = dlabel_dir / "Q1-Q6_RelatedValidation210.CorticalAreas_dil_Final_Final_Areas_Group_Colors.32k_fs_LR.dlabel.nii"

_HCP_INFO = !wb_command -file-information {hcp_label}
HCP_LABELS = []
HCP_COUNTER = 0
for i in _HCP_INFO:
    if len(i) == 60 and any(["L_" in i, "R_" in i]):
        hcp_colors = tuple([float(f"0.{k}") for k in [j.split(' ') [0] for j in i.split('0.')][-3:]] + [1])
        if ' R_' in i:
            roi = i.split("_ROI")[0].split(' R_')[1]
            HCP_LABELS.append(f"R_{roi}_ROI")
        if ' L_' in i:
            roi = i.split("_ROI")[0].split(' L_')[1]
            HCP_LABELS.append(f"L_{roi}_ROI")
        HCP_COUNTER += 1

"""Get HCP label coordinates
"""
dscalar_dir = Path("/opt/app/notebooks/data/dscalars")
tmpdir = Path("/tmp")
template_dscalar = dscalar_dir / "S1200.MyelinMap_BC_MSMAll.32k_fs_LR.dscalar.nii"

hcp_mapping = {}
for roi_label in HCP_LABELS:
    out_dscalar = tmpdir / f"{roi_label}.dscalar.nii"
    if out_dscalar.exists():
        hcp_mapping[roi_label] = out_dscalar
        continue
    !wb_command -cifti-label-to-roi {hcp_label} {out_dscalar} -name {roi_label}
    assert out_dscalar.exists(), f"{out_dscalar.stem} does not exist."
    hcp_mapping[roi_label] = out_dscalar
hcp_rois = list(set([k.split('_')[1] for k in hcp_mapping.keys()]))

Functions

In [3]:

def convert_to_fractional_overlap(data):

    return data.sum(0) / data.shape[0]

def map_data_to_value(data_list):

    for ix, (k,v) in enumerate(data_list):

        if ix == 0:
            new_data = k.copy() * v
        else:
            new_data += k * v

    return new_data

def combine_f1_f2(f1,f2,fo=1.,mask=None, f1_c=.01, f2_c=.82, f1f2_c=.28, mask_c=.01):
    f1_data = convert_to_fractional_overlap(nib.load(f1).get_fdata())
    f2_data = convert_to_fractional_overlap(nib.load(f2).get_fdata())
    f1_data = (f1_data >= fo).astype(int)
    f2_data = (f2_data >= fo).astype(int)
    f1f2_data = ((f1_data + f2_data) == 2).astype(int)
    f1_data -= f1f2_data
    f2_data -= f1f2_data
    if mask:
        mask_data = convert_to_fractional_overlap(nib.load(mask).get_fdata())
        mask_data = (mask_data >= 1.).astype(int)
        mask_data -= f1f2_data
        mask_data -= f1_data
        mask_data -= f2_data
    data_dict = [(f1_data, f1_c), (f2_data, f2_c), (f1f2_data, f1f2_c)]
    if mask:
        data_dict.append((mask_data,mask_c))

    return map_data_to_value(data_dict)

def get_quadrant_id(mask_path):
    rel_mask_path = mask_path.split("/")[-1]
    idx_start = rel_mask_path.find("Q")
    quadrant_id = rel_mask_path[idx_start:idx_start+2]
    assert quadrant_id in ['Q1', 'Q2'], f"{quadrant_id} not Q1 or Q2"

    return quadrant_id

def merge_and_binarize_mask(data, f1_c, f2_c, f1f2_c, mask_c):
    data_dict = {
        "f1": data.copy(),
        "f2": data.copy(),
        "f1Uf2": data.copy(),
    }
    data_dict["f1"][(data_dict["f1"]==f1_c) | (data_dict["f1"]==f1f2_c)] = 1
    data_dict["f1"][(data_dict["f1"]==f2_c)] = 0
    data_dict["f2"][(data_dict["f2"]==f2_c) | (data_dict["f2"]==f1f2_c)] = 1
    data_dict["f2"][(data_dict["f2"]==f1_c)] = 0
    data_dict["f1Uf2"][(data_dict["f1Uf2"]==f1f2_c)] = 1
    data_dict["f1Uf2"][(data_dict["f1Uf2"]==f1_c) | (data_dict["f1Uf2"]==f2_c)] = 0
    for v in data_dict.values():
        v[v==mask_c] = 0

    return data_dict

@lru_cache(maxsize=360)
def read_roi_path(roi_path):
    return nib.load(roi_path).get_fdata()[0,:]

def append_data(
    df_data,
    hcp_mapping,
    map_data,
    power_f1_data,
    power_f2_data,
    pd_f1_data, 
    pd_f2_data, 
    q_id,
    experiment_label, 
    sub_id, 
    roi_fo,
    roi_task_id,
    task_id,
):
    """Create function to store vertex level data for each HCP ROI:
    - columns: [cohort_id, sub_ids, quadrant_id, hcp_roi, frequency_of_roi, vertex_count, vertex_coordinates, f1_BOLD_power, f2_BOLD_power, f1_phase_delay, f2_phase_delay]
        - roi_fo = region fractional overlap threshold
        - cohort_id = cohort_id of each dataset [3T/7T Normal/Vary]
            - sub_ids = sub_id of all ROIs in cohort
                - quadrant_id = each subject will have a quadrant_id (corresponding to quadrant stimulation)
                - hcp_roi = all HCP ROIs, convert L/R to express laterality
                    - CONTRA/IPSI
                    - frequency_of_roi = each `hcp_roi` will have a ROI corresponding to f1, f2 or both (f1Uf2)
                        - vertex_count = each `frequency_of_roi` will have a vertex_count
                        - vertex_coordinates = each `frequency_of_roi` will have coordinates to all its vertices
                        - f1_BOLD_power = each `frequency_of_roi` will have a np.array of power values corresponding to each vertex
                        - f2_BOLD_power = each `frequency_of_roi` will have a np.array of power values corresponding to each vertex
                        - f1_phase_delay = each `frequency_of_roi` will have a np.array of phase delay values corresponding to each vertex
                        - f2_phase_delay = each `frequency_of_roi` will have a np.array of phase delay values corresponding to each vertex
    """
    for frequency_of_roi, f_data in map_data.items():
        for roi_label, roi_path in hcp_mapping.items():
            if q_id == "Q1":
                contra = "L_"
            elif q_id == "Q2":
                contra = "R_"
            else:
                raise ValueError(f"{q_id} not supported.")

            if roi_label.startswith(contra):
                roi_label = f"CONTRA_{roi_label[2:-4]}"
            else:
                roi_label = f"IPSI_{roi_label[2:-4]}"

            roi_mask = read_roi_path(roi_path)
            assert roi_mask.shape == f_data.shape

            hcp_and_f_roi = roi_mask * f_data
            vertex_coordinates = np.where(hcp_and_f_roi == 1)
            vertex_count = hcp_and_f_roi.sum()
            if vertex_count == 0:
                continue

            if roi_task_id == "control":
                f1_BOLD_power = None
                f2_BOLD_power = None
            else:
                f1_BOLD_power = power_f1_data[hcp_and_f_roi==1]
                f2_BOLD_power = power_f2_data[hcp_and_f_roi==1]
            if task_id != roi_task_id:
                f1_phase_delay = None
                f2_phase_delay = None
            else:
                f1_phase_delay = pd_f1_data[hcp_and_f_roi==1]
                f2_phase_delay = pd_f2_data[hcp_and_f_roi==1]

            df_data["roi_task_id"].append(roi_task_id)
            df_data["task_id"].append(task_id)
            df_data["roi_fo"].append(roi_fo)
            df_data["experiment_id"].append(experiment_label)
            df_data["sub_id"].append(sub_id)
            df_data["quadrant_id"].append(q_id)
            df_data["hcp_roi"].append(roi_label)
            df_data["frequency_of_roi"].append(frequency_of_roi)
            df_data["vertex_count"].append(vertex_count)
            df_data["vertex_coordinates"].append(vertex_coordinates)
            df_data["f1_BOLD_power"].append(f1_BOLD_power)
            df_data["f2_BOLD_power"].append(f2_BOLD_power)
            df_data["f1_phase_delay"].append(f1_phase_delay)
            df_data["f2_phase_delay"].append(f2_phase_delay)

    return df_data

def contains_all_strings(input_str, string_list):
    for string in string_list:
        if string not in input_str:
            return False
    return True

def find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, fo, sub_id, data_split_id="train", match_str="activations.dtseries.nii", additional_match_strs=None, additional_match_str=None, corr_type="uncp"):
    import os
    directory = f"/scratch/fastfmri/experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-200_batch-merged_desc-basic_roi-{roi_task_id}-{roi_f_1}_pval-{corr_type}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/"
    if additional_match_strs is not None:
        match_str = additional_match_strs + [match_str, f"data-{data_split_id}"]
        activations_files = []
        for file in os.listdir(directory):
            if contains_all_strings(file, match_str):
                activations_files.append(file)
    else:
        activations_files = [file for file in os.listdir(directory) if f'data-{data_split_id}' in file and match_str in file]

    return [f"{directory}{i}" for i in activations_files]

def set_base_dir(basedir):
    basedir = Path(basedir)
    if not basedir.exists():
        basedir.mkdir(exist_ok=True, parents=True)

    return basedir

def load_mean_dtseries(dtseries):
    mean_power = nib.load(dtseries).get_fdata().mean(0)
    return mean_power

def generate_single_subject_maps(
    label, experiment_id, mri_id, sub_ids, 
    roi_task_ids, roi_f_1s, roi_f_2s, roi_fo,
    df_data=None,
    corr_type="uncp",
    ROI_FO=.8, SUB_THRESHOLD=.5,
    LEFT=590, TOP=80, RIGHT=1140, BOTTOM=460, VERTEX_TO = 59412,
    FORCE_TASK_ID=None,
):

    if df_data is None:
        df_data = defaultdict(list)

    for ix, (sub_id, roi_task_id, roi_f_1, roi_f_2) in enumerate(zip(
        sub_ids,
        roi_task_ids, 
        roi_f_1s,
        roi_f_2s,
    )):

        if FORCE_TASK_ID is None:
            _roi_task_id = roi_task_id
        else:
            _roi_task_id = FORCE_TASK_ID

        png_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/dual_frequency_mapping")) / f"label-{label}_mri-{mri_id}_sub-{sub_id}_task-{roi_task_id}_f-{roi_f_1}-{roi_f_2}_corr-{corr_type}_fo-{ROI_FO}.png"
        dscalar_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/dual_frequency_mapping_cifti")) / f"label-{label}_mri-{mri_id}_sub-{sub_id}_task-{roi_task_id}_f-{roi_f_1}-{roi_f_2}_corr-{corr_type}_fo-{ROI_FO}.dtseries.nii"
        if png_out.exists():
            #continue
            pass

        f1 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, .8, sub_id, match_str="activations.dtseries.nii", corr_type=corr_type)
        f2 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_2, .8, sub_id, match_str="activations.dtseries.nii", corr_type=corr_type)
        mask = find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, .8, sub_id, match_str="mask.dtseries.nii", corr_type=corr_type)
        pd_f1 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, .8, sub_id, data_split_id = "train", match_str="phasedelay.dtseries.nii", additional_match_strs=[roi_task_id,f"f-{roi_f_1}"], corr_type=corr_type)
        pd_f2 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_2, .8, sub_id, data_split_id = "train", match_str="phasedelay.dtseries.nii", additional_match_strs=[roi_task_id,f"f-{roi_f_2}"], corr_type=corr_type)
        power_f1 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, .8, sub_id, data_split_id = "test", match_str="power.dtseries.nii", additional_match_strs=[_roi_task_id,f"f-{roi_f_1}"], corr_type=corr_type)
        power_f2 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_2, .8, sub_id, data_split_id = "test", match_str="power.dtseries.nii", additional_match_strs=[_roi_task_id,f"f-{roi_f_2}"], corr_type=corr_type)
        for f_label, f in zip(["f1","f2","mask","pd_f1","pd_f2","power_f1","power_f2"], [f1,f2, mask, pd_f1, pd_f2, power_f1, power_f2]):
            if roi_task_id == "control" and experiment_id == "1_frequency_tagging":
                if f_label in ["f1", "f2", "mask"]:
                    assert len(f) == 1, f"{sub_id}, {f_label} - {f}"
            else:
                assert len(f) == 1, f"{sub_id}, {f_label} - {f}, {experiment_id} {roi_task_id}"

        f1, f2 = f1[0], f2[0]
        data = combine_f1_f2(f1, f2, fo=ROI_FO, mask=mask[0], f1_c=f1_c,f2_c=f2_c,f1f2_c=f1f2_c,mask_c=mask_c)
        data = data[:VERTEX_TO]

        map_data = merge_and_binarize_mask(data,f1_c,f2_c,f1f2_c,mask_c)
        pd_f1_data = load_mean_dtseries(pd_f1[0])[:VERTEX_TO]
        pd_f2_data = load_mean_dtseries(pd_f2[0])[:VERTEX_TO]
        if roi_task_id == "control" and experiment_id == "1_frequency_tagging":
            power_f1_data = None
            power_f2_data = None
        # Power metrics were not calculated for control task condition (no voxels allocated to task-control ROIs)
        else:
            power_f1_data = load_mean_dtseries(power_f1[0])[:VERTEX_TO]
            power_f2_data = load_mean_dtseries(power_f2[0])[:VERTEX_TO]
        q_id = get_quadrant_id(mask[0])
        df_data = append_data(
            df_data, 
            hcp_mapping, 
            map_data, 
            power_f1_data, 
            power_f2_data, 
            pd_f1_data, 
            pd_f2_data, 
            q_id,
            label, 
            sub_id,
            roi_fo,
            roi_task_id,
            _roi_task_id, 
        )

        palette_params = {
            "disp-zero": False,
            "disp-neg": True,
            "disp-pos": True,
            "pos-user": (0, 1.),
            "neg-user": (-1,0),
            "interpolate": True,
        }
        # Save f1f2 map as dtseries
        f1_img = nib.load(f1)
        dscalar_to_save_as_cifti = np.zeros((1,f1_img.shape[-1]))
        dscalar_to_save_as_cifti[0,:VERTEX_TO] = data
        f1f2_img = nib.Cifti2Image(dscalar_to_save_as_cifti, header=f1_img.header)
        f1f2_img.header.matrix[0].number_of_series_points = 1
        nib.save(f1f2_img, dscalar_out)
        dscalar(
            png_out, data, 
            orientation="portrait", 
            hemisphere='right',
            palette=PALETTE, 
            palette_params=palette_params,
            transparent=False,
            flatmap=True,
            flatmap_style='plain',
        )
        crop_and_save(png_out, str(png_out).replace("png", "cropped.png"), LEFT, TOP, RIGHT, BOTTOM)
        
        track = [len(v) for k,v in df_data.items()]
        print(track)

    return df_data

def crop_and_save(input_file, output_file, left, top, right, bottom):
    from PIL import Image
    try:
        # Open the input image
        with Image.open(input_file) as img:
            # Crop the image
            cropped_img = img.crop((left, top, right, bottom))
            # Save the cropped image
            cropped_img.save(output_file)
            print("Cropped image saved successfully as", output_file)
    except Exception as e:
        print("An error occurred:", e)

Save visualizations and create DataFrame storing data for simple analysis
- Datasets
    - `3TNormal` ($f_1$=.125, $f_2$=.2)
    - `7TNormal` ($f_1$=.125, $f_2$=.2)
    - `3TVary` varying frequencies 
    - `7TVary` varying frequencies
- Region info `df.frequency_of_roi` and `df.hcp_roi`
    - Region data of $f_1$ *include* $f_1$&$f_2$ intersected vertices, same goes for $f_2$
    - roi choice includes `["f1","f2","f1Uf2"]`, where `f1Uf2` denotes intersected vertices
- Data includes
    - Note: to load `fdrp` (from `uncp`) corrected data change the `corr_type` variable in the cell below
    - `df.roi_task_id` task used to generate the ROI
    - `df.roi_fo` fractional overlap used to generate the ROI
    - ...
    - `df.hcp_roi` HCP ROI used to filter data from
    - `df.frequency_of_roi` frequency of ROI used to filter data from (related to the frequency of `df.roi_task_id`)
    - `df.vertex_count` total vertices in the HCP ROI & identified with the frequency of the task
    - `df.vertex_coordinates` coordinates on a 32k_fs_LR surface (consistent with file structure of `template_dscalar`)
    - `df.f1_[BOLD_power,phase_delay]` vertex-wise power extracted from $f_1$ region (this is set to 0 for $f_2$-only regions)
    - `df.f2_[BOLD_power,phase_delay]` vertex-wise power extracted from $f_2$ region (this is set to 0 for $f_1$-only regions)
        - `f1Uf2` regions contain both $f_1$ and $f_2$ metrics


In [4]:
"""Set up for visualizing dual frequency tagging across each subject using fractional overlap
"""
PALETTE = "power_surf"
f1_c = -.1 # red -.1
f2_c = .82 # blue .82
f1f2_c = .14 # white .88 yellow .1
mask_c = .41 # .9 [green], .1 [goldish], .4 [black]

cohort_roi_info_across_experiments = {}
ROI_FOS = [.8,1.]
corr_type = "uncp"

"""Save png
"""
# 3T control under entrain condition (set this to get power measurements with entrain ROIs)
label = "3TNormal"
df_data = None
for _roi_task_id in ["entrain"]:
    for ROI_FO in ROI_FOS:
        experiment_id = "1_frequency_tagging" 
        mri_id = "3T"
        sub_ids = ["000", "002", "003", "004", "005", "006", "007", "008", "009"] 
        roi_task_ids = [_roi_task_id] * len(sub_ids)
        roi_f_1s = [.125] * len(sub_ids)
        roi_f_2s = [.2] * len(sub_ids)
        df_data = generate_single_subject_maps(
            label, experiment_id, mri_id, sub_ids, 
            roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
            df_data=df_data,
            corr_type=corr_type,
            ROI_FO=ROI_FO, SUB_THRESHOLD=.5,
            FORCE_TASK_ID="control"
        )
# 3T normal
label = "3TNormal"
for _roi_task_id in ["entrain"]:
    for ROI_FO in ROI_FOS:
        experiment_id = "1_frequency_tagging" 
        mri_id = "3T"
        sub_ids = ["000", "002", "003", "004", "005", "006", "007", "008", "009"] 
        roi_task_ids = [_roi_task_id] * len(sub_ids)
        roi_f_1s = [.125] * len(sub_ids)
        roi_f_2s = [.2] * len(sub_ids)
        df_data = generate_single_subject_maps(
            label, experiment_id, mri_id, sub_ids, 
            roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
            df_data=df_data,
            corr_type=corr_type,
            ROI_FO=ROI_FO, SUB_THRESHOLD=.5
        )
# 7T normal
label = "7TNormal"
for ROI_FO in ROI_FOS:
    experiment_id = "1_attention" 
    mri_id = "7T"
    sub_ids = ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]
    roi_task_ids = ["AttendAway"] * len(sub_ids)
    roi_f_1s = [.125] * len(sub_ids)
    roi_f_2s = [.2] * len(sub_ids)
    df_data = generate_single_subject_maps(
        label, experiment_id, mri_id, sub_ids, 
        roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
        df_data=df_data,
        corr_type=corr_type,
        ROI_FO=ROI_FO, SUB_THRESHOLD=.5
    )
# 3T vary
label = "3TVary"
for ROI_FO in ROI_FOS:
    experiment_id = "1_frequency_tagging"
    mri_id = "3T"
    sub_ids = ["020"] * 3 + ["021"] * 3
    roi_task_ids = [f"entrain{i}" for i in ["A", "B", "C", "D", "E", "F"]]
    roi_f_1s = [.125] * 3 + [.125, .15, .175]
    roi_f_2s = [.2, .175, .15] + [.2] * 3
    df_data = generate_single_subject_maps(
        label, experiment_id, mri_id, sub_ids, 
        roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
        df_data=df_data,
        corr_type=corr_type,
        ROI_FO=ROI_FO, SUB_THRESHOLD=.5
    )
# 7T vary
label = "7TVary"
for ROI_FO in ROI_FOS:
    experiment_id = "1_frequency_tagging"
    mri_id = "7T"
    sub_ids = ["020"] * 3 + ["021"] * 3
    roi_task_ids = [f"entrain{i}" for i in ["A", "B", "C", "D", "E", "F"]]
    roi_f_1s = [.125] * 3 + [.125, .15, .175]
    roi_f_2s = [.2, .175, .15] + [.2] * 3
    df_data = generate_single_subject_maps(
        label, experiment_id, mri_id, sub_ids, 
        roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
        df_data=df_data,
        corr_type=corr_type,
        ROI_FO=ROI_FO, SUB_THRESHOLD=.5
    )

df = pd.DataFrame(df_data)

from IPython.display import clear_output
clear_output()


KeyboardInterrupt: 

Create geodesic distance ROI from V1

In [None]:
dlabel_dir = Path("/opt/app/notebooks/data/dlabels")
hcp_label = dlabel_dir / "Q1-Q6_RelatedValidation210.CorticalAreas_dil_Final_Final_Areas_Group_Colors.32k_fs_LR.dlabel.nii"

_HCP_INFO = !wb_command -file-information {hcp_label}
HCP_LABELS = []
HCP_COUNTER = 0
for i in _HCP_INFO:
    if len(i) == 60 and any(["L_" in i, "R_" in i]):
        hcp_colors = tuple([float(f"0.{k}") for k in [j.split(' ') [0] for j in i.split('0.')][-3:]] + [1])
        if ' R_' in i:
            roi = i.split("_ROI")[0].split(' R_')[1]
            HCP_LABELS.append(f"R_{roi}_ROI")
        if ' L_' in i:
            roi = i.split("_ROI")[0].split(' L_')[1]
            HCP_LABELS.append(f"L_{roi}_ROI")
        HCP_COUNTER += 1


dscalar_dir = Path("/opt/app/notebooks/data/dscalars")
tmpdir = Path("/tmp")
template_dscalar = dscalar_dir / "S1200.MyelinMap_BC_MSMAll.32k_fs_LR.dscalar.nii"

hcp_mapping = {}
for roi_label in HCP_LABELS:
    out_dscalar = tmpdir / f"{roi_label}.gd.dscalar.nii"
    if out_dscalar.exists():
        print(f"Skipping {roi_label}")
        hcp_mapping[roi_label] = out_dscalar
        continue
    !wb_command -cifti-label-to-roi {hcp_label} {out_dscalar} -name {roi_label}
    !wb_command -cifti-create-dense-from-template {template_dscalar} {out_dscalar} -cifti {out_dscalar}
    assert out_dscalar.exists(), f"{out_dscalar.stem} does not exist."
    hcp_mapping[roi_label] = out_dscalar
hcp_rois = list(set([k.split('_')[1] for k in hcp_mapping.keys()]))

surface_dir = Path("/opt/app/notebooks/data/surfaces")
tmpdir = Path("/tmp")
L_mid = surface_dir / "S1200.L.midthickness_MSMAll.32k_fs_LR.surf.gii"
R_mid = surface_dir / "S1200.R.midthickness_MSMAll.32k_fs_LR.surf.gii"
L_geo = tmpdir / "L.dconn.nii"
R_geo = tmpdir / "R.dconn.nii"
!wb_command -surface-geodesic-distance-all-to-all {L_mid} {L_geo}
!wb_command -surface-geodesic-distance-all-to-all {R_mid} {R_geo}

L_V1 = tmpdir / "L_V1_ROI.gd.dscalar.nii"
L_V1_coords = nib.load(L_V1).get_fdata()[0,:32492]==1
_L_geo = nib.load(L_geo).get_fdata()
L_geo_arr = _L_geo[L_V1_coords,:].mean(0)
del _L_geo
R_V1 = tmpdir / "R_V1_ROI.gd.dscalar.nii"
R_V1_coords = nib.load(R_V1).get_fdata()[0,32492:]==1
_R_geo = nib.load(R_geo).get_fdata()
R_geo_arr = _R_geo[R_V1_coords,:].mean(0)
del _R_geo

geo_arr = np.concatenate((L_geo_arr, R_geo_arr))

geodesic_dscalar = tmpdir / "geodesic_V1.dscalar.nii"
img = nib.load(L_V1)
data = np.zeros(img.shape)
data[0,:] = geo_arr
geo_img = nib.Cifti2Image(data, header=img.header)
nib.save(geo_img, geodesic_dscalar)

# Plot histogram of %BOLD power
- Get max voxel count for $f_1$ or $f_2$
- Get ROI colours from `dfm:get_roi_colour_codes`
1. Subset `df` by dataset and fractional overlap to generate `subset_df`
    - `df.experiment_id`: `3TNormal`, `7TNormal`, `3TVary`, `7TVary`
    - `df.roi_fo`: `.8`, `1.`
2. Get unique ROIs (specify thresholding: % of subjects in dataset that has this ROI)
    - Unique ROIs are selected by `subset_df.frequency_of_roi=="f1"`
        - $f_1$ is more sensitive then $f_2$, therefore $f_2$ is a subset of $f_1$ ROIs
    - ROIs contain contralateral **or** ipsilateral component
3. Sort ROIs by a metric (**y-axis**)
    - geodesic distance (order of regions along y-axis)
    - length of region along y-axis varied by ROI size (contralateral+ipsilateral)
    - get colours for each region based on HCP labels 
4. Subset `subset_df` by region (2) and extract metric of interest (3) to generate `region_df`
5. Plot (for normal/vary experiments plot dots/line using **subject-level**/**task-level**, respectively)
    - Note: Regions are selected using $f_1$ regions, as such $f_2$ counterpart
        - Also, not all regions will have an ipsilateral counterpart as contralateral sensitivity > ipsilateral sensitivity
    - Contralateral (`marker_style='o'`)
        - (a) dots (mean **or** median) of $f_1$ and $f_2$,  (size of marker varied by voxel count)
        - (b) line connecting $f_1$ and $f_2$
    - Repeat for ipsilateral (`marker_style='^'`)

functions

1.

In [None]:
def get_unique_rois(subset_df, n_ids, frequency_of_roi="f1", threshold=1.):
    _df = subset_df[subset_df.frequency_of_roi==frequency_of_roi]
    rois = [roi for roi in _df.hcp_roi]
    unique_rois = list(set(rois))

    roi_map = defaultdict(list)
    for roi in unique_rois:
        for _roi in rois:
            if _roi == roi:
                roi_map[roi].append(_roi)

    thresholded_unique_rois = []
    unique_rois = ( list(set([i.split("_")[-1] for i in roi_map.keys()])) )
    for roi in unique_rois:
        contra_count = len(roi_map[f"CONTRA_{roi}"])
        ipsi_count = len(roi_map[f"IPSI_{roi}"])
        max_count = max(contra_count,ipsi_count)

        _thr = max_count/n_ids
        if _thr >= threshold:
            print(roi, _thr, ipsi_count, contra_count, max_count, n_ids)
            thresholded_unique_rois.append(roi)

    return thresholded_unique_rois

def sort_lists(list1, list2):
    combined_lists = zip(list1, list2)
    sorted_combined_lists = sorted(combined_lists, key=lambda x: x[0])
    sorted_list1, sorted_list2 = zip(*sorted_combined_lists)
    
    return list(sorted_list1), list(sorted_list2)

def sort_unique_rois_by_geodesic_distance(geodesic_dscalar, unique_rois, hcp_mapping):

    if not geodesic_dscalar.exists():
        raise ValueError(f"{geodesic_dscalar} does not exist.")
    geo_data = nib.load(geodesic_dscalar).get_fdata()

    mean_geodesic_distances, n_vertices = [], []
    for roi_label in unique_rois:
        L_ROI = hcp_mapping[f"L_{roi_label}_ROI"]
        R_ROI = hcp_mapping[f"R_{roi_label}_ROI"]
        if not L_ROI.exists():
            raise ValueError(f"{L_ROI} does not exist.")
        if not R_ROI.exists():
            raise ValueError(f"{R_ROI} does not exist.")
        L_data = nib.load(L_ROI).get_fdata()
        R_data = nib.load(R_ROI).get_fdata()
        L_geo = geo_data[L_data==1]
        R_geo = geo_data[R_data==1]        
        mean_geo = np.concatenate((L_geo, R_geo)).mean()
        mean_geodesic_distances.append(mean_geo)
        n_vertices.append(L_data.sum() + R_data.sum())

    _, sorted_rois = sort_lists(mean_geodesic_distances, unique_rois)
    _, vertex_per_roi = sort_lists(mean_geodesic_distances, n_vertices)

    return sorted_rois, vertex_per_roi

def remove_outliers(arr, threshold=1.2):
    # Calculate mean and standard deviation
    mean_val = np.mean(arr)
    std_dev = np.std(arr)

    # Define threshold for outliers
    lower_bound = mean_val - threshold * std_dev
    upper_bound = mean_val + threshold * std_dev

    # Filter array to remove outliers
    filtered_arr = arr[(arr >= lower_bound) & (arr <= upper_bound)]

    return filtered_arr


get HCP label colours as RGB

In [None]:
hcp_c_dict = dict()
for roi in hcp_rois:
    _HCP_INFO = !wb_command -file-information {hcp_label} | grep "_{roi}_ROI"
    for i in _HCP_INFO:
        #hcp_c_dict[roi] = np.array([int(float(j)*256) for j in i.split('   ')[-5:-1]])[np.newaxis,:]
        hcp_c_dict[roi] = np.array([float(j) for j in i.split('   ')[-5:-1]])[np.newaxis,:]


4. Subset `subset_df` by region (2) and extract metric of interest (3) to generate `region_df`
5. Plot (for normal/vary experiments plot dots/line using **subject-level**/**task-level**, respectively)
    - Note: Regions are selected using $f_1$ regions, as such $f_2$ counterpart
        - Also, not all regions will have an ipsilateral counterpart as contralateral sensitivity > ipsilateral sensitivity
    - Contralateral (`marker_style='o'`)
        - (a) dots (mean **or** median) of $f_1$ and $f_2$,  (size of marker varied by voxel count)
        - (b) line connecting $f_1$ and $f_2$
    - Repeat for ipsilateral (`marker_style='^'`)

Write code to get min/max vertex count across all experiment_ids

In [None]:
roi_fo = .8
max_vertex_count = []
for experiment_id in df.experiment_id.unique():
    _tmp_df = df[(df.experiment_id==experiment_id) & (df.roi_fo==roi_fo)]
    for ix, i in _tmp_df.iterrows():
        vertex_count = i["f1_BOLD_power"].shape[0]
        max_vertex_count.append(vertex_count)
max_vertex_count = np.max(max_vertex_count)

print(f"Max vertex count HCP ROIs and experiments: {max_vertex_count}")

Get ratios of ROI across `experiment_ids`

In [None]:
_df = subset_df[subset_df.frequency_of_roi==frequency_of_roi]
rois = [roi for roi in _df.hcp_roi]
unique_rois = list(set(rois))

roi_map = defaultdict(list)
for roi in unique_rois:
    for _roi in rois:
        if _roi == roi:
            roi_map[roi].append(_roi)

thresholded_unique_rois = []
unique_rois = ( list(set([i.split("_")[-1] for i in roi_map.keys()])) )
for roi in unique_rois:
    contra_count = len(roi_map[f"CONTRA_{roi}"])
    ipsi_count = len(roi_map[f"IPSI_{roi}"])
    max_count = max(contra_count,ipsi_count)

    _thr = max_count/n_ids
    if _thr >= threshold:
        print(roi, _thr, ipsi_count, contra_count, max_count, n_ids)
        thresholded_unique_rois.append(roi)

return thresholded_unique_rois

In [None]:
import itertools
def sort_lists(list1, list2):
    combined_lists = zip(list1, list2)
    sorted_combined_lists = sorted(combined_lists, key=lambda x: x[0])
    sorted_list1, sorted_list2 = zip(*sorted_combined_lists)
    
    return list(sorted_list1), list(sorted_list2)

def sort_unique_rois_by_geodesic_distance(geodesic_dscalar, unique_rois, hcp_mapping):

    if not geodesic_dscalar.exists():
        raise ValueError(f"{geodesic_dscalar} does not exist.")
    geo_data = nib.load(geodesic_dscalar).get_fdata()

    mean_geodesic_distances, n_vertices = [], []
    for roi_label in unique_rois:
        L_ROI = hcp_mapping[f"L_{roi_label}_ROI"]
        R_ROI = hcp_mapping[f"R_{roi_label}_ROI"]
        if not L_ROI.exists():
            raise ValueError(f"{L_ROI} does not exist.")
        if not R_ROI.exists():
            raise ValueError(f"{R_ROI} does not exist.")
        L_data = nib.load(L_ROI).get_fdata()
        R_data = nib.load(R_ROI).get_fdata()
        L_geo = geo_data[L_data==1]
        R_geo = geo_data[R_data==1]        
        mean_geo = np.concatenate((L_geo, R_geo)).mean()
        mean_geodesic_distances.append(mean_geo)
        n_vertices.append(L_data.sum() + R_data.sum())

    _, sorted_rois = sort_lists(mean_geodesic_distances, unique_rois)
    _, vertex_per_roi = sort_lists(mean_geodesic_distances, n_vertices)

    return sorted_rois, vertex_per_roi

def get_unique_rois(df, experiment_id, roi_fo, frequency_of_roi, hemi_prefix, task_id, sub_threshold, geodesic_dscalar, hcp_mapping, VARY=False):
    experiment_df = df[(df.experiment_id==experiment_id)]
    if VARY:
        n_sub_ids_in_experiment = experiment_df.roi_task_id.unique().shape[0]
        subset_df = experiment_df[(experiment_df.roi_fo==roi_fo) & (experiment_df.frequency_of_roi==frequency_of_roi) & (experiment_df.hcp_roi.str.startswith(hemi_prefix))]
    else:
        n_sub_ids_in_experiment = experiment_df.sub_id.unique().shape[0]
        subset_df = experiment_df[(experiment_df.roi_fo==roi_fo) & (experiment_df.task_id==task_id) & (experiment_df.frequency_of_roi==frequency_of_roi) & (experiment_df.hcp_roi.str.startswith(hemi_prefix))]

    thresholded_unique_rois = []
    not_thresholded_unique_rois = []
    for roi in subset_df.hcp_roi.unique():
        _subset_df = subset_df[(subset_df.hcp_roi == roi)]
        n_sub_ids_per_roi = _subset_df.shape[0]
        if n_sub_ids_per_roi > n_sub_ids_in_experiment:
            raise ValueError()
        _thr = n_sub_ids_per_roi / n_sub_ids_in_experiment
        if _thr >= sub_threshold:
            thresholded_unique_rois.append(roi)
        else:
            not_thresholded_unique_rois.append(roi)

    if len(thresholded_unique_rois) != 0:
        thresholded_unique_rois,_ = sort_unique_rois_by_geodesic_distance(
            geodesic_dscalar,
            [i.split("_")[-1] for i in thresholded_unique_rois],
            hcp_mapping
        )
    
    if len(not_thresholded_unique_rois) != 0:
        not_thresholded_unique_rois,_ = sort_unique_rois_by_geodesic_distance(
            geodesic_dscalar,
            [i.split("_")[-1] for i in not_thresholded_unique_rois],
            hcp_mapping
        )

    return thresholded_unique_rois, not_thresholded_unique_rois

roi_fo = .8
experiment_ids = ["3TNormal","7TNormal","3TVary","7TVary"]
task_ids = ["entrain","AttendAway","",""]
frequency_of_rois = ["f1","f2","f1Uf2"]
hemi_prefices = ["IPSI","CONTRA"]
sub_threshold = 1.

roi_dict = {}
for experiment_id, task_id in zip(experiment_ids, task_ids):
    for frequency_of_roi, hemi_prefix in itertools.product(frequency_of_rois, hemi_prefices):
        roi_dict[(experiment_id,frequency_of_roi,hemi_prefix)] = get_unique_rois(df, experiment_id, roi_fo, frequency_of_roi, hemi_prefix, task_id, sub_threshold, geodesic_dscalar, hcp_mapping, VARY="Vary" in experiment_id)

for k,v in roi_dict.items():
    print(k,v)

In [None]:
roi_fo = .8
hemi_prefix = "CONTRA"
frequency_of_roi = "f1"
use_common_across_datasets = True
    
def compute_mean(X):
    X.sort
    X=X[::-1]
    half_index = len(X)//2
    top_half_mean = sum(X[:half_index]) / half_index

    #return top_half_mean
    return np.mean(X)

# Get common ROIs across all datasets
all_lists = []
for k, (thr_list, unthr_list) in roi_dict.items():
    if frequency_of_roi in k and hemi_prefix in k:
        all_lists.append(thr_list)
common_list = all_lists[0]
for lst in all_lists[1:]:
    common_list = [item for item in common_list if item in lst]

# Set up axes
experiment_ids = ["3TNormalC","3TNormal","7TNormal","3TVary","7TVary"]
fig,ax_dict = plt.subplot_mosaic([experiment_ids], layout="constrained",figsize=(7,2),dpi=300)

# Plot
for experiment_ix, experiment_id in enumerate(experiment_ids):

    if experiment_id == "3TNormalC":
        roi_task_id = "entrain"
        task_id = "control"
    if experiment_id == "3TNormal":
        roi_task_id = "entrain"
        task_id = "entrain"
    if experiment_id == "7TNormal":
        roi_task_id = "AttendAway"
        task_id = "AttendAway"

    if use_common_across_datasets:
        thresholded_rois = common_list
    else:
        thresholded_rois = roi_dict[(experiment_id,frequency_of_roi,hemi_prefix)][0]

    ax = ax_dict[experiment_id]
    if experiment_id == "3TNormalC":
        experiment_id="3TNormal"

    for ix, roi in enumerate(thresholded_rois):
        if experiment_id.endswith("Vary"):
            _df = df[(df.experiment_id==experiment_id) & (df.roi_fo==roi_fo) & (df.hcp_roi.str.startswith(hemi_prefix)) & (df.hcp_roi.str.endswith(f"_{roi}")) & (df.frequency_of_roi==frequency_of_roi)]#[["sub_id", f"{frequency_of_roi}_BOLD_power"]]
        else:
            _df = df[(df.experiment_id==experiment_id) & (df.roi_task_id==roi_task_id) & (df.task_id==task_id) & (df.roi_fo==roi_fo) & (df.hcp_roi.str.startswith(hemi_prefix)) & (df.hcp_roi.str.endswith(f"_{roi}")) & (df.frequency_of_roi==frequency_of_roi)]#[["sub_id", f"{frequency_of_roi}_BOLD_power"]]
        _df["metric_mean"] = _df[f"{frequency_of_roi}_BOLD_power"].apply(compute_mean)

        _y = _df.metric_mean.values
        _x = np.zeros_like(_y) + ix

        ax.scatter(_x,_y,c=hcp_c_dict[roi])
        ax.scatter(_x.mean(),_y.mean(),c='k',zorder=10)

        if ix == 0:
            _yline = _y
        else:
            _yline = np.vstack((_yline,_y))

    for j in range(_yline.shape[-1]):
        ax.plot(range(len(thresholded_rois)),_yline[:,j],lw=.2,c='k')


    ax.set_xticks(range(len(thresholded_rois)))
    ax.set_xticklabels(thresholded_rois,rotation=90, fontsize=6)
    ax.set_title(experiment_id)

In [None]:
_df

In [None]:
_yline.shape

In [None]:
np.vstack((_y,_y)).shape

In [None]:
unique_rois, vertex_per_roi = sort_unique_rois_by_geodesic_distance(geodesic_dscalar, unique_rois, hcp_mapping) # sorted by geodesic distance

In [None]:
roi_dict[("3TVary","f1","CONTRA")][0]

In [None]:
threshold_unique_rois = .8
experiment_ids = ["3TNormal","7TNormal","3TVary","7TVary"]
n_rois_per_experiment_id = {}
for experiment_id in experiment_ids:
    # 1
    if experiment_id == "3TNormal":
        subset_df = df[(df.experiment_id==experiment_id) & (df.roi_fo==roi_fo) & (df.task_id=="entrain")]
    else:
        subset_df = df[(df.experiment_id==experiment_id) & (df.roi_fo==roi_fo)]
    # 2
    if experiment_id.endswith("Normal"):
        col_id = "sub_id" 
    elif experiment_id.endswith("Vary"):
        col_id = "roi_task_id" 
    else:
        raise ValueError(f"{experiment_id} not implemented.")
    n_ids = len(subset_df[col_id].unique())
    unique_rois = get_unique_rois(subset_df, n_ids, threshold=threshold_unique_rois)
    n_rois_per_experiment_id[experiment_id] = len(unique_rois)

max_rois_across_experiments = np.max([i for i in n_rois_per_experiment_id.values()])
n_rois_per_experiment_id, max_rois_across_experiments

In [None]:
roi_fo = .8
threshold_unique_rois = 1.
metric = "BOLD_power" # Or phase_delay
summary_metric_type = "mean"

mosaic = [
    ["3TNormal"]*n_rois_per_experiment_id["3TNormal"],
    ["7TNormal"]*n_rois_per_experiment_id["7TNormal"]+["7TNormalFILL"]*(max_rois_across_experiments-n_rois_per_experiment_id["7TNormal"]),
    ["3TVary"]*n_rois_per_experiment_id["3TVary"]+["3TVaryFILL"]*(max_rois_across_experiments-n_rois_per_experiment_id["3TVary"]),
    ["7TVary"]*n_rois_per_experiment_id["7TVary"],
]
fig, ax_dict = plt.subplot_mosaic(
    mosaic, 
    figsize=(5,4),
    dpi=200, 
    layout="constrained"
)

experiment_unique_rois = {}
experiment_ids = ["3TNormal","7TNormal","3TVary","7TVary"]
ylabels = ["3T","7T","3TVary","7TVary"]
for experiment_id,ylabel in zip(experiment_ids,ylabels):
    # 1
    subset_df = df[(df.experiment_id==experiment_id) & (df.roi_fo==roi_fo)]
    # 2
    if experiment_id.endswith("Normal"):
        col_id = "sub_id" 
    elif experiment_id.endswith("Vary"):
        col_id = "roi_task_id" 
    else:
        raise ValueError(f"{experiment_id} not implemented.")
    n_ids = len(subset_df[col_id].unique())
    unique_rois = get_unique_rois(subset_df, n_ids, threshold=threshold_unique_rois)
    # 3
    unique_rois, vertex_per_roi = sort_unique_rois_by_geodesic_distance(geodesic_dscalar, unique_rois, hcp_mapping) # sorted by geodesic distance
    experiment_unique_rois[experiment_id] = unique_rois
    # 4. Plot
    roi_y_coords = [.2, .8] # [f_1, f_2] where f_1 and f_2 is plotted between 0 and 1
    roi_c_dict = get_roi_colour_codes()
    FONTSIZE = 6
    LINEWIDTH = .4
    ax = ax_dict[experiment_id]

    # Loop over laterality
    all_metrics = []
    all_line_metrics = []
    for laterality in ["CONTRA","IPSI"]:

        marker = "o"
        if laterality == "IPSI":
            marker = "^"

        # Loop across all filtered HCP ROIs
        for hcp_ix, hcp_roi in enumerate(unique_rois):
            if experiment_id == "3TNormal":
                region_df = subset_df[(subset_df.hcp_roi.str.endswith(f"_{hcp_roi}")) & (subset_df.frequency_of_roi.isin(["f1","f2"]) & (subset_df.task_id=="entrain"))]
            else:
                region_df = subset_df[(subset_df.hcp_roi.str.endswith(f"_{hcp_roi}")) & (subset_df.frequency_of_roi.isin(["f1","f2"]))]
            data_levels = region_df[col_id].unique() # sub_id or task_id
            for data_level in data_levels:
                _df = region_df[(region_df[col_id]==data_level) & (region_df.hcp_roi==f"{laterality}_{hcp_roi}")]
                f1_data = _df[_df.frequency_of_roi=="f1"]
                f2_data = _df[_df.frequency_of_roi=="f2"]

                plot_line = False
                if (f1_data.shape[0] == 1 and (f1_data["f1_BOLD_power"].values[0]!=0).shape[0]) and (f2_data.shape[0] == 1 and (f2_data["f2_BOLD_power"].values[0]!=0).shape[0]):
                    plot_line = True
                    line_x = []
                    line_y = []

                if f1_data.shape[0] > 1:
                    raise ValueError(f"Expect 0 or 1 row in the filtered dataframe `f1_data`")
                if f2_data.shape[0] > 1:
                    raise ValueError(f"Expect 0 or 1 row in the filtered dataframe `f2_data`")

                # Loop over f1 and f2 data
                for plot_ix, f_data in enumerate([f1_data,f2_data]):
                    if f_data.shape[0] > 1:
                        raise ValueError(f"Expect 0 or 1 row in the filtered dataframe")

                    if f_data.shape[0]==1: 

                        frequency_of_roi = f_data.frequency_of_roi.values[0]
                        pos = (hcp_ix)+roi_y_coords[plot_ix]
                        metric = f_data[f"{frequency_of_roi}_BOLD_power"].values[0]
                        vertex_count = metric.shape[0]
                        if vertex_count == 0:
                            continue

                        if summary_metric_type == "median":
                            summary_metric = np.median(metric)
                        elif summary_metric_type == "mean":
                            summary_metric = np.mean(metric)
                        else:
                            raise ValueError(f"summary metric: {summary_metric_type} not supported.")

                        # s= # need to set marker size based on voxel count
                        marker_scale = (vertex_count*.9/444)+.1
                        ax.scatter(pos, summary_metric, s=marker_scale*50, c=roi_c_dict[frequency_of_roi],edgecolor='k',linewidths=LINEWIDTH,zorder=10, marker=marker)
                        if plot_line:
                            line_y.append(summary_metric)
                            line_x.append(pos)
                    
                    # Plot line
                    if plot_line:
                        ax.plot(line_x, line_y, lw=LINEWIDTH, c='k',zorder=5, linestyle='dotted')
                        all_line_metrics += line_y

                all_metrics.append(summary_metric)
                #print(data_level, f1_data.shape, f2_data.shape, plot_line)
        
    # Y-axis
    YTICKMAX = round(np.max(all_line_metrics),4)
    _yticks = [0,YTICKMAX/2,YTICKMAX]
    _yticklabels = [f"{i:.1e}" if i!=0 else "0" for i in _yticks]
    ax.set_yticks(_yticks)
    ax.set_yticklabels(_yticklabels,fontsize=FONTSIZE)
    ax.tick_params(axis="y",width=LINEWIDTH,pad=0.2,length=4)
    ax.tick_params(axis="x",width=LINEWIDTH,pad=0.4,length=0)
    YMAX = np.array(all_line_metrics).max()
    Y_OFFSET = YMAX*.05
    X_TICKLABEL_POSITION = -(YMAX*.35)
    ax.set_ylim(X_TICKLABEL_POSITION,YMAX+Y_OFFSET)
    ax.plot([-.1]*2,[_yticks[0],_yticks[-1]],lw=LINEWIDTH,c='k',zorder=20)
    
    # X-axis
    ax.set_xlim(-.1,len(unique_rois))
    ax.set_xticks([])
    XTICKS = [i-.5 for i in range(1, len(unique_rois)+1,1)]
    for xtickpos, xticklabel in zip(XTICKS,unique_rois):
        if xticklabel == "TPOJ2":
            ax.text(xtickpos,X_TICKLABEL_POSITION+(YMAX*.15),xticklabel,fontsize=FONTSIZE-.5,ha="center",va="center",c='white',zorder=50)
            continue
        ax.text(xtickpos,X_TICKLABEL_POSITION+(YMAX*.15),xticklabel,fontsize=FONTSIZE,ha="center",va="center",c='white',zorder=50)
    for start_ix, hcp_roi in enumerate(unique_rois):
        end_ix = start_ix+1
        ax.fill_between([start_ix, end_ix],[X_TICKLABEL_POSITION]*2,[-(YMAX*.05)]*2,color=hcp_c_dict[hcp_roi])

    for spine_type in ["left","right", "top", "bottom"]:
        ax.spines[spine_type].set_visible(False)

    _ylabel_pos = -X_TICKLABEL_POSITION / (-X_TICKLABEL_POSITION+YMAX+Y_OFFSET)
    _ylabel_pos = _ylabel_pos + ( (1 - _ylabel_pos) / 2 )
    ax.set_ylabel(ylabel,fontsize=FONTSIZE,y=_ylabel_pos)

for k,ax in ax_dict.items():
    if k.endswith("FILL"):
        ax.remove()

In [None]:
fig.savefig("D.png",dpi=600)

**Note**: For all power plots below uses *mean* for points, and *median* across points

In [None]:
from scipy import stats

Laterality dependence

In [None]:
roi_fo = .8
subset_df = df[(df.roi_fo==roi_fo)]
experiment_ids = subset_df.experiment_id.unique()

data_dict = defaultdict(list)
for experiment_id in experiment_ids:
    _df = subset_df[(subset_df.experiment_id==experiment_id)]
    sub_ids = _df.sub_id.unique()
    for sub_id in sub_ids:
        _df = subset_df[(subset_df.experiment_id==experiment_id) & (subset_df.sub_id==sub_id)]
        roi_task_ids = _df.roi_task_id.unique()
        for roi_task_id in roi_task_ids:
            _df = subset_df[(subset_df.experiment_id==experiment_id) & (subset_df.sub_id==sub_id) & (subset_df.roi_task_id==roi_task_id)]
            hcp_rois = [i for i in _df.hcp_roi.unique() if i.startswith("CONTRA")]
            for frequency_of_roi, f_metric_label in zip(["f1","f2"],["f1_BOLD_power","f2_BOLD_power"]):
                for hcp_roi in hcp_rois:
                    _df = subset_df[(subset_df.experiment_id==experiment_id) & (subset_df.sub_id==sub_id) & (subset_df.roi_task_id==roi_task_id) & (subset_df.frequency_of_roi==frequency_of_roi) & (subset_df.hcp_roi.str.contains(hcp_roi.split("_")[-1]))]
                    if _df.shape[0]!=2:
                        continue
                    hcp_roi_stripped = hcp_roi.split("_")[-1]
                    ipsi_data = _df[(_df.hcp_roi==f"IPSI_{hcp_roi_stripped}")]
                    contra_data = _df[(_df.hcp_roi==f"CONTRA_{hcp_roi_stripped}")]
                    if ipsi_data.shape[0]==0 or contra_data.shape[0]==0:
                        continue
                    n_contra_vertices = contra_data[f_metric_label].values[0].shape[0]
                    n_ipsi_vertices = ipsi_data[f_metric_label].values[0].shape[0]
                    if n_contra_vertices == 0 or n_ipsi_vertices == 0:
                        continue

                    data_dict["experiment_id"].append(ipsi_data.experiment_id.values[0])
                    data_dict["roi_task_id"].append(ipsi_data.roi_task_id.values[0])
                    data_dict["sub_id"].append(ipsi_data.sub_id.values[0])
                    data_dict["frequency_of_roi"].append(ipsi_data.frequency_of_roi.values[0])
                    data_dict["hcp_roi"].append(ipsi_data.hcp_roi.values[0].split("_")[-1])
                    # Control for vertex size for contra/ipsilateral ROIs
                    if True:
                        contra_metric = np.sort(contra_data[f_metric_label].values[0])[::-1]
                        ipsi_metric = np.sort(ipsi_data[f_metric_label].values[0])[::-1]
                        _n_vertices = np.min([contra_metric.shape[0],ipsi_metric.shape[0]])
                        contra_metric = contra_metric[:_n_vertices]
                        ipsi_metric = ipsi_metric[:_n_vertices]
                        pvalue = stats.mannwhitneyu(
                            contra_metric,
                            ipsi_metric,
                        ).pvalue
                        contra_metric = np.mean(contra_metric)
                        ipsi_metric = np.mean(ipsi_metric)
                    else:
                        pvalue = stats.mannwhitneyu(
                            contra_data[f_metric_label].values[0],
                            ipsi_data[f_metric_label].values[0],
                        ).pvalue
                        contra_metric = np.mean(contra_data[f_metric_label].values[0])
                        ipsi_metric = np.mean(ipsi_data[f_metric_label].values[0])
                    laterality_difference = contra_metric-ipsi_metric # CONTRA > IPSI -> laterality_difference > 0
                    data_dict["metric"].append(laterality_difference)
                    data_dict["p_value"].append(pvalue)

laterality_df = pd.DataFrame(data_dict)
laterality_df

In [None]:
fig, ax_dict = plt.subplot_mosaic([["X"]], dpi=300, figsize=(.75,1.5),layout="constrained")

ax=ax_dict["X"]
s = 20
store_data = {}
for ix, frequency in enumerate(["f1","f2"]):
    _laterality_df = laterality_df[(laterality_df.frequency_of_roi==frequency)]
    c = [hcp_c_dict[hcp_roi] for hcp_roi in _laterality_df.hcp_roi.values]
    y = _laterality_df.metric.values
    store_data[frequency] = y
    x = np.zeros_like(y.shape) + ix
    x_jitter = np.random.uniform(low=-.35,high=.35,size=y.shape)
    ax.scatter(x+x_jitter,y,c=c,s=s,edgecolor='k',linewidth=LINEWIDTH)
    ax.plot([ix-.5,ix+.35],[np.median(y)]*2,lw=LINEWIDTH,c='r')
    ax.plot([-.5,1.5],[0]*2,lw=LINEWIDTH,c='grey',linestyle='dotted')

    print(stats.ttest_1samp(y,0))

ax.set_xticks([0,1])
ax.set_xticklabels(["$f_1$","$f_2$"],fontsize=FONTSIZE)
ax.set_yticks(ax.get_yticks())
ax.set_yticklabels([int(i) for i in ax.get_yticks()],fontsize=FONTSIZE)
ax.set_ylim(-.0005,.001)
ax.set_ylabel("$\Delta$Power ($C$-$I$)",fontsize=FONTSIZE,y=.42,ha="center",va="center")

for spine_type in ["left","right", "top", "bottom"]:
    ax.spines[spine_type].set_visible(False)
ax.tick_params("both",width=LINEWIDTH,pad=0.4)

fig.savefig("E.png",dpi=600)

In [None]:
y=store_data["f1"]
weights = np.ones_like(y)/len(y)

fig,ax=plt.subplots(dpi=200,figsize=(2,1))
ax.hist(y,weights=weights,bins=600)
ax.set_ylim(0,.5)
ax.plot([0,0],[0,.1],linewidth=.4,c='r')
ax.set_xlim(-.001,.001)
print(np.median(y), np.mean(y))

fig.savefig("F.png",dpi=600)

Frequency dependence

In [None]:
frequency_mapping = {
    "entrain": (.125,.2),
    "AttendAway": (.125,.2),
    "entrainA": (.125,.2),
    "entrainB": (.125,.175),
    "entrainC": (.125,.15),
    "entrainD": (.125,.2),
    "entrainE": (.15,.2),
    "entrainF": (.175,.2),
}

In [None]:
roi_fo = .8
subset_df = df[(df.roi_fo==roi_fo)]
experiment_ids = subset_df.experiment_id.unique()


data_dict = defaultdict(list)
for experiment_id in experiment_ids:
    _df = subset_df[(subset_df.experiment_id==experiment_id)]
    sub_ids = _df.sub_id.unique()
    for sub_id in sub_ids:
        _df = subset_df[(subset_df.experiment_id==experiment_id) & (subset_df.sub_id==sub_id)]
        roi_task_ids = _df.roi_task_id.unique()
        for roi_task_id in roi_task_ids:
            _df = subset_df[(subset_df.experiment_id==experiment_id) & (subset_df.sub_id==sub_id) & (subset_df.roi_task_id==roi_task_id)]
            hcp_rois = _df.hcp_roi.unique()
            for hcp_roi in hcp_rois:
                _df = subset_df[(subset_df.experiment_id==experiment_id) & (subset_df.sub_id==sub_id) & (subset_df.roi_task_id==roi_task_id) & (subset_df.hcp_roi==hcp_roi)]
                f1_frequency = frequency_mapping[_df.roi_task_id.values[0]][0]
                f2_frequency = frequency_mapping[_df.roi_task_id.values[0]][1]
                f_difference = round(f2_frequency-f1_frequency,3)
                f1_data = _df[(_df.frequency_of_roi=="f1")]
                f2_data = _df[(_df.frequency_of_roi=="f2")]
                if f1_data.shape[0] != 1 or f2_data.shape[0] != 1:
                    continue
                f1_data = f1_data["f1_BOLD_power"].values[0]
                f2_data = f2_data["f2_BOLD_power"].values[0]
                if f1_data.shape[0]==0 or f2_data.shape[0]==0:
                    continue

                if True:
                    _f1_data = np.sort(f1_data)[::-1]
                    _f2_data = np.sort(f2_data)[::-1]
                    _n_vertices = np.min([_f1_data.shape[0],_f2_data.shape[0]])
                    _f1_data = _f1_data[:_n_vertices]
                    _f2_data = _f2_data[:_n_vertices]
                    pvalue = stats.mannwhitneyu(
                        _f1_data,
                        _f2_data,
                    ).pvalue
                    metric = np.mean(_f1_data) - np.mean(_f2_data)
                else:
                    metric = np.mean(f1_data) - np.mean(f2_data) # f1>f2
                    
                data_dict["experiment_id"].append(experiment_id)
                data_dict["roi_task_id"].append(roi_task_id)
                data_dict["sub_id"].append(sub_id)
                data_dict["frequency_difference"].append(f_difference)
                data_dict["hcp_roi"].append(hcp_roi)
                data_dict["metric"].append(metric)

frequency_difference_dependence_df = pd.DataFrame(data_dict)
frequency_difference_dependence_df

In [None]:
fig, ax_dict = plt.subplot_mosaic([["ALL","3TNormal","7TNormal","3TVary","7TVary"]], dpi=300, figsize=(7,1.5),layout="constrained",gridspec_kw={"width_ratios": [1,.33,.33,1,1]})
s = 20

for key, ax in ax_dict.items():
    if key == "ALL":
        unique_frequencies = frequency_difference_dependence_df.frequency_difference.unique()
    else:
        unique_frequencies = frequency_difference_dependence_df[(frequency_difference_dependence_df.experiment_id==key)].frequency_difference.unique()
    for ix, f_difference in enumerate(unique_frequencies):
        if key == "ALL":
            _frequency_difference_dependence_df = frequency_difference_dependence_df[(frequency_difference_dependence_df.frequency_difference==f_difference)]
        else:
            _frequency_difference_dependence_df = frequency_difference_dependence_df[(frequency_difference_dependence_df.frequency_difference==f_difference)&(frequency_difference_dependence_df.experiment_id==key)]
        c = [hcp_c_dict[hcp_roi.split("_")[-1]] for hcp_roi in _frequency_difference_dependence_df.hcp_roi.values]
        y = _frequency_difference_dependence_df.metric.values
        print(stats.ttest_1samp(y,0))
        x = np.zeros_like(y.shape) + ix
        x_jitter = np.random.uniform(low=-.35,high=.35,size=y.shape)
        ax.scatter(x+x_jitter,y,c=c,s=s,edgecolor='k',linewidth=LINEWIDTH)
        ax.plot([ix-.5,ix+.35],[np.median(y)]*2,lw=LINEWIDTH,c='r')
        ax.plot([-.5,len(unique_frequencies)-.5],[0]*2,lw=LINEWIDTH,c='grey',linestyle='dotted')

        ax.set_xticks([i for i in range(len(unique_frequencies))])
        xticklabels = [f"{round(i,4)}" for i in unique_frequencies]
        ax.set_xticklabels(xticklabels, fontsize=FONTSIZE)
        ax.set_yticks(ax.get_yticks())
        ax.set_yticklabels([f"{i:.0e}" for i in ax.get_yticks()],fontsize=FONTSIZE)
        ax.set_ylim(-.00025,.001)
        if key == "ALL":
            ax.set_ylabel("$\Delta$Power ($f_1$-$f_2$)",fontsize=FONTSIZE,y=.42,ha="center",va="center")
        ax.set_title(key, fontsize=FONTSIZE)

        for spine_type in ["left","right", "top", "bottom"]:
            ax.spines[spine_type].set_visible(False)
        ax.tick_params("both",width=LINEWIDTH,pad=0.4)

fig.savefig("G.png",dpi=600)

MRI dependence on power
    - only plots ROIs that are common across experiments

In [None]:
import itertools

experiment_suffices = ["Normal","Vary"]
lateralities = ["CONTRA","IPSI"]

mosaic_grid = [
    [
        "f1_CONTRA_Normal",
        "f1_IPSI_Normal",
        "f2_CONTRA_Normal",
        "f2_IPSI_Normal",
        "f1_CONTRA_Vary",
        "f1_IPSI_Vary",
        "f2_CONTRA_Vary",
        "f2_IPSI_Vary",
    ]
]
fig, ax_dict = plt.subplot_mosaic(mosaic_grid, dpi=300, figsize=(7.2,1.5),layout="constrained")
for experiment_suffix, laterality in itertools.product(experiment_suffices,lateralities):

    rois_to_keep = [f"{laterality}_{i}" for i in list(set(experiment_unique_rois[f"3T{experiment_suffix}"]+experiment_unique_rois[f"7T{experiment_suffix}"]))]
    s = 20
    for f_type, f_metric in zip(["f1","f2"],["f1_BOLD_power","f2_BOLD_power"]):
        key = f"{f_type}_{laterality}_{experiment_suffix}"
        ax = ax_dict[f"{f_type}_{laterality}_{experiment_suffix}"]
        data_for_testing = {}
        for ix, mri_id in enumerate(["3T","7T"]):
            mri_data = df[(df.experiment_id==f"{mri_id}{experiment_suffix}") & (df.frequency_of_roi==f_type) & (df.hcp_roi.isin(rois_to_keep))]
            c = [hcp_c_dict[hcp_roi.split("_")[-1]] for hcp_roi,f_data in zip(mri_data.hcp_roi.values, mri_data[f_metric]) if not np.isnan(np.mean(f_data))]
            y = np.array([np.mean(i) for i in mri_data[f_metric] if not np.isnan(np.mean(i))])
            data_for_testing[mri_id] = y
            x = np.zeros_like(y.shape[0]) + ix
            x_jitter = np.random.uniform(low=-.35,high=.35,size=y.shape)
            ax.scatter(x+x_jitter,y,c=c,s=s,edgecolor='k',linewidth=LINEWIDTH)
            ax.plot([ix-.5,ix+.35],[np.median(y)]*2,lw=LINEWIDTH,c='r')
            ax.plot([-.5,1.5],[0]*2,lw=LINEWIDTH,c='grey',linestyle='dotted')


        pvalue = stats.mannwhitneyu(
            data_for_testing["3T"],
            data_for_testing["7T"],
        ).pvalue
        print(experiment_suffix,laterality,f_type,pvalue)
        ax.set_xticks([i for i in range(2)])
        xticklabels = ["3T","7T"]
        ax.set_xticklabels(xticklabels, fontsize=FONTSIZE)
        ax.set_yticks(ax.get_yticks())
        ax.set_yticklabels([f"{i:.0e}" for i in ax.get_yticks()],fontsize=FONTSIZE)
        ax.set_ylim(-.0005,.002)
        ax.set_title(key, fontsize=FONTSIZE)

        for spine_type in ["left","right", "top", "bottom"]:
            ax.spines[spine_type].set_visible(False)
        ax.tick_params("both",width=LINEWIDTH,pad=0.4)
ax_dict["f1_CONTRA_Normal"].set_ylabel("Power",fontsize=FONTSIZE,y=.42,ha="center",va="center")

fig.savefig("H.png",dpi=600)

Intersection dependence (f1&f2 vs f1/f2 alone)

In [None]:
roi_fo = .8

data_dict = defaultdict(list)
experiment_ids = df.experiment_id.unique()
for experiment_id in experiment_ids:
    _df = df[(df.roi_fo==roi_fo) & (df.experiment_id==experiment_id)]
    roi_task_ids = _df.roi_task_id.unique()
    for roi_task_id in roi_task_ids:
        _df = df[(df.roi_fo==roi_fo) & (df.experiment_id==experiment_id) & (df.roi_task_id==roi_task_id)]
        sub_ids = _df.sub_id.unique()
        for sub_id in sub_ids:
            _df = df[(df.roi_fo==roi_fo) & (df.experiment_id==experiment_id) & (df.roi_task_id==roi_task_id) & (df.sub_id==sub_id)]
            hcp_rois = _df.hcp_roi.unique()
            for hcp_roi in hcp_rois:
                _df = df[(df.roi_fo==roi_fo) & (df.experiment_id==experiment_id) & (df.roi_task_id==roi_task_id) & (df.sub_id==sub_id) & (df.hcp_roi==hcp_roi)]
                for f_type in ["f1","f2"]:
                    _df = df[(df.roi_fo==roi_fo) & (df.experiment_id==experiment_id) & (df.roi_task_id==roi_task_id) & (df.sub_id==sub_id) & (df.hcp_roi==hcp_roi)]
                    f_data = _df[_df.frequency_of_roi==f_type]
                    f_inter_data = _df[_df.frequency_of_roi=="f1Uf2"]

                    if f_data.shape[0]+f_inter_data.shape[0] != 2:
                        continue
                    if f_data[f"{f_type}_BOLD_power"].values[0].shape[0]==0:
                        continue
                    if f_inter_data[f"{f_type}_BOLD_power"].values[0].shape[0]==0:
                        continue
                    # Get intersection coordinates
                    f_data_vertices = f_data["vertex_coordinates"].values[0][0]
                    f_inter_data_vertices = f_inter_data["vertex_coordinates"].values[0][0]
                    intersection_bools = []
                    for f_data_vertex in f_data_vertices:
                        intersection_bool = False
                        if f_data_vertex in f_inter_data_vertices:
                            intersection_bool = True
                        intersection_bools.append(intersection_bool)
                    intersection_bools = np.array(intersection_bools)
                    # Get power of non-intersection and intersection data
                    _f_not_inter_data = f_data[f"{f_type}_BOLD_power"].values[0][intersection_bools==0]
                    _f_inter_data = f_data[f"{f_type}_BOLD_power"].values[0][intersection_bools==1]
                    if _f_not_inter_data.shape[0] == 0:
                        continue
                    
                    pvalue = stats.mannwhitneyu(
                        _f_inter_data,
                        _f_not_inter_data,
                    ).pvalue

                    
                    metric = np.mean(_f_inter_data) - np.mean(_f_not_inter_data) # power @ intersection of f1&f2 > outside
                    print(hcp_roi,metric,pvalue)
                        
                    data_dict["experiment_id"].append(experiment_id)
                    data_dict["roi_task_id"].append(roi_task_id)
                    data_dict["sub_id"].append(sub_id)
                    data_dict["hcp_roi"].append(hcp_roi)
                    data_dict["f_type"].append(f_type)
                    data_dict["metric"].append(metric)
                    data_dict["p_value"].append(pvalue)

intersection_df = pd.DataFrame(data_dict)
intersection_df

In [None]:
intersection_df[intersection_df.p_value<.01]

In [None]:

fig, ax_dict = plt.subplot_mosaic([["X"]], dpi=300, figsize=(.75,1.5),layout="constrained")

ax=ax_dict["X"]
s = 20
#store_data = {}
for ix, frequency in enumerate(["f1","f2"]):
    _intersection_df = intersection_df[(intersection_df.f_type==frequency)]
    c = [hcp_c_dict[hcp_roi.split("_")[-1]] for hcp_roi in _intersection_df.hcp_roi.values]
    y = _intersection_df.metric.values
    print(stats.ttest_1samp(y,0))
    #store_data[frequency] = y
    x = np.zeros_like(y.shape) + ix
    x_jitter = np.random.uniform(low=-.35,high=.35,size=y.shape)
    ax.scatter(x+x_jitter,y,c=c,s=s,edgecolor='k',linewidth=LINEWIDTH)
    ax.plot([ix-.5,ix+.35],[np.median(y)]*2,lw=LINEWIDTH,c='r')
    ax.plot([-.5,1.5],[0]*2,lw=LINEWIDTH,c='grey',linestyle='dotted')

ax.set_xticks([0,1])
ax.set_xticklabels(["$f_1$","$f_2$"],fontsize=FONTSIZE)
ax.set_yticks([0])
ax.set_yticklabels(["0"],fontsize=FONTSIZE)
ax.set_ylabel("$\Delta$Power ($f_{both}$-$f_{single}$)",fontsize=FONTSIZE,y=.42,ha="center",va="center")

for spine_type in ["left","right", "top", "bottom"]:
    ax.spines[spine_type].set_visible(False)
ax.tick_params("both",width=LINEWIDTH,pad=0.4)

fig.savefig("H.png",dpi=600)

- Compare contralateral, ipsilateral
    - Consider only subjects, regions with a contralateral and ipsilateral component
    - x-axis: L, R
    - y-axis: power
- Compare stimulation frequencies
    - Consider only subjects, regions with $f_1$ and $f_2$ component
    - x-axis: stimulation frequency
    - y-axis: power
- MRI dependence
    - x-axis: MRI strength (3T, 7T)
    - y-axis: power
- Control experiment
    - Plot across all 3T subjects, regions
    - x-axis: 3T control, 3T entrain
    - y-axis: power
- Compare (lower, higher visual hierarchy)
    - How to define this?

# Compare single subject maps to HCP retinotopic polar angle maps

Create group-level retinotopy maps (from HCP) 
- normalize metrics' coordinates to the `template_dscalar`

In [None]:

retinotopy_dir = Path("/opt/app/notebooks/data/S1200_7T_Retinotopy_Pr_9Zkk/S1200_7T_Retinotopy181/MNINonLinear/fsaverage_LR32k")
retinotopy_dscalars = {k: retinotopy_dir / f"S1200_7T_Retinotopy181.Fit1_{k}_MSMAll.32k_fs_LR.dscalar.nii" for k in ["PolarAngle","Eccentricity","ReceptiveFieldSize"]}
tmp_retinotopy_dscalars = {k: tmpdir / f"S1200_7T_Retinotopy181.Fit1_{k}_MSMAll.32k_fs_LR.dscalar.nii" for k in ["PolarAngle","Eccentricity","ReceptiveFieldSize"]}
for retino_type, raw_retino in retinotopy_dscalars.items():
    if not raw_retino.exists():
        raise ValueError(f"{retino_type}: {raw_retino} not found.")

    tmp_retino = tmp_retinotopy_dscalars[retino_type]
    !wb_command -cifti-create-dense-from-template {template_dscalar} {tmp_retino} -cifti {raw_retino}
    assert tmp_retino.exists()

Testing extraction of retinotopy information with a row of the dataframe (or `df`)

In [None]:
retinotopy_metrics = {k: read_roi_path(v) for k,v in tmp_retinotopy_dscalars.items()}

In [None]:

import sys
sys.path.append("ComputeCanada/frequency_tagging")
from dfm import (
    get_roi_colour_codes,
    change_font,
)
change_font()

def circular_median(vals):
    median_cos = np.mean([np.cos(i) for i in vals if not np.isnan(i)])
    median_sin = np.mean([np.sin(i) for i in vals if not np.isnan(i)])
    x = np.arctan2(median_sin,median_cos)

    return x

In [None]:
FONTSIZE = 6

roi_fo = .8
hcp_roi = "CONTRA_V1"
metric = "power"

roi_c_dict = get_roi_colour_codes()
for experiment_id in df.experiment_id.unique():

    sub_ids = df[(df.experiment_id==experiment_id)].sub_id.unique()

    for sub_id in sub_ids:

        roi_task_ids = df[(df.experiment_id==experiment_id) & (df.sub_id==sub_id)].roi_task_id.unique()

        for roi_task_id in roi_task_ids:

            subset_df = df[(df.roi_fo==roi_fo) & (df.hcp_roi==hcp_roi) & (df.experiment_id==experiment_id) & (df.sub_id==sub_id) & (df.roi_task_id==roi_task_id)]
            max_metric = max([i.max() for i in subset_df['f1_BOLD_power']])
            assert subset_df.shape[0]<=3, subset_df

            fig, ax_dict = plt.subplot_mosaic([["f1","f2","f1Uf2"]],figsize=(2,1),subplot_kw={'projection':'polar'}, dpi=300)

            f_types = ["f1","f2","f1Uf2"]

            for _, r in subset_df.iterrows():
                _ = [f_types.pop(ix) for ix,i in enumerate(f_types) if r.frequency_of_roi == i]
                ax = ax_dict[r.frequency_of_roi]
                polar_angle = retinotopy_metrics["PolarAngle"][r.vertex_coordinates] * np.pi/180

                metric_map = {}
                if metric=="power":
                    if r.frequency_of_roi == "f1":
                        metric_map["f1"] = r.f1_BOLD_power
                    elif r.frequency_of_roi == "f2":
                        metric_map["f2"] = r.f2_BOLD_power
                    elif r.frequency_of_roi == "f1Uf2":
                        metric_map["f1"] = r.f1_BOLD_power
                        metric_map["f2"] = r.f2_BOLD_power
                    else:
                        raise ValueError()
                for k,_metric in metric_map.items():
                    ax.scatter(polar_angle, _metric, c=roi_c_dict[k],s=.2,zorder=1)
                    ax.plot([circular_median(polar_angle)]*2, [0,max_metric/5],linewidth=.5,c='k',zorder=2)
                ax.set_xticks([0])
                ax.set_xticklabels(["0"],fontsize=FONTSIZE-2)
                ax.set_yticks([max_metric/5])
                ax.set_yticklabels([f"{max_metric:.4f}"],fontsize=FONTSIZE-3)
                ax.tick_params(axis="both",pad=0,zorder=3)
                ax.grid(linewidth=.5)
                ax.spines.polar.set_visible(False)
                ax.set_title(f"{r.frequency_of_roi} [{circular_median(polar_angle)*180/np.pi:.1f}]",fontsize=FONTSIZE)
            for k in f_types:
                ax_dict[k].set_visible(False)

            fig.suptitle(f"[{experiment_id}] sub-{sub_id} roi-task-{roi_task_id}{r.quadrant_id}", fontsize=FONTSIZE-2)