In [None]:
import nibabel as nib
import numpy as np

import sys
sys.path.append("/opt/wbplot")

from wbplot import dscalar

from IPython.display import Image

from pathlib import Path
from collections import defaultdict
from functools import lru_cache

In [None]:
"""Get HCP labels
"""
dlabel_dir = Path("/opt/app/notebooks/data/dlabels")
hcp_label = dlabel_dir / "Q1-Q6_RelatedValidation210.CorticalAreas_dil_Final_Final_Areas_Group_Colors.32k_fs_LR.dlabel.nii"

_HCP_INFO = !wb_command -file-information {hcp_label}
HCP_LABELS = []
HCP_COUNTER = 0
for i in _HCP_INFO:
    if len(i) == 60 and any(["L_" in i, "R_" in i]):
        hcp_colors = tuple([float(f"0.{k}") for k in [j.split(' ') [0] for j in i.split('0.')][-3:]] + [1])
        if ' R_' in i:
            roi = i.split("_ROI")[0].split(' R_')[1]
            HCP_LABELS.append(f"R_{roi}_ROI")
        if ' L_' in i:
            roi = i.split("_ROI")[0].split(' L_')[1]
            HCP_LABELS.append(f"L_{roi}_ROI")
        HCP_COUNTER += 1

"""Get HCP label coordinates
"""
dscalar_dir = Path("/opt/app/notebooks/data/dscalars")
tmpdir = Path("/tmp")
template_dscalar = dscalar_dir / "S1200.MyelinMap_BC_MSMAll.32k_fs_LR.dscalar.nii"

hcp_mapping = {}
for roi_label in HCP_LABELS:
    out_dscalar = tmpdir / f"{roi_label}.dscalar.nii"
    if out_dscalar.exists():
        hcp_mapping[roi_label] = out_dscalar
        continue
    !wb_command -cifti-label-to-roi {hcp_label} {out_dscalar} -name {roi_label}
    assert out_dscalar.exists(), f"{out_dscalar.stem} does not exist."
    hcp_mapping[roi_label] = out_dscalar
hcp_rois = list(set([k.split('_')[1] for k in hcp_mapping.keys()]))

In [None]:
def convert_to_fractional_overlap(data):

    return data.sum(0) / data.shape[0]

def map_data_to_value(data_list):

    for ix, (k,v) in enumerate(data_list):

        if ix == 0:
            new_data = k.copy() * v
        else:
            new_data += k * v

    return new_data

def get_quadrant_id(mask_path):
    rel_mask_path = mask_path.split("/")[-1]
    idx_start = rel_mask_path.find("Q")
    quadrant_id = rel_mask_path[idx_start:idx_start+2]
    assert quadrant_id in ['Q1', 'Q2'], f"{quadrant_id} not Q1 or Q2"

    return quadrant_id

@lru_cache(maxsize=360)
def read_roi_path(roi_path):
    return nib.load(roi_path).get_fdata()[0,:]

def roi_vertex_counter(
    roi_vertex_count,
    hcp_mapping,
    data,
    q_id,
):

    for roi_label, roi_path in hcp_mapping.items():
        if q_id == "Q1":
            contra = "L_"
        elif q_id == "Q2":
            contra = "R_"
        else:
            raise ValueError(f"{q_id} not supported.")

        if roi_label.startswith(contra):
            roi_label = f"CONTRA_{roi_label[2:-4]}"
        else:
            roi_label = f"IPSI_{roi_label[2:-4]}"

        roi_mask = read_roi_path(roi_path)
        assert roi_mask.shape == data.shape
        
        roi_vertex_count[roi_label].append((roi_mask * data).sum())

    return roi_vertex_count

"""Reimplement when mapping code is done
def condense_roi_info_across_cohort(sub_ids, roi_vertex_count, sub_threshold=.5):
    
    n_sub_per_roi = []

    n_sub_total = len(sub_ids)
    for roi, vertex_count_across_subs in roi_vertex_count.items():
        n_sub_per_roi.append(
            (
                roi, # ROI label
                sum([i>0 for i in vertex_count_across_subs])/n_sub_total, # Number of subjects with vertex in a ROI
                np.median(vertex_count_across_subs), # Median vertex count
            )
        )

    my_list = [i[2] for i in n_sub_per_roi]
    sorted_indices = sorted(range(len(my_list)), key=lambda x: my_list[x], reverse=True)
    
    condensed_roi_info = []
    for i in sorted_indices:
        if n_sub_per_roi[i][1]<sub_threshold:
            continue
        condensed_roi_info.append(n_sub_per_roi[i])

    return condensed_roi_info
"""

def find_activations(experiment_id, mri_id, roi_task_id, roi_f_1, fo, sub_id, corr_type="uncp", match_str="activations.dtseries.nii"):
    
    import os

    directory = f"/scratch/fastfmri/experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-200_batch-merged_desc-IMall_roi-{roi_task_id}-{roi_f_1}_pval-{corr_type}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/"
    activations_files = [file for file in os.listdir(directory) if 'data-train' in file and match_str in file]

    if roi_task_id == "AttendInF1":
        return [f"{directory}{i}" for i in activations_files if "AttendInF1F2" not in i]
    else:
        return [f"{directory}{i}" for i in activations_files]

def set_base_dir(basedir):
    basedir = Path(basedir)
    if not basedir.exists():
        basedir.mkdir(exist_ok=True, parents=True)

    return basedir

"""Not used
def load_mean_power_dtseries(dtseries):
    mean_power = nib.load(dtseries).get_fdata().mean(0)
    return mean_power
"""

def crop_and_save(input_file, output_file, left, top, right, bottom):
    from PIL import Image
    try:
        # Open the input image
        with Image.open(input_file) as img:
            # Crop the image
            cropped_img = img.crop((left, top, right, bottom))
            # Save the cropped image
            cropped_img.save(output_file)
            print("Cropped image saved successfully as", output_file)
    except Exception as e:
        print("An error occurred:", e)

def get_im_frequencies(f1,f2):
    assert f2 > f1, f"{f2} <= {f1}"
    im_frequencies = {}
    im_frequencies["first_order"] = [f1, f2]
    f2_sub_f1 = round(f2-f1, 10)
    f1_plus_f2 = round(f1+f2, 10)
    f1_mul_2 = round(f1*2, 10)
    f2_mul_2 = round(f2*2, 10)
    f1_mul_2_sub_f2 = round(2*f1-f2, 10)
    f2_mul_2_sub_f1 = round(2*f2-f1, 10)
    im_frequencies["f1f2_f2-f1"] = [
        f1, f2,
        f2_sub_f1
    ]
    im_frequencies["f1f2_f1+f2"] = [
        f1, f2,
        f1_plus_f2,
    ]
    im_frequencies["f1f2_2f1"] = [
        f1, f2,
        f1_mul_2,
    ]
    im_frequencies["f1f2_2f2"] = [
        f1, f2,
        f2_mul_2,
    ]
    im_frequencies["f1f2_2f1-f2"] = [
        f1, f2,
        f1_mul_2_sub_f2,
    ]
    im_frequencies["f1f2_2f2-f1"] = [
        f1, f2,
        f2_mul_2_sub_f1,
    ]

    return im_frequencies

def combine_f1f2_with_im(im_order, im_paths, im_c, fo=1., mask=None, mask_c=.41):
    """Example inputs
    all_paths = [f1, f2, im_product, mask]
    im_paths = [f1[0], f2[0], im_product[0]]
    im_cs = [-.1, .82] # red, blue, (overlap is yellow [.14])
    """
    f1_data = convert_to_fractional_overlap(nib.load(im_paths[0]).get_fdata())
    f1_data = (f1_data >= fo).astype(int)
    f2_data = convert_to_fractional_overlap(nib.load(im_paths[1]).get_fdata())
    f2_data = (f2_data >= fo).astype(int)
    f1f2_data = ((f1_data + f2_data) == 2).astype(int)
    im_data = convert_to_fractional_overlap(nib.load(im_paths[2]).get_fdata())
    im_data = (im_data >= fo).astype(int)
    f1f2im_data = ((f1f2_data + im_data) == 2).astype(int)
    im_data -= f1f2im_data
    f1f2_data -= f1f2im_data
    if mask:
        mask_data = convert_to_fractional_overlap(nib.load(mask).get_fdata())
        mask_data = (mask_data >= 1.).astype(int)
        mask_data -= f1f2im_data
        mask_data -= f1f2_data
        mask_data -= im_data
    data_dict = [(f1f2_data, im_c[0]), (im_data, im_c[1]), (f1f2im_data, .14)]
    if mask:
        data_dict.append((mask_data,mask_c))
    return map_data_to_value(data_dict)

def combine_im(
    im_order, im_paths, im_c,
    fo=1.,
    mask=None, 
    mask_c=.41,
):
    
    if im_order.startswith("f1f2"):
        return combine_f1f2_with_im(im_order, im_paths, im_c, fo=fo, mask=mask, mask_c=mask_c)
    
    if im_order == "first_order":
        f1_data = convert_to_fractional_overlap(nib.load(im_paths[0]).get_fdata())
        f1_data = (f1_data >= fo).astype(int)
        f2_data = convert_to_fractional_overlap(nib.load(im_paths[1]).get_fdata())
        f2_data = (f2_data >= fo).astype(int)
        f1f2_data = ((f1_data + f2_data) == 2).astype(int)
        f1_data -= f1f2_data
        f2_data -= f1f2_data
        if mask:
            mask_data = convert_to_fractional_overlap(nib.load(mask).get_fdata())
            mask_data = (mask_data >= 1.).astype(int)
            mask_data -= f1f2_data
            mask_data -= f1_data
            mask_data -= f2_data
        data_dict = [(f1_data, im_c[0]), (f2_data, im_c[1]), (f1f2_data, .14)]
        if mask:
            data_dict.append((mask_data,mask_c))
        return map_data_to_value(data_dict)

def merge_and_binarize_mask(data, im_order, im_c, mask_c=.41):

    if im_order == "first_order":
        new_data = data.copy()
        new_data[(new_data == im_c[0]) | (new_data == im_c[1]) | (new_data == .14)] = 1
        new_data[new_data == mask_c] = 0

        return new_data
    
    if im_order.startswith("f1f2"):
        new_data = data.copy()
        new_data[(new_data == im_c[0]) | (new_data == im_c[1]) | (new_data == .14)] = 1
        new_data[new_data == mask_c] = 0

        return new_data

def generate_single_subject_maps(
    experiment_id, mri_id, sub_ids, 
    roi_task_ids, im_frequencies, 
    ROI_FO=.8, SUB_THRESHOLD=.5,
    LEFT=600, TOP=120, RIGHT=1120, BOTTOM=420, VERTEX_TO = 59412,
    mask_c = .41, PALETTE="power_surf",
    SKIP_IF_EXISTS=True, corr_type="uncp"
):
    roi_vertex_count = defaultdict(list)

    for ix, (sub_id, roi_task_id, im_frequencies) in enumerate(zip(
        sub_ids,
        roi_task_ids, 
        im_frequencies,
    )):
        for im_order, im_f in im_frequencies.items():
            im_order_str = im_order.replace("f1f2_", "f1f2U")
            if im_order == "first_order":
                im_order_str = "f1Uf2"
            png_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/data_exploration/bootstrapped_rois/{experiment_id}/mri-{mri_id}/sub-{sub_id}/figures")) / f"sub-{sub_id}_task-{roi_task_id}-{corr_type}_im-{im_order_str}_fo-{ROI_FO}.png"
            if png_out.exists() and SKIP_IF_EXISTS:
                return None
            mask = find_activations(experiment_id, mri_id, roi_task_id, im_frequencies['first_order'][0], .8, sub_id, corr_type=corr_type, match_str="mask.dtseries.nii")
            if im_order == "first_order":
                try:
                    f1 = find_activations(experiment_id, mri_id, roi_task_id, im_f[0], .8, sub_id, corr_type=corr_type, match_str="activations.dtseries.nii")
                    f2 = find_activations(experiment_id, mri_id, roi_task_id, im_f[1], .8, sub_id, corr_type=corr_type, match_str="activations.dtseries.nii")
                    all_paths = [f1, f2, mask]
                    im_paths = [f1[0], f2[0]]
                    im_cs = [-.1, .82] # red, blue, (overlap is white)
                except:
                    print(f"\n\n\nSkipping: {sub_id} {roi_task_id} {im_f}\n\n\n")
                    continue
            elif im_order.startswith("f1f2"):
                try:
                    f1 = find_activations(experiment_id, mri_id, roi_task_id, im_f[0], .8, sub_id, corr_type=corr_type, match_str="activations.dtseries.nii")
                    f2 = find_activations(experiment_id, mri_id, roi_task_id, im_f[1], .8, sub_id, corr_type=corr_type, match_str="activations.dtseries.nii")
                    im_product = find_activations(experiment_id, mri_id, roi_task_id, im_f[2], .8, sub_id, corr_type=corr_type, match_str="activations.dtseries.nii")
                    all_paths = [f1, f2, im_product, mask]
                    im_paths = [f1[0], f2[0], im_product[0]]
                    im_cs = [-.1, .82] # red, blue, (overlap is white)
                except:
                    print(f"\n\n\nSkipping: {sub_id} {roi_task_id} {im_f}\n\n\n")
                    continue


            for f in all_paths:
                assert len(f) == 1, f"{roi_task_id}, {f}"

            data = combine_im(im_order, im_paths, im_cs, fo=ROI_FO, mask=mask[0], mask_c=mask_c)
            data = data[:VERTEX_TO]
            map_data = merge_and_binarize_mask(data, im_order, im_cs)
            q_id = get_quadrant_id(mask[0])
            roi_vertex_count = roi_vertex_counter(roi_vertex_count, hcp_mapping, map_data, q_id)

            palette_params = {
                "disp-zero": False,
                "disp-neg": True,
                "disp-pos": True,
                "pos-user": (0, 1.),
                "neg-user": (-1,0),
                "interpolate": True,
            }
            dscalar(
                png_out, data, 
                orientation="portrait", 
                hemisphere='right',
                palette=PALETTE, 
                palette_params=palette_params,
                transparent=False,
                flatmap=True,
                flatmap_style='plain',
            )
            #crop_and_save(png_out, str(png_out).replace("png", "cropped.png"), LEFT, TOP, RIGHT, BOTTOM)

    """
    roi_cohort_info = condense_roi_info_across_cohort(
        sub_ids, 
        roi_vertex_count, 
        sub_threshold=SUB_THRESHOLD
    )
    """
    return None
    return roi_cohort_info



In [None]:
"""Set up for visualizing dual frequency tagging across each subject using fractional overlap
"""
cohort_roi_info_across_experiments = {}
ROI_FOS = [.2, .8, 1.]
SKIP_IF_EXISTS=True

"""Save png
"""
for corr_type in ["fdrp", "uncp"]:
    # 3T normal
    f1, f2 = .125, .2
    cohort_roi_info = defaultdict(list)
    for ROI_FO in ROI_FOS:
        experiment_id = "1_frequency_tagging" 
        mri_id = "3T"
        sub_ids = ["000", "002", "003", "004", "005", "006", "007", "008", "009"] 
        roi_task_ids = ["entrain"] * len(sub_ids)
        im_frequencies = [get_im_frequencies(f1, f2)] * len(sub_ids)
        cohort_roi_info = generate_single_subject_maps(
            experiment_id, mri_id, sub_ids, 
            roi_task_ids, im_frequencies, 
            ROI_FO=ROI_FO, SUB_THRESHOLD=.5, corr_type=corr_type,SKIP_IF_EXISTS=SKIP_IF_EXISTS
        )
    # 7T normal
    f1, f2 = .125, .2
    cohort_roi_info = defaultdict(list)
    for ROI_FO in ROI_FOS:
        experiment_id = "1_attention" 
        mri_id = "7T"
        sub_ids = ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]
        roi_task_ids = ["AttendAway"] * len(sub_ids)
        im_frequencies = [get_im_frequencies(f1, f2)] * len(sub_ids)
        cohort_roi_info = generate_single_subject_maps(
            experiment_id, mri_id, sub_ids, 
            roi_task_ids, im_frequencies, 
            ROI_FO=ROI_FO, SUB_THRESHOLD=.5, corr_type=corr_type,SKIP_IF_EXISTS=SKIP_IF_EXISTS
        )
    # 3T/7T vary
    for label, mri_id in zip(["3TVary", "7TVary"], ["3T", "7T"]):
        cohort_roi_info = defaultdict(list)
        for ROI_FO in ROI_FOS:
            experiment_id = "1_frequency_tagging"
            sub_ids = ["020"] * 3 + ["021"] * 3
            roi_task_ids = [f"entrain{i}" for i in ["A", "B", "C", "D", "E", "F"]]
            im_frequencies = [
                get_im_frequencies(.125, .2),
                get_im_frequencies(.125, .175),
                get_im_frequencies(.125, .15),
                get_im_frequencies(.125, .2),
                get_im_frequencies(.15, .2),
                get_im_frequencies(.175, .2),
            ]
            cohort_roi_info = generate_single_subject_maps(
                experiment_id, mri_id, sub_ids, 
                roi_task_ids, im_frequencies, 
                ROI_FO=ROI_FO, SUB_THRESHOLD=.5
            )
    # 7T attention
    f1, f2 = .125, .2
    cohort_roi_info = defaultdict(list)
    for ROI_FO in ROI_FOS:
        experiment_id = "1_attention" 
        mri_id = "7T"
        sub_ids = ["010", "011", "012", "013", "014", "015", "016"] 
        roi_task_ids = ["AttendAway"] * len(sub_ids) + ["AttendInF1"] * len(sub_ids) + ["AttendInF2"] * len(sub_ids) + ["AttendInF1F2"] * len(sub_ids)
        im_frequencies = [get_im_frequencies(f1, f2)] * len(sub_ids) * 4
        sub_ids = ["010", "011", "012", "013", "014", "015", "016"] * 4
        cohort_roi_info = generate_single_subject_maps(
            experiment_id, mri_id, sub_ids, 
            roi_task_ids, im_frequencies, 
            ROI_FO=ROI_FO, SUB_THRESHOLD=.5, corr_type=corr_type,SKIP_IF_EXISTS=SKIP_IF_EXISTS
        )