In [None]:
import nibabel as nib
import numpy as np

import sys
sys.path.append("/opt/wbplot")

from wbplot import dscalar

from IPython.display import Image
import matplotlib.pyplot as plt

from pathlib import Path
import pandas as pd
from collections import defaultdict
from functools import lru_cache

sys.path.append("ComputeCanada/frequency_tagging")
from dfm import (
    get_roi_colour_codes,
    change_font,
)
change_font()

Get HCP info
- `hcp_mappings`: dict of ROI: dscalars
- `hcp_rois`: 

In [2]:

"""Get HCP labels
"""
dlabel_dir = Path("/opt/app/notebooks/data/dlabels")
hcp_label = dlabel_dir / "Q1-Q6_RelatedValidation210.CorticalAreas_dil_Final_Final_Areas_Group_Colors.32k_fs_LR.dlabel.nii"

_HCP_INFO = !wb_command -file-information {hcp_label}
HCP_LABELS = []
HCP_COUNTER = 0
for i in _HCP_INFO:
    if len(i) == 60 and any(["L_" in i, "R_" in i]):
        hcp_colors = tuple([float(f"0.{k}") for k in [j.split(' ') [0] for j in i.split('0.')][-3:]] + [1])
        if ' R_' in i:
            roi = i.split("_ROI")[0].split(' R_')[1]
            HCP_LABELS.append(f"R_{roi}_ROI")
        if ' L_' in i:
            roi = i.split("_ROI")[0].split(' L_')[1]
            HCP_LABELS.append(f"L_{roi}_ROI")
        HCP_COUNTER += 1

"""Get HCP label coordinates
"""
dscalar_dir = Path("/opt/app/notebooks/data/dscalars")
tmpdir = Path("/tmp")
template_dscalar = dscalar_dir / "S1200.MyelinMap_BC_MSMAll.32k_fs_LR.dscalar.nii"

hcp_mapping = {}
for roi_label in HCP_LABELS:
    out_dscalar = tmpdir / f"{roi_label}.dscalar.nii"
    if out_dscalar.exists():
        hcp_mapping[roi_label] = out_dscalar
        continue
    !wb_command -cifti-label-to-roi {hcp_label} {out_dscalar} -name {roi_label}
    assert out_dscalar.exists(), f"{out_dscalar.stem} does not exist."
    hcp_mapping[roi_label] = out_dscalar
hcp_rois = list(set([k.split('_')[1] for k in hcp_mapping.keys()]))

Functions

In [None]:

def convert_to_fractional_overlap(data):

    return data.sum(0) / data.shape[0]

def map_data_to_value(data_list):

    for ix, (k,v) in enumerate(data_list):

        if ix == 0:
            new_data = k.copy() * v
        else:
            new_data += k * v

    return new_data

def get_quadrant_id(mask_path):
    rel_mask_path = mask_path.split("/")[-1]
    idx_start = rel_mask_path.find("Q")
    quadrant_id = rel_mask_path[idx_start:idx_start+2]
    assert quadrant_id in ['Q1', 'Q2'], f"{quadrant_id} not Q1 or Q2"

    return quadrant_id

@lru_cache(maxsize=360)
def read_roi_path(roi_path):
    return nib.load(roi_path).get_fdata()[0,:]

def contains_all_strings(input_str, string_list):
    for string in string_list:
        if string not in input_str:
            return False
    return True

def find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, fo, sub_id, data_split_id="train", match_str="activations.dtseries.nii", additional_match_strs=None, additional_match_str=None, corr_type="uncp"):
    import os
    directory = f"/scratch/fastfmri/experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-200_batch-merged_desc-IMall_roi-{roi_task_id}-{roi_f_1}_pval-{corr_type}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/"
    if additional_match_strs is not None:
        match_str = additional_match_strs + [match_str, f"data-{data_split_id}"]
        activations_files = []
        for file in os.listdir(directory):
            if contains_all_strings(file, match_str):
                activations_files.append(file)
    else:
        activations_files = [file for file in os.listdir(directory) if f'data-{data_split_id}' in file and match_str in file]

    return [f"{directory}{i}" for i in activations_files]

def set_base_dir(basedir):
    basedir = Path(basedir)
    if not basedir.exists():
        basedir.mkdir(exist_ok=True, parents=True)

    return basedir

def load_mean_dtseries(dtseries):
    mean_power = nib.load(dtseries).get_fdata().mean(0)
    return mean_power

def crop_and_save(input_file, output_file, left, top, right, bottom):
    from PIL import Image
    try:
        # Open the input image
        with Image.open(input_file) as img:
            # Crop the image
            cropped_img = img.crop((left, top, right, bottom))
            # Save the cropped image
            cropped_img.save(output_file)
            print("Cropped image saved successfully as", output_file)
    except Exception as e:
        print("An error occurred:", e)

Save visualizations
- 3T normal (.125/.2)
- 7T normal (.125/.2)
- 3T varying frequencies
- 7T varying frequencies

In [4]:
def get_im_frequencies(f1,f2):
    assert f2 > f1, f"{f2} <= {f1}"
    im_frequencies = {}
    im_frequencies["first_order"] = [
        ("f1",f1),
        ("f2",f2),
    ]
    f2_sub_f1 = round(f2-f1,10)
    f1_plus_f2 = round(f1+f2,10)
    f1_mul_2 = round(f1*2,10)
    f2_mul_2 = round(f2*2,10)
    f1_mul_2_sub_f2 = round(2*f1-f2,10)
    f2_mul_2_sub_f1 = round(2*f2-f1,10)
    im_frequencies["second_order"] = [
        ("f2-f1",f2_sub_f1),
        ("f1+f2",f1_plus_f2),
        ("2f1",f1_mul_2),
        ("2f2",f2_mul_2),
    ]
    im_frequencies["third_order"] = [
        ("2f1-f2",f1_mul_2_sub_f2),
        ("2f2-f1",f2_mul_2_sub_f1),
    ]

    return im_frequencies

def convert_f_im(f_im,fo=1.,mask=None,f_im_c=f_im_c,mask_c=mask_c):
    f_im_data = convert_to_fractional_overlap(nib.load(f_im).get_fdata())
    f_im_data = (f_im_data >= fo).astype(int)
    if mask:
        mask_data = convert_to_fractional_overlap(nib.load(mask).get_fdata())
        mask_data = (mask_data >= 1.).astype(int)
        mask_data -= f_im_data
    data_dict = [(f_im_data,f_im_c)]
    if mask:
        data_dict.append((mask_data,mask_c))
    
    return map_data_to_value(data_dict)

def binarize_mask(data, f_im_c, mask_c, im_key):
    data_dict = {
        im_key: data.copy(),
    }
    data_dict[im_key][(data_dict[im_key]==f_im_c)] = 1
    for v in data_dict.values():
        v[v==mask_c] = 0

    return data_dict

def append_data(
    df_data,
    hcp_mapping,
    map_data,
    power_im_data,
    pd_im_data, 
    q_id,
    experiment_label, 
    sub_id, 
    roi_fo,
    roi_task_id,
    task_id,
):
    """Create function to store vertex level data for each HCP ROI:
    - columns: [cohort_id, sub_ids, quadrant_id, hcp_roi, im_code, vertex_count, vertex_coordinates, f_im_BOLD_power, f_im_phase_delay]
        - roi_fo = region fractional overlap threshold
        - cohort_id = cohort_id of each dataset [3T/7T Normal/Vary]
            - sub_ids = sub_id of all ROIs in cohort
                - quadrant_id = each subject will have a quadrant_id (corresponding to quadrant stimulation)
                - hcp_roi = all HCP ROIs, convert L/R to express laterality
                    - CONTRA/IPSI
                    - im_code = each `hcp_roi` will have a ROI corresponding to f1, f2 or both (f1Uf2)
                        - vertex_count = each `im_code` will have a vertex_count
                        - vertex_coordinates = each `im_code` will have coordinates to all its vertices
                        - f_im_BOLD_power = each `im_code` will have a np.array of power values corresponding to each vertex
                        - f_im_phase_delay = each `im_code` will have a np.array of phase delay values corresponding to each vertex
    """
    for im_code, f_data in map_data.items():
        for roi_label, roi_path in hcp_mapping.items():
            if q_id == "Q1":
                contra = "L_"
            elif q_id == "Q2":
                contra = "R_"
            else:
                raise ValueError(f"{q_id} not supported.")

            if roi_label.startswith(contra):
                roi_label = f"CONTRA_{roi_label[2:-4]}"
            else:
                roi_label = f"IPSI_{roi_label[2:-4]}"

            roi_mask = read_roi_path(roi_path)
            assert roi_mask.shape == f_data.shape

            hcp_and_f_roi = roi_mask * f_data
            vertex_coordinates = np.where(hcp_and_f_roi == 1)
            vertex_count = hcp_and_f_roi.sum()
            if vertex_count == 0:
                continue

            if roi_task_id == "control":
                f_im_BOLD_power = None
            else:
                f_im_BOLD_power = power_im_data[hcp_and_f_roi==1]
            if task_id != roi_task_id:
                f_im_phase_delay = None
            else:
                f_im_phase_delay = pd_im_data[hcp_and_f_roi==1]

            df_data["roi_task_id"].append(roi_task_id)
            df_data["task_id"].append(task_id)
            df_data["roi_fo"].append(roi_fo)
            df_data["experiment_id"].append(experiment_label)
            df_data["sub_id"].append(sub_id)
            df_data["quadrant_id"].append(q_id)
            df_data["hcp_roi"].append(roi_label)
            df_data["im_code"].append(im_code)
            df_data["vertex_count"].append(vertex_count)
            df_data["vertex_coordinates"].append(vertex_coordinates)
            df_data["f_im_BOLD_power"].append(f_im_BOLD_power)
            df_data["f_im_phase_delay"].append(f_im_phase_delay)

    return df_data

def generate_single_subject_maps(
    label, experiment_id, mri_id, sub_ids, 
    roi_task_ids, im_frequencies,
    df_data=None,
    corr_type="uncp", 
    ROI_FO=.8, SUB_THRESHOLD=.5,
    LEFT=590, TOP=80, RIGHT=1140, BOTTOM=460, VERTEX_TO = 59412,
    FORCE_TASK_ID=None,
    mask_c = .41,
    PALETTE="power_surf"
):
    
    if df_data is None:
        df_data = defaultdict(list)

    for ix, (sub_id, roi_task_id, im_frequencies) in enumerate(zip(
        sub_ids,
        roi_task_ids, 
        im_frequencies,
    )):

        if FORCE_TASK_ID is not None:
            _roi_task_id = roi_task_id
        else:
            _roi_task_id = FORCE_TASK_ID

        """ 
        im_frequencies = {
            "first_order": [("f1",.125), ...],
            "second_order": [("f2-f1",.075), ...],
            "third_order": [("2f1-f2",.05), ...],
        }
        """
        for im_order, (im_strs,im_fs) in im_frequencies.items():
            for im_str, im_f in zip(im_strs, im_fs):
                png_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/im_frequency_mapping")) / f"label-{label}_mri-{mri_id}_sub-{sub_id}_task-{roi_task_id}_f-{im_order}-{im_str}-{im_f}_corr-{corr_type}_fo-{ROI_FO}.png"
                dscalar_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/im_frequency_mapping_cifti")) / f"label-{label}_mri-{mri_id}_sub-{sub_id}_task-{roi_task_id}_f-{im_order}-{im_str}-{im_f}_corr-{corr_type}_fo-{ROI_FO}.dtseries.nii"
                if png_out.exists():
                    pass

                f_im = find_activations(experiment_id, mri_id, roi_task_id, im_f, .8, sub_id, match_str="activations.dtseries.nii", corr_type=corr_type)
                mask = find_activations(experiment_id, mri_id, roi_task_id, im_f, .8, sub_id, match_str="mask.dtseries.nii", corr_type=corr_type)
                pd_im = find_activations(experiment_id, mri_id, roi_task_id, im_f, .8, sub_id, data_split_id = "train", match_str="phasedelay.dtseries.nii", additional_match_strs=[roi_task_id,f"f-{roi_f_1}"], corr_type=corr_type)
                power_im = find_activations(experiment_id, mri_id, roi_task_id, im_f, .8, sub_id, data_split_id = "test", match_str="power.dtseries.nii", additional_match_strs=[_roi_task_id,f"f-{roi_f_1}"], corr_type=corr_type)
                for f_label, f in zip([im_str,"mask",f"pd_{im_str}",f"power_{im_str}"], [f_im,mask,pd_im,power_im]):
                    if roi_task_id == "control" and experiment_id == "1_frequency_tagging":
                        if f_label in [im_str,"mask"]:
                            assert len(f) == 1, f"{sub_id}, {f_label} - {f}"
                    else:
                        assert len(f) == 1, f"{sub_id}, {f_label} - {f}, {experiment_id} {roi_task_id}"

                f_im = f_im[0]
                data = convert_f_im(f_im, fo=ROI_FO, mask=mask[0], f_im_c=f_im_c, mask_c=mask_c)
                data = data[:VERTEX_TO]
        
                map_data = binarize_mask(data,f_im_c,mask_c,im_str)
                pd_im_data = load_mean_dtseries(pd_im[0])[:VERTEX_TO]
                if roi_task_id == "control" and experiment_id == "1_frequency_tagging":
                    power_im_data = None
                # Power metrics were not calculated for control task condition (no voxels allocated to task-control ROIs)
                else:
                    power_im_data = load_mean_dtseries(power_im[0])[:VERTEX_TO]
                q_id = get_quadrant_id(mask[0])
                df_data = append_data(
                    df_data, 
                    hcp_mapping, 
                    map_data, 
                    power_im_data, 
                    pd_im_data, 
                    q_id,
                    label, 
                    sub_id,
                    roi_fo,
                    roi_task_id,
                    _roi_task_id, 
                )

                palette_params = {
                    "disp-zero": False,
                    "disp-neg": True,
                    "disp-pos": True,
                    "pos-user": (0, 1.),
                    "neg-user": (-1,0),
                    "interpolate": True,
                }
                # Save f1f2 map as dtseries
                f1_img = nib.load(f1)
                dscalar_to_save_as_cifti = np.zeros((1,f1_img.shape[-1]))
                dscalar_to_save_as_cifti[0,:VERTEX_TO] = data
                f1f2_img = nib.Cifti2Image(dscalar_to_save_as_cifti, header=f1_img.header)
                f1f2_img.header.matrix[0].number_of_series_points = 1
                nib.save(f1f2_img, dscalar_out)
                dscalar(
                    png_out, data, 
                    orientation="portrait", 
                    hemisphere='right',
                    palette=PALETTE, 
                    palette_params=palette_params,
                    transparent=False,
                    flatmap=True,
                    flatmap_style='plain',
                )
                crop_and_save(png_out, str(png_out).replace("png", "cropped.png"), LEFT, TOP, RIGHT, BOTTOM)
                
                track = [len(v) for k,v in df_data.items()]
                print(track)

            return df_data



        #f1, f2 = f1[0], f2[0]
        #data = combine_f1_f2(f1, f2, fo=ROI_FO, mask=mask[0], f1_c=f1_c,f2_c=f2_c,f1f2_c=f1f2_c,mask_c=mask_c)
        #data = data[:VERTEX_TO]
            """
            if im_order == "first_order":
                f1 = find_activations(experiment_id, mri_id, roi_task_id, im_f[0], .8, sub_id, match_str="activations.dtseries.nii")
                f2 = find_activations(experiment_id, mri_id, roi_task_id, im_f[1], .8, sub_id, match_str="activations.dtseries.nii")
                all_paths = [f1, f2, mask]
                im_paths = [f1[0], f2[0]]
                im_cs = [-.1, .82] # red, blue, (overlap is white)
            elif im_order == "second_order":
                f2_sub_f1 = find_activations(experiment_id, mri_id, roi_task_id, im_f[0], .8, sub_id, match_str="activations.dtseries.nii")
                f1_plus_f2 = find_activations(experiment_id, mri_id, roi_task_id, im_f[1], .8, sub_id, match_str="activations.dtseries.nii")
                f1_mul_2 = find_activations(experiment_id, mri_id, roi_task_id, im_f[2], .8, sub_id, match_str="activations.dtseries.nii")
                f2_mul_2 = find_activations(experiment_id, mri_id, roi_task_id, im_f[3], .8, sub_id, match_str="activations.dtseries.nii")
                all_paths = [f2_sub_f1, f1_plus_f2, f1_mul_2, f2_mul_2, mask]
                im_paths = [f2_sub_f1[0], f1_plus_f2[0], f1_mul_2[0], f2_mul_2[0]]
                im_cs = [-.1, .95, .82, .14] # red, green, blue, yellow  (overlap is white)
            elif im_order == "third_order":
                f1_mul_2_sub_f2 = find_activations(experiment_id, mri_id, roi_task_id, im_f[0], .8, sub_id, match_str="activations.dtseries.nii")
                f2_mul_2_sub_f1 = find_activations(experiment_id, mri_id, roi_task_id, im_f[1], .8, sub_id, match_str="activations.dtseries.nii")
                all_paths = [f1_mul_2_sub_f2, f2_mul_2_sub_f1, mask]
                im_paths = [f1_mul_2_sub_f2[0], f2_mul_2_sub_f1[0]]
                im_cs = [-.1, .82] # red, blue, (overlap is white)
            elif im_order.startswith("f1f2"):
                f1 = find_activations(experiment_id, mri_id, roi_task_id, im_f[0], .8, sub_id, match_str="activations.dtseries.nii")
                f2 = find_activations(experiment_id, mri_id, roi_task_id, im_f[1], .8, sub_id, match_str="activations.dtseries.nii")
                im_product = find_activations(experiment_id, mri_id, roi_task_id, im_f[2], .8, sub_id, match_str="activations.dtseries.nii")
                all_paths = [f1, f2, im_product, mask]
                im_paths = [f1[0], f2[0], im_product[0]]
                im_cs = [-.1, .82] # red, blue, (overlap is white)


            for f in all_paths:
                assert len(f) == 1

            data = combine_im(im_order, im_paths, im_cs, fo=ROI_FO, mask=mask[0], mask_c=mask_c)
            data = data[:VERTEX_TO]
            map_data = merge_and_binarize_mask(data, im_order, im_cs)
            q_id = get_quadrant_id(mask[0])
            roi_vertex_count = roi_vertex_counter(roi_vertex_count, hcp_mapping, map_data, q_id)

            palette_params = {
                "disp-zero": False,
                "disp-neg": True,
                "disp-pos": True,
                "pos-user": (0, 1.),
                "neg-user": (-1,0),
                "interpolate": True,
            }
            dscalar(
                png_out, data, 
                orientation="portrait", 
                hemisphere='right',
                palette=PALETTE, 
                palette_params=palette_params,
                transparent=False,
                flatmap=True,
                flatmap_style='plain',
            )
            #crop_and_save(png_out, str(png_out).replace("png", "cropped.png"), LEFT, TOP, RIGHT, BOTTOM)
            """

    """
    roi_cohort_info = condense_roi_info_across_cohort(
        sub_ids, 
        roi_vertex_count, 
        sub_threshold=SUB_THRESHOLD
    )
    """
    return 0
    return roi_cohort_info



In [5]:
"""Set up for visualizing dual frequency tagging across each subject using fractional overlap
"""
PALETTE = "power_surf"
f1_c = -.1 # red -.1
f2_c = .82 # blue .82
f1f2_c = .88 # white .88 yellow .14
mask_c = .41 # .9 [green], .1 [goldish], .4 [black]

cohort_roi_info_across_experiments = {}
ROI_FOS = [.8, 1.]
corr_type = "uncp"

"""Save png
"""
# 3T control under entrain condition (set this to get power measurements with entrain ROIs)
label = "3TNormal"
df_data = None
for _roi_task_id in ["entrain"]:
    for ROI_FO in ROI_FOS:
        experiment_id = "1_frequency_tagging" 
        mri_id = "3T"
        sub_ids = ["000", "002", "003", "004", "005", "006", "007", "008", "009"] 
        roi_task_ids = [_roi_task_id] * len(sub_ids)
        roi_f_1s = [.125] * len(sub_ids)
        roi_f_2s = [.2] * len(sub_ids)
        df_data = generate_single_subject_maps(
            label, experiment_id, mri_id, sub_ids, 
            roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
            df_data=df_data,
            corr_type=corr_type,
            ROI_FO=ROI_FO, SUB_THRESHOLD=.5,
            FORCE_TASK_ID="control"
        )
# 3T normal
label = "3TNormal"
for _roi_task_id in ["entrain"]:
    for ROI_FO in ROI_FOS:
        experiment_id = "1_frequency_tagging" 
        mri_id = "3T"
        sub_ids = ["000", "002", "003", "004", "005", "006", "007", "008", "009"] 
        roi_task_ids = [_roi_task_id] * len(sub_ids)
        roi_f_1s = [.125] * len(sub_ids)
        roi_f_2s = [.2] * len(sub_ids)
        df_data = generate_single_subject_maps(
            label, experiment_id, mri_id, sub_ids, 
            roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
            df_data=df_data,
            corr_type=corr_type,
            ROI_FO=ROI_FO, SUB_THRESHOLD=.5
        )
# 7T normal
label = "7TNormal"
for ROI_FO in ROI_FOS:
    experiment_id = "1_attention" 
    mri_id = "7T"
    sub_ids = ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]
    roi_task_ids = ["AttendAway"] * len(sub_ids)
    roi_f_1s = [.125] * len(sub_ids)
    roi_f_2s = [.2] * len(sub_ids)
    df_data = generate_single_subject_maps(
        label, experiment_id, mri_id, sub_ids, 
        roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
        df_data=df_data,
        corr_type=corr_type,
        ROI_FO=ROI_FO, SUB_THRESHOLD=.5
    )
# 3T vary
label = "3TVary"
for ROI_FO in ROI_FOS:
    experiment_id = "1_frequency_tagging"
    mri_id = "3T"
    sub_ids = ["020"] * 3 + ["021"] * 3
    roi_task_ids = [f"entrain{i}" for i in ["A", "B", "C", "D", "E", "F"]]
    roi_f_1s = [.125] * 3 + [.125, .15, .175]
    roi_f_2s = [.2, .175, .15] + [.2] * 3
    df_data = generate_single_subject_maps(
        label, experiment_id, mri_id, sub_ids, 
        roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
        df_data=df_data,
        corr_type=corr_type,
        ROI_FO=ROI_FO, SUB_THRESHOLD=.5
    )
# 7T vary
label = "7TVary"
for ROI_FO in ROI_FOS:
    experiment_id = "1_frequency_tagging"
    mri_id = "7T"
    sub_ids = ["020"] * 3 + ["021"] * 3
    roi_task_ids = [f"entrain{i}" for i in ["A", "B", "C", "D", "E", "F"]]
    roi_f_1s = [.125] * 3 + [.125, .15, .175]
    roi_f_2s = [.2, .175, .15] + [.2] * 3
    df_data = generate_single_subject_maps(
        label, experiment_id, mri_id, sub_ids, 
        roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
        df_data=df_data,
        corr_type=corr_type,
        ROI_FO=ROI_FO, SUB_THRESHOLD=.5
    )

df = pd.DataFrame(df_data)

from IPython.display import clear_output
clear_output()


Info: Time to read /tmp/HumanCorticalParcellations/S1200.L.flat.32k_fs_LR.surf.gii was 0.061685 seconds.


Info: Time to read /tmp/HumanCorticalParcellations/S1200.L.inflated_MSMAll.32k_fs_LR.surf.gii was 0.023692 seconds.


Info: Time to read /tmp/HumanCorticalParcellations/S1200.L.midthickness_MSMAll.32k_fs_LR.surf.gii was 0.03587 seconds.


Info: Time to read /tmp/HumanCorticalParcellations/S1200.L.very_inflated_MSMAll.32k_fs_LR.surf.gii was 0.022113 seconds.


Info: Time to read /tmp/HumanCorticalParcellations/S1200.R.flat.32k_fs_LR.surf.gii was 0.060976 seconds.


Info: Time to read /tmp/HumanCorticalParcellations/S1200.R.inflated_MSMAll.32k_fs_LR.surf.gii was 0.023907 seconds.


Info: Time to read /tmp/HumanCorticalParcellations/S1200.R.midthickness_MSMAll.32k_fs_LR.surf.gii was 0.035913 seconds.


Info: Time to read /tmp/HumanCorticalParcellations/S1200.R.very_inflated_MSMAll.32k_fs_LR.surf.gii was 0.022091 seconds.


Info: Time to read /tmp/ImageDense.dscalar.nii was 0.038776 

In [None]:
show_top_x = 20
for i in ["3TNormal", "7TNormal",]: #"3TVary", "7TVary"]:
    print('\n',i)
    for j_ix, j in enumerate(cohort_roi_info_across_experiments[i][.2][:show_top_x]):
        print(j_ix+1, j)

In [None]:
sub_id = "020"
task_id = "entrainA"
power_f = .125

png = Path(f"/tmp/{sub_id}.{task_id}.power-f-{power_f}.png")
assert png.exists()
print(png)
Image(png)

In [None]:
sub_id = "020"
task_id = "entrainA"
power_f = .125

png = Path(f"/tmp/{sub_id}.{task_id}.power-f-{power_f}.png")
assert png.exists()
print(png)
Image(png)

Plot power across ROIs

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import itertools

def read_pickle(pkl):

    import pickle 
    
    with open(pkl, 'rb') as f:
        data = pickle.load(f)

    return data

def read_statistics(observed_statistics, n_bootstraps):

    accuracy = sum([i[0] < .05 for i in observed_statistics]) / n_bootstraps

    observed_statistics = np.array([i[1] for i in observed_statistics])
    mean_statistic = np.mean(observed_statistics)
    confidence_interval = np.percentile(observed_statistics, [2.5, 97.5])

    return {
        "accuracy": accuracy,
        "mean": mean_statistic, 
        "CI": confidence_interval,
    }

In [None]:
from copy import copy 

def plot_power_of_rois_across_cohort(
    experiment_id, 
    mri_id, 
    sub_ids, 
    roi_task_ids, 
    roi_fs, 
    primary_fs, 
    secondary_fs, 
    task_ids, 
    fo, 
    n_permutations,n_bootstraps,
    JITTER=.25,
    FONTSIZE=4,
    FIGSIZE=(5,2),
    NORMAL_EXPERIMENTS=True,
):

    all_fs = list(set(roi_fs))
    all_fs.sort()

    fig, axs = plt.subplots(ncols=2, nrows=2, figsize=FIGSIZE, dpi=200)

    for row_ix, control_roi_size in enumerate([True, False]):
        for col_ix, phaseadjusted in enumerate([True, False]):

            f_mapping = {
                "primary_frequency": axs[row_ix, col_ix], 
                "secondary_frequency": axs[row_ix, col_ix], 
                "im_frequency": axs[row_ix, col_ix],
            }

            primary_mean_across_cohort = defaultdict(list)
            primary_frequency_across_cohort = defaultdict(list)
            for sub_id, roi_task_id, roi_f, primary_f, secondary_f, task_id in zip(
                sub_ids, roi_task_ids, roi_fs, primary_fs, secondary_fs, task_ids
            ):
                if control_roi_size and primary_f < secondary_f:
                    pkl = Path(f"/scratch/fastfmri/pickles/experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{roi_task_id}_roi-{roi_f}_controlroisizetomatch-{secondary_f}_task-{task_id}_fo-{fo}_phaseadjusted-{phaseadjusted}_n-{n_permutations}.pkl")
                else:
                    pkl = Path(f"/scratch/fastfmri/pickles/experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{roi_task_id}_roi-{roi_f}_task-{task_id}_fo-{fo}_phaseadjusted-{phaseadjusted}_n-{n_permutations}.pkl")
                #print(f"[{row_ix}, {col_ix}] {pkl}")
                #print(f"ROI-{roi_task_id}-{roi_f}, probing [{primary_f}/{secondary_f}]")
                assert pkl.exists(), f"{pkl} does not exist."

                pkl_data = read_pickle(pkl)

                plot_x = []
                plot_y = []
                for k, v in pkl_data.items():
                    #print(k, f"test-{round(secondary_f-primary_f, 10)}")
                    marker = "o"
                    mec = 'k'
                    ms=4
                    zorder=10
                    if k == f"test-{primary_f}":
                        label = "primary_frequency"
                    elif k == f"test-{secondary_f}":
                        label = "secondary_frequency"
                    elif k == f"test-{abs(round(secondary_f-primary_f, 10))}":
                        label = "im_frequency"
                    else:
                        raise ValueError(f"{k} does not match {primary_f} or {secondary_f}")
                        
                    ax = f_mapping[label]

                    frequency = float(k.split("test-")[-1])
                    
                    statistic_info = read_statistics(v, n_bootstraps)
                    if label == "primary_frequency":
                        primary_mean_across_cohort[sub_id].append(statistic_info['mean'])
                        if frequency == min(all_fs):
                            primary_frequency_across_cohort[sub_id].append(frequency+JITTER)
                        if frequency == max(all_fs):
                            primary_frequency_across_cohort[sub_id].append(frequency-JITTER)

                    _accuracy = statistic_info['accuracy']
                    if sub_id == "000":
                        print(pkl)
                        print(f"[x] sub-{sub_id}_roi-{roi_task_id}-{roi_f}_task-{task_id}_mainf-{primary_f}_secf-{secondary_f}_{label}, acc: {_accuracy*100:.1f}%")

                    if _accuracy == 1:
                        mfc = 'r'
                        zorder+=2
                    elif _accuracy < 1. and _accuracy >= .8:
                        mfc = 'gold'
                        zorder+=1
                    else:
                        mfc = 'grey'

                    offset=0
                    s=10
                    if label == 'secondary_frequency':
                        continue # temporarily remove
                        offset+=.003
                        ms = 2
                        zorder=5

                    jitter = np.random.uniform(low=-JITTER, high=JITTER)
                    ax.errorbar(
                        frequency+offset+jitter, statistic_info['mean'], yerr=statistic_info['CI'][:,np.newaxis], 
                        marker=marker, ms=ms, c='k', 
                        markeredgewidth=.5, markerfacecolor=mfc, markeredgecolor=mec,
                        lw=.5, capsize=2, capthick=1, zorder=zorder
                    )
                    #ax.text(frequency+offset+jitter+.003, statistic_info['mean'], roi_task_id, fontsize=FONTSIZE)

                    plot_x.append(frequency)
                    plot_y.append(statistic_info['mean'])

                #xfrequency_with_jitter = copy(plot_x)
                #xfrequency_with_jitter.sort()
                #xfrequency_with_jitter[0] += JITTER
                #xfrequency_with_jitter[1] -= JITTER
                #ax.plot(plot_x, plot_y, linestyle=':', lw=1, c='lightgrey', zorder=2)

            """CONNECT LINES SAME SUBJECTS
            """
            if NORMAL_EXPERIMENTS:
                for sub_id in primary_frequency_across_cohort.keys():
                    ax.plot(
                        primary_frequency_across_cohort[sub_id], primary_mean_across_cohort[sub_id], linestyle=':', lw=1, c='lightgrey', zorder=2
                    )
            ax.set_ylabel("PSD", fontsize=FONTSIZE)
            ax.set_xlabel("Frequency", fontsize=FONTSIZE)
            ax.set_title(f"Phaseadjusted-{phaseadjusted} controlsize-{control_roi_size}", fontsize=FONTSIZE)

            fig.suptitle(f"{mri_id} / ROI: f1/f2 / {fo}", fontsize=FONTSIZE)
            fig.tight_layout()

    max_vals = []
    for row_ix, col_ix in itertools.product(range(2), range(2)):
        vals = axs[row_ix,col_ix].get_ylim()
        max_vals.append(vals[-1])
    max_ylim = max(max_vals)

    for row_ix, col_ix in itertools.product(range(2), range(2)):
        axs[row_ix,col_ix].set_ylim(-.0001, max_ylim)
        axs[row_ix,col_ix].set_xticks(all_fs)
        axs[row_ix,col_ix].set_xticklabels([str(i) for i in all_fs], fontsize=FONTSIZE)
        axs[row_ix,col_ix].set_yticklabels(
            axs[row_ix,col_ix].get_yticklabels(), 
            fontsize=FONTSIZE
        )
        axs[row_ix,col_ix].spines['top'].set_visible(False)
        axs[row_ix,col_ix].spines['right'].set_visible(False)
        axs[row_ix,col_ix].spines['bottom'].set_visible(False)
        axs[row_ix,col_ix].spines['left'].set_visible(False)
        axs[row_ix,col_ix].tick_params(axis='x', length=0.) 
        axs[row_ix,col_ix].tick_params(axis='y', length=0.)
        for _f in all_fs:
            axs[row_ix,col_ix].axvspan(
                _f-JITTER-.0015, _f+JITTER+.0015, 
                facecolor='lightgray', alpha=0.5
            )

def plot_power_of_rois_across_cohort_intersection(
    experiment_id, 
    mri_id, 
    sub_ids, 
    roi_task_ids, 
    roi_fs, 
    primary_fs, 
    secondary_fs, 
    task_ids, 
    fo, 
    n_permutations,n_bootstraps,
    JITTER=.25,
    FONTSIZE=4,
    FIGSIZE=(5,2),
    NORMAL_EXPERIMENTS=True,
):

    all_fs = list(set(roi_fs))
    all_fs.sort()

    fig, axs = plt.subplots(ncols=2, nrows=1, figsize=FIGSIZE, dpi=200)

    for col_ix, phaseadjusted in enumerate([True, False]):

        f_mapping = {
            "primary_frequency": axs[col_ix], 
            "secondary_frequency": axs[col_ix], 
            "im_frequency": axs[col_ix],
        }

        primary_mean_across_cohort = defaultdict(list)
        primary_frequency_across_cohort = defaultdict(list)
        for sub_id, roi_task_id, roi_f, primary_f, secondary_f, task_id in zip(
            sub_ids, roi_task_ids, roi_fs, primary_fs, secondary_fs, task_ids
        ):
            if secondary_f > roi_f:
                pkl = Path(f"/scratch/fastfmri/pickles/experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{roi_task_id}_roi-{roi_f}-{secondary_f}_task-{task_id}_fo-{fo}_phaseadjusted-{phaseadjusted}_n-{n_permutations}.pkl")
            else:
                pkl = Path(f"/scratch/fastfmri/pickles/experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_roitask-{roi_task_id}_roi-{secondary_f}-{roi_f}_task-{task_id}_fo-{fo}_phaseadjusted-{phaseadjusted}_n-{n_permutations}.pkl")
            if not pkl.exists():
                print(f"Warning: {pkl} does not exist. Probably means intersection doesn't exist.\nSkipping")
                continue

            pkl_data = read_pickle(pkl)
            #print(pkl, pkl_data.keys())

            plot_x = []
            plot_y = []
            for k, v in pkl_data.items():
                marker = "o"
                mec = 'k'
                ms=4
                zorder=10
                if k == f"test-{primary_f}":
                    label = "primary_frequency"
                elif k == f"test-{secondary_f}":
                    label = "secondary_frequency"
                elif k == f"test-{abs(round(secondary_f-primary_f, 10))}":
                    label = "im_frequency"
                else:
                    raise ValueError(f"{k} does not match {primary_f} or {secondary_f}")
                if label == "secondary_frequency":
                    continue
                    
                ax = f_mapping[label]

                frequency = float(k.split("test-")[-1])
                
                statistic_info = read_statistics(v, n_bootstraps)
                if label == "primary_frequency":
                    primary_mean_across_cohort[sub_id].append(statistic_info['mean'])
                    if frequency == min(all_fs):
                        primary_frequency_across_cohort[sub_id].append(frequency+JITTER)
                    if frequency == max(all_fs):
                        primary_frequency_across_cohort[sub_id].append(frequency-JITTER)

                _accuracy = statistic_info['accuracy']
                if sub_id == "000":
                    print(pkl)
                    print(f"[I] sub-{sub_id}_roi-{roi_task_id}-{roi_f}_task-{task_id}_mainf-{primary_f}_secf-{secondary_f}_{label}, acc: {_accuracy*100:.1f}%")

                if _accuracy == 1:
                    mfc = 'r'
                    zorder+=2
                elif _accuracy < 1. and _accuracy >= .8:
                    mfc = 'gold'
                    zorder+=1
                else:
                    mfc = 'grey'

                offset=0
                s=10
                jitter = np.random.uniform(low=-JITTER, high=JITTER)
                ax.errorbar(
                    frequency+offset+jitter, statistic_info['mean'], yerr=statistic_info['CI'][:,np.newaxis], 
                    marker=marker, ms=ms, c='k', 
                    markeredgewidth=.5, markerfacecolor=mfc, markeredgecolor=mec,
                    lw=.5, capsize=2, capthick=1, zorder=zorder
                )

                plot_x.append(frequency)
                plot_y.append(statistic_info['mean'])

        """CONNECT LINES SAME SUBJECTS
        """
        if NORMAL_EXPERIMENTS:
            for sub_id in primary_frequency_across_cohort.keys():
                ax.plot(
                    primary_frequency_across_cohort[sub_id], primary_mean_across_cohort[sub_id], linestyle=':', lw=1, c='lightgrey', zorder=2
                )
        ax.set_ylabel("PSD", fontsize=FONTSIZE)
        ax.set_xlabel("Frequency", fontsize=FONTSIZE)
        ax.set_title(f"Phaseadjusted-{phaseadjusted}", fontsize=FONTSIZE)

        fig.suptitle(f"{mri_id} / ROI: Intersection / {fo}", fontsize=FONTSIZE)
        fig.tight_layout()

    max_vals = []
    for col_ix in range(2):
        vals = axs[col_ix].get_ylim()
        max_vals.append(vals[-1])
    max_ylim = max(max_vals)

    for col_ix in range(2):
        axs[col_ix].set_ylim(-.0001, max_ylim)
        axs[col_ix].set_xticks(all_fs)
        axs[col_ix].set_xticklabels([str(i) for i in all_fs], fontsize=FONTSIZE)
        axs[col_ix].set_yticklabels(
            axs[col_ix].get_yticklabels(), 
            fontsize=FONTSIZE
        )
        axs[col_ix].spines['top'].set_visible(False)
        axs[col_ix].spines['right'].set_visible(False)
        axs[col_ix].spines['bottom'].set_visible(False)
        axs[col_ix].spines['left'].set_visible(False)
        axs[col_ix].tick_params(axis='x', length=0.) 
        axs[col_ix].tick_params(axis='y', length=0.)
        for _f in all_fs:
            axs[col_ix].axvspan(
                _f-JITTER-.0015, _f+JITTER+.0015, 
                facecolor='lightgray', alpha=0.5
            )

In [None]:
FONTSIZE=4
FIGSIZE_F1F2=(4,3)
FIGSIZE_INTERSECTION=(4,1.5)

import warnings

# To ignore the specific warning
warnings.filterwarnings("ignore")

3T normal

In [None]:
NORMAL_3T_SUB_IDS = ["000", "002", "003", "004", "005", "006", "007", "008", "009"]

for _task_id in ["entrain", "control"]:
    experiment_id = "1_frequency_tagging"
    mri_id = "3T"
    sub_ids = NORMAL_3T_SUB_IDS * 2
    roi_task_ids = ["entrain"] * 2 * len(NORMAL_3T_SUB_IDS)
    roi_fs = [.125, .2] * len(NORMAL_3T_SUB_IDS)
    primary_fs = [.125, .2] * len(NORMAL_3T_SUB_IDS)
    secondary_fs = [.2, .125] * len(NORMAL_3T_SUB_IDS)
    task_ids = [_task_id] * 2 * len(NORMAL_3T_SUB_IDS)
    n_permutations = 1000
    n_bootstraps = 200
    for fo in [.4]:
        for plot_func, FIGSIZE in zip(
            [plot_power_of_rois_across_cohort, plot_power_of_rois_across_cohort_intersection], 
            [FIGSIZE_F1F2, FIGSIZE_INTERSECTION]
        ):

            plot_func(
                experiment_id, 
                mri_id, 
                sub_ids, 
                roi_task_ids, 
                roi_fs, 
                primary_fs, 
                secondary_fs, 
                task_ids, 
                fo, 
                n_permutations,
                n_bootstraps,
                JITTER=.015,
                FONTSIZE=FONTSIZE,
                FIGSIZE=FIGSIZE,
            )

7T normal

In [None]:
NORMAL_7T_SUB_IDS = ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]

experiment_id = "1_attention"
mri_id = "7T"
sub_ids = NORMAL_7T_SUB_IDS * 2
roi_task_ids = ["AttendAway"] * 2 * len(NORMAL_7T_SUB_IDS)
roi_fs = [.125] * 4 + [.2] * 4
primary_fs = [.125] * 4 + [.2] * 4
secondary_fs = [.2] * 4 + [.125] * 4
task_ids = ["AttendAway"] * 2 * len(NORMAL_7T_SUB_IDS)
fo = .8
n_permutations = 1000
n_bootstraps = 200

for plot_func, FIGSIZE in zip(
    [plot_power_of_rois_across_cohort, plot_power_of_rois_across_cohort_intersection], 
    [FIGSIZE_F1F2, FIGSIZE_INTERSECTION]
):
    plot_func(
        experiment_id, 
        mri_id, 
        sub_ids, 
        roi_task_ids, 
        roi_fs, 
        primary_fs, 
        secondary_fs, 
        task_ids, 
        fo, 
        n_permutations,
        n_bootstraps,
        JITTER=.015,
        FONTSIZE=FONTSIZE,
        FIGSIZE=FIGSIZE,
    )

sub-020, vary 3T/7T

In [None]:
experiment_id = "1_frequency_tagging"
mri_ids = ["3T","7T"]
sub_ids = ["020"] * 6
roi_task_ids = ["entrainA", "entrainB", "entrainC"] * 2
roi_fs = [.125, .125, .125, .2, .175, .15]
primary_fs = [.125, .125, .125, .2, .175, .15]
secondary_fs = [.2, .175, .15, .125, .125, .125]
task_ids = ["entrainA", "entrainB", "entrainC"] * 2
fo = .8
n_permutations = 1000
n_bootstraps = 200

for plot_func, FIGSIZE in zip(
    [plot_power_of_rois_across_cohort, plot_power_of_rois_across_cohort_intersection], 
    [FIGSIZE_F1F2, FIGSIZE_INTERSECTION]
):
    for mri_id in mri_ids:
        plot_func(
            experiment_id, 
            mri_id, 
            sub_ids, 
            roi_task_ids, 
            roi_fs, 
            primary_fs, 
            secondary_fs, 
            task_ids, 
            fo, 
            n_permutations,
            n_bootstraps,
            JITTER=.005,
            FONTSIZE=FONTSIZE,
            FIGSIZE=FIGSIZE,
            NORMAL_EXPERIMENTS=False,
        )

sub-021, vary 3T/7T

In [None]:
experiment_id = "1_frequency_tagging"
mri_ids = ["3T","7T"]
sub_ids = ["021"] * 6
roi_task_ids = ["entrainD", "entrainE", "entrainF"] * 2
roi_fs = [.125, .15, .175, .2, .2, .2]
primary_fs = [.125, .15, .175, .2, .2, .2]
secondary_fs = [.2, .2, .2, .125, .15, .175]
task_ids = ["entrainD", "entrainE", "entrainF"] * 2
fo = .8
n_permutations = 1000
n_bootstraps = 200

for plot_func, FIGSIZE in zip(
    [plot_power_of_rois_across_cohort, plot_power_of_rois_across_cohort_intersection], 
    [FIGSIZE_F1F2, FIGSIZE_INTERSECTION]
):
    for mri_id in mri_ids:
        plot_func(
            experiment_id, 
            mri_id, 
            sub_ids, 
            roi_task_ids, 
            roi_fs, 
            primary_fs, 
            secondary_fs, 
            task_ids, 
            fo, 
            n_permutations,
            n_bootstraps,
            JITTER=.005,
            FONTSIZE=FONTSIZE,
            FIGSIZE=FIGSIZE,
            NORMAL_EXPERIMENTS=False,
        )