In [None]:
import nibabel as nib
import numpy as np

import sys
sys.path.append("/opt/wbplot")

from wbplot import dscalar

from IPython.display import Image
import matplotlib.pyplot as plt

from pathlib import Path
import pandas as pd
from collections import defaultdict
from functools import lru_cache

sys.path.append("ComputeCanada/frequency_tagging")
from dfm import (
    get_roi_colour_codes,
    change_font,
)
change_font()

Get HCP info
- `hcp_mappings`: dict of ROI: dscalars
- `hcp_rois`: 

In [None]:
"""Get HCP labels
"""
dlabel_dir = Path("/opt/app/notebooks/data/dlabels")
hcp_label = dlabel_dir / "Q1-Q6_RelatedValidation210.CorticalAreas_dil_Final_Final_Areas_Group_Colors.32k_fs_LR.dlabel.nii"

_HCP_INFO = !wb_command -file-information {hcp_label}
HCP_LABELS = []
HCP_COUNTER = 0
for i in _HCP_INFO:
    if len(i) == 60 and any(["L_" in i, "R_" in i]):
        hcp_colors = tuple([float(f"0.{k}") for k in [j.split(' ') [0] for j in i.split('0.')][-3:]] + [1])
        if ' R_' in i:
            roi = i.split("_ROI")[0].split(' R_')[1]
            HCP_LABELS.append(f"R_{roi}_ROI")
        if ' L_' in i:
            roi = i.split("_ROI")[0].split(' L_')[1]
            HCP_LABELS.append(f"L_{roi}_ROI")
        HCP_COUNTER += 1

"""Get HCP label coordinates
"""
dscalar_dir = Path("/opt/app/notebooks/data/dscalars")
tmpdir = Path("/tmp")
template_dscalar = dscalar_dir / "S1200.MyelinMap_BC_MSMAll.32k_fs_LR.dscalar.nii"

hcp_mapping = {}
for roi_label in HCP_LABELS:
    out_dscalar = tmpdir / f"{roi_label}.dscalar.nii"
    if out_dscalar.exists():
        hcp_mapping[roi_label] = out_dscalar
        continue
    !wb_command -cifti-label-to-roi {hcp_label} {out_dscalar} -name {roi_label}
    assert out_dscalar.exists(), f"{out_dscalar.stem} does not exist."
    hcp_mapping[roi_label] = out_dscalar
hcp_rois = list(set([k.split('_')[1] for k in hcp_mapping.keys()]))

Functions

In [None]:

def convert_to_fractional_overlap(data):

    return data.sum(0) / data.shape[0]

def map_data_to_value(data_list):

    for ix, (k,v) in enumerate(data_list):

        if ix == 0:
            new_data = k.copy() * v
        else:
            new_data += k * v

    return new_data

def combine_f1_f2(f1,f2,fo=1.,mask=None, f1_c=.01, f2_c=.82, f1f2_c=.28, mask_c=.01):
    f1_data = convert_to_fractional_overlap(nib.load(f1).get_fdata())
    f2_data = convert_to_fractional_overlap(nib.load(f2).get_fdata())
    f1_data = (f1_data >= fo).astype(int)
    f2_data = (f2_data >= fo).astype(int)
    f1f2_data = ((f1_data + f2_data) == 2).astype(int)
    f1_data -= f1f2_data
    f2_data -= f1f2_data
    if mask:
        mask_data = convert_to_fractional_overlap(nib.load(mask).get_fdata())
        mask_data = (mask_data >= 1.).astype(int)
        mask_data -= f1f2_data
        mask_data -= f1_data
        mask_data -= f2_data
    data_dict = [(f1_data, f1_c), (f2_data, f2_c), (f1f2_data, f1f2_c)]
    if mask:
        data_dict.append((mask_data,mask_c))

    return map_data_to_value(data_dict)

def get_quadrant_id(mask_path):
    rel_mask_path = mask_path.split("/")[-1]
    idx_start = rel_mask_path.find("Q")
    quadrant_id = rel_mask_path[idx_start:idx_start+2]
    assert quadrant_id in ['Q1', 'Q2'], f"{quadrant_id} not Q1 or Q2"

    return quadrant_id

def merge_and_binarize_mask(data, f1_c, f2_c, f1f2_c, mask_c):
    data_dict = {
        "f1": data.copy(),
        "f2": data.copy(),
        "f1Uf2": data.copy(),
    }
    data_dict["f1"][(data_dict["f1"]==f1_c) | (data_dict["f1"]==f1f2_c)] = 1
    data_dict["f1"][(data_dict["f1"]==f2_c)] = 0
    data_dict["f2"][(data_dict["f2"]==f2_c) | (data_dict["f2"]==f1f2_c)] = 1
    data_dict["f2"][(data_dict["f2"]==f1_c)] = 0
    data_dict["f1Uf2"][(data_dict["f1Uf2"]==f1f2_c)] = 1
    data_dict["f1Uf2"][(data_dict["f1Uf2"]==f1_c) | (data_dict["f1Uf2"]==f2_c)] = 0
    for v in data_dict.values():
        v[v==mask_c] = 0

    return data_dict

@lru_cache(maxsize=360)
def read_roi_path(roi_path):
    return nib.load(roi_path).get_fdata()[0,:]

def append_data(
    df_data,
    hcp_mapping,
    map_data,
    power_f1_data,
    power_f2_data,
    pd_f1_data, 
    pd_f2_data, 
    q_id,
    experiment_label, 
    sub_id, 
    roi_fo,
    roi_task_id,
    task_id,
):
    """Create function to store vertex level data for each HCP ROI:
    - columns: [cohort_id, sub_ids, quadrant_id, hcp_roi, frequency_of_roi, vertex_count, vertex_coordinates, f1_BOLD_power, f2_BOLD_power, f1_phase_delay, f2_phase_delay]
        - roi_fo = region fractional overlap threshold
        - cohort_id = cohort_id of each dataset [3T/7T Normal/Vary]
            - sub_ids = sub_id of all ROIs in cohort
                - quadrant_id = each subject will have a quadrant_id (corresponding to quadrant stimulation)
                - hcp_roi = all HCP ROIs, convert L/R to express laterality
                    - CONTRA/IPSI
                    - frequency_of_roi = each `hcp_roi` will have a ROI corresponding to f1, f2 or both (f1Uf2)
                        - vertex_count = each `frequency_of_roi` will have a vertex_count
                        - vertex_coordinates = each `frequency_of_roi` will have coordinates to all its vertices
                        - f1_BOLD_power = each `frequency_of_roi` will have a np.array of power values corresponding to each vertex
                        - f2_BOLD_power = each `frequency_of_roi` will have a np.array of power values corresponding to each vertex
                        - f1_phase_delay = each `frequency_of_roi` will have a np.array of phase delay values corresponding to each vertex
                        - f2_phase_delay = each `frequency_of_roi` will have a np.array of phase delay values corresponding to each vertex
    """
    for frequency_of_roi, f_data in map_data.items():
        for roi_label, roi_path in hcp_mapping.items():
            if q_id == "Q1":
                contra = "L_"
            elif q_id == "Q2":
                contra = "R_"
            else:
                raise ValueError(f"{q_id} not supported.")

            if roi_label.startswith(contra):
                roi_label = f"CONTRA_{roi_label[2:-4]}"
            else:
                roi_label = f"IPSI_{roi_label[2:-4]}"

            roi_mask = read_roi_path(roi_path)
            assert roi_mask.shape == f_data.shape

            hcp_and_f_roi = roi_mask * f_data
            vertex_coordinates = np.where(hcp_and_f_roi == 1)
            vertex_count = hcp_and_f_roi.sum()
            if vertex_count == 0:
                continue

            if roi_task_id == "control":
                f1_BOLD_power = None
                f2_BOLD_power = None
            else:
                f1_BOLD_power = power_f1_data[hcp_and_f_roi==1]
                f2_BOLD_power = power_f2_data[hcp_and_f_roi==1]
            if task_id != roi_task_id:
                f1_phase_delay = None
                f2_phase_delay = None
            else:
                f1_phase_delay = pd_f1_data[hcp_and_f_roi==1]
                f2_phase_delay = pd_f2_data[hcp_and_f_roi==1]

            df_data["roi_task_id"].append(roi_task_id)
            df_data["task_id"].append(task_id)
            df_data["roi_fo"].append(roi_fo)
            df_data["experiment_id"].append(experiment_label)
            df_data["sub_id"].append(sub_id)
            df_data["quadrant_id"].append(q_id)
            df_data["hcp_roi"].append(roi_label)
            df_data["frequency_of_roi"].append(frequency_of_roi)
            df_data["vertex_count"].append(vertex_count)
            df_data["vertex_coordinates"].append(vertex_coordinates)
            df_data["f1_BOLD_power"].append(f1_BOLD_power)
            df_data["f2_BOLD_power"].append(f2_BOLD_power)
            df_data["f1_phase_delay"].append(f1_phase_delay)
            df_data["f2_phase_delay"].append(f2_phase_delay)

    return df_data

def contains_all_strings(input_str, string_list):
    for string in string_list:
        if string not in input_str:
            return False
    return True

def find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, fo, sub_id, data_split_id="train", match_str="activations.dtseries.nii", additional_match_strs=None, additional_match_str=None, corr_type="uncp"):
    import os
    directory = f"/scratch/fastfmri/experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-200_batch-merged_desc-basic_roi-{roi_task_id}-{roi_f_1}_pval-{corr_type}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/"
    if additional_match_strs is not None:
        match_str = additional_match_strs + [match_str, f"data-{data_split_id}"]
        activations_files = []
        for file in os.listdir(directory):
            if contains_all_strings(file, match_str):
                activations_files.append(file)
    else:
        activations_files = [file for file in os.listdir(directory) if f'data-{data_split_id}' in file and match_str in file]

    return [f"{directory}{i}" for i in activations_files]

def set_base_dir(basedir):
    basedir = Path(basedir)
    if not basedir.exists():
        basedir.mkdir(exist_ok=True, parents=True)

    return basedir

def load_mean_dtseries(dtseries):
    mean_power = nib.load(dtseries).get_fdata().mean(0)
    return mean_power

def generate_single_subject_maps(
    label, experiment_id, mri_id, sub_ids, 
    roi_task_ids, roi_f_1s, roi_f_2s, roi_fo,
    df_data=None,
    corr_type="uncp",
    ROI_FO=.8, SUB_THRESHOLD=.5,
    LEFT=590, TOP=80, RIGHT=1140, BOTTOM=460, VERTEX_TO = 59412,
    FORCE_TASK_ID=None,
):

    if df_data is None:
        df_data = defaultdict(list)

    for ix, (sub_id, roi_task_id, roi_f_1, roi_f_2) in enumerate(zip(
        sub_ids,
        roi_task_ids, 
        roi_f_1s,
        roi_f_2s,
    )):

        if FORCE_TASK_ID is None:
            _roi_task_id = roi_task_id
        else:
            _roi_task_id = FORCE_TASK_ID

        png_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/dual_frequency_mapping")) / f"label-{label}_mri-{mri_id}_sub-{sub_id}_task-{roi_task_id}_f-{roi_f_1}-{roi_f_2}_corr-{corr_type}_fo-{ROI_FO}.png"
        dscalar_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/dual_frequency_mapping_cifti")) / f"label-{label}_mri-{mri_id}_sub-{sub_id}_task-{roi_task_id}_f-{roi_f_1}-{roi_f_2}_corr-{corr_type}_fo-{ROI_FO}.dtseries.nii"
        if png_out.exists():
            #continue
            pass

        f1 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, .8, sub_id, match_str="activations.dtseries.nii", corr_type=corr_type)
        f2 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_2, .8, sub_id, match_str="activations.dtseries.nii", corr_type=corr_type)
        mask = find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, .8, sub_id, match_str="mask.dtseries.nii", corr_type=corr_type)
        pd_f1 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, .8, sub_id, data_split_id = "train", match_str="phasedelay.dtseries.nii", additional_match_strs=[roi_task_id,f"f-{roi_f_1}"], corr_type=corr_type)
        pd_f2 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_2, .8, sub_id, data_split_id = "train", match_str="phasedelay.dtseries.nii", additional_match_strs=[roi_task_id,f"f-{roi_f_2}"], corr_type=corr_type)
        power_f1 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, .8, sub_id, data_split_id = "test", match_str="power.dtseries.nii", additional_match_strs=[_roi_task_id,f"f-{roi_f_1}"], corr_type=corr_type)
        power_f2 = find_activations(experiment_id, mri_id, roi_task_id, roi_f_2, .8, sub_id, data_split_id = "test", match_str="power.dtseries.nii", additional_match_strs=[_roi_task_id,f"f-{roi_f_2}"], corr_type=corr_type)
        for f_label, f in zip(["f1","f2","mask","pd_f1","pd_f2","power_f1","power_f2"], [f1,f2, mask, pd_f1, pd_f2, power_f1, power_f2]):
            if roi_task_id == "control" and experiment_id == "1_frequency_tagging":
                if f_label in ["f1", "f2", "mask"]:
                    assert len(f) == 1, f"{sub_id}, {f_label} - {f}"
            else:
                assert len(f) == 1, f"{sub_id}, {f_label} - {f}, {experiment_id} {roi_task_id}"

        f1, f2 = f1[0], f2[0]
        data = combine_f1_f2(f1, f2, fo=ROI_FO, mask=mask[0], f1_c=f1_c,f2_c=f2_c,f1f2_c=f1f2_c,mask_c=mask_c)
        data = data[:VERTEX_TO]

        map_data = merge_and_binarize_mask(data,f1_c,f2_c,f1f2_c,mask_c)
        pd_f1_data = load_mean_dtseries(pd_f1[0])[:VERTEX_TO]
        pd_f2_data = load_mean_dtseries(pd_f2[0])[:VERTEX_TO]
        if roi_task_id == "control" and experiment_id == "1_frequency_tagging":
            power_f1_data = None
            power_f2_data = None
        # Power metrics were not calculated for control task condition (no voxels allocated to task-control ROIs)
        else:
            power_f1_data = load_mean_dtseries(power_f1[0])[:VERTEX_TO]
            power_f2_data = load_mean_dtseries(power_f2[0])[:VERTEX_TO]
        q_id = get_quadrant_id(mask[0])
        df_data = append_data(
            df_data, 
            hcp_mapping, 
            map_data, 
            power_f1_data, 
            power_f2_data, 
            pd_f1_data, 
            pd_f2_data, 
            q_id,
            label, 
            sub_id,
            roi_fo,
            roi_task_id,
            _roi_task_id, 
        )

        palette_params = {
            "disp-zero": False,
            "disp-neg": True,
            "disp-pos": True,
            "pos-user": (0, 1.),
            "neg-user": (-1,0),
            "interpolate": True,
        }
        # Save f1f2 map as dtseries
        f1_img = nib.load(f1)
        dscalar_to_save_as_cifti = np.zeros((1,f1_img.shape[-1]))
        dscalar_to_save_as_cifti[0,:VERTEX_TO] = data
        f1f2_img = nib.Cifti2Image(dscalar_to_save_as_cifti, header=f1_img.header)
        f1f2_img.header.matrix[0].number_of_series_points = 1
        nib.save(f1f2_img, dscalar_out)
        dscalar(
            png_out, data, 
            orientation="portrait", 
            hemisphere='right',
            palette=PALETTE, 
            palette_params=palette_params,
            transparent=False,
            flatmap=True,
            flatmap_style='plain',
        )
        crop_and_save(png_out, str(png_out).replace("png", "cropped.png"), LEFT, TOP, RIGHT, BOTTOM)
        
        track = [len(v) for k,v in df_data.items()]
        print(track)

    return df_data

def crop_and_save(input_file, output_file, left, top, right, bottom):
    from PIL import Image
    try:
        # Open the input image
        with Image.open(input_file) as img:
            # Crop the image
            cropped_img = img.crop((left, top, right, bottom))
            # Save the cropped image
            cropped_img.save(output_file)
            print("Cropped image saved successfully as", output_file)
    except Exception as e:
        print("An error occurred:", e)

Save visualizations and create DataFrame storing data for simple analysis
- Datasets
    - `3TNormal` ($f_1$=.125, $f_2$=.2)
    - `7TNormal` ($f_1$=.125, $f_2$=.2)
    - `3TVary` varying frequencies 
    - `7TVary` varying frequencies
- Region info `df.frequency_of_roi` and `df.hcp_roi`
    - Region data of $f_1$ *include* $f_1$&$f_2$ intersected vertices, same goes for $f_2$
    - roi choice includes `["f1","f2","f1Uf2"]`, where `f1Uf2` denotes intersected vertices
- Data includes
    - Note: to load `fdrp` (from `uncp`) corrected data change the `corr_type` variable in the cell below
    - `df.roi_task_id` task used to generate the ROI
    - `df.roi_fo` fractional overlap used to generate the ROI
    - ...
    - `df.hcp_roi` HCP ROI used to filter data from
    - `df.frequency_of_roi` frequency of ROI used to filter data from (related to the frequency of `df.roi_task_id`)
    - `df.vertex_count` total vertices in the HCP ROI & identified with the frequency of the task
    - `df.vertex_coordinates` coordinates on a 32k_fs_LR surface (consistent with file structure of `template_dscalar`)
    - `df.f1_[BOLD_power,phase_delay]` vertex-wise power extracted from $f_1$ region (this is set to 0 for $f_2$-only regions)
    - `df.f2_[BOLD_power,phase_delay]` vertex-wise power extracted from $f_2$ region (this is set to 0 for $f_1$-only regions)
        - `f1Uf2` regions contain both $f_1$ and $f_2$ metrics


In [None]:
"""Set up for visualizing dual frequency tagging across each subject using fractional overlap
"""
PALETTE = "power_surf"
f1_c = -.1 # red -.1
f2_c = .82 # blue .82
f1f2_c = .14 # white .88 yellow .1
mask_c = .41 # .9 [green], .1 [goldish], .4 [black]

cohort_roi_info_across_experiments = {}
ROI_FOS = [.8,1.]
corr_type = "uncp"

"""Save png
"""
# 3T control under entrain condition (set this to get power measurements with entrain ROIs)
label = "3TNormal"
df_data = None
for _roi_task_id in ["entrain"]:
    for ROI_FO in ROI_FOS:
        experiment_id = "1_frequency_tagging" 
        mri_id = "3T"
        sub_ids = ["000", "002", "003", "004", "005", "006", "007", "008", "009"] 
        roi_task_ids = [_roi_task_id] * len(sub_ids)
        roi_f_1s = [.125] * len(sub_ids)
        roi_f_2s = [.2] * len(sub_ids)
        df_data = generate_single_subject_maps(
            label, experiment_id, mri_id, sub_ids, 
            roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
            df_data=df_data,
            corr_type=corr_type,
            ROI_FO=ROI_FO, SUB_THRESHOLD=.5,
            FORCE_TASK_ID="control"
        )
# 3T normal
label = "3TNormal"
for _roi_task_id in ["entrain"]:
    for ROI_FO in ROI_FOS:
        experiment_id = "1_frequency_tagging" 
        mri_id = "3T"
        sub_ids = ["000", "002", "003", "004", "005", "006", "007", "008", "009"] 
        roi_task_ids = [_roi_task_id] * len(sub_ids)
        roi_f_1s = [.125] * len(sub_ids)
        roi_f_2s = [.2] * len(sub_ids)
        df_data = generate_single_subject_maps(
            label, experiment_id, mri_id, sub_ids, 
            roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
            df_data=df_data,
            corr_type=corr_type,
            ROI_FO=ROI_FO, SUB_THRESHOLD=.5
        )
# 7T normal
label = "7TNormal"
for ROI_FO in ROI_FOS:
    experiment_id = "1_attention" 
    mri_id = "7T"
    sub_ids = ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]
    roi_task_ids = ["AttendAway"] * len(sub_ids)
    roi_f_1s = [.125] * len(sub_ids)
    roi_f_2s = [.2] * len(sub_ids)
    df_data = generate_single_subject_maps(
        label, experiment_id, mri_id, sub_ids, 
        roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
        df_data=df_data,
        corr_type=corr_type,
        ROI_FO=ROI_FO, SUB_THRESHOLD=.5
    )
# 3T vary
label = "3TVary"
for ROI_FO in ROI_FOS:
    experiment_id = "1_frequency_tagging"
    mri_id = "3T"
    sub_ids = ["020"] * 3 + ["021"] * 3
    roi_task_ids = [f"entrain{i}" for i in ["A", "B", "C", "D", "E", "F"]]
    roi_f_1s = [.125] * 3 + [.125, .15, .175]
    roi_f_2s = [.2, .175, .15] + [.2] * 3
    df_data = generate_single_subject_maps(
        label, experiment_id, mri_id, sub_ids, 
        roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
        df_data=df_data,
        corr_type=corr_type,
        ROI_FO=ROI_FO, SUB_THRESHOLD=.5
    )
# 7T vary
label = "7TVary"
for ROI_FO in ROI_FOS:
    experiment_id = "1_frequency_tagging"
    mri_id = "7T"
    sub_ids = ["020"] * 3 + ["021"] * 3
    roi_task_ids = [f"entrain{i}" for i in ["A", "B", "C", "D", "E", "F"]]
    roi_f_1s = [.125] * 3 + [.125, .15, .175]
    roi_f_2s = [.2, .175, .15] + [.2] * 3
    df_data = generate_single_subject_maps(
        label, experiment_id, mri_id, sub_ids, 
        roi_task_ids, roi_f_1s, roi_f_2s, ROI_FO,
        df_data=df_data,
        corr_type=corr_type,
        ROI_FO=ROI_FO, SUB_THRESHOLD=.5
    )

df = pd.DataFrame(df_data)

from IPython.display import clear_output
clear_output()
