In [1]:
import itertools
import nibabel as nib
import numpy as np
from scipy.stats import zscore
from typing import Dict, List, Tuple
from pathlib import Path
    
# Set-up
mouse_template_dir = Path("/opt/animalfmritools/animalfmritools/data_template/MouseABA")
TEMPLATE_LABEL_NIFTI = mouse_template_dir / "P56_Annotation_downsample2.nii.gz"
roi_dir = mouse_template_dir / "rois"
if not roi_dir.exists():
    roi_dir.mkdir()

"""
ABA parser
"""
def load_json(json_path: Path) -> Dict:
    import json
    
    with open(json_path, 'r') as f:
        data = json.load(f)

    return data

def assert_keys(check_dict: Dict, list_keys: List[str]) -> None:

    for k in check_dict.keys():
        assert k in list_keys

def organize_substructures(aba_dict: Dict, target_key: str, expected_keys: List[str]) -> Dict:
    sub_dict = organize_main_structures(aba_dict)[target_key]['children']
    organized_dict = {item['name']: item for item in sub_dict}
    assert_keys(organized_dict, expected_keys)
    return organized_dict

EXPECTED_KEYS_MAIN = [
    'Basic cell groups and regions',
    'fiber tracts',
    'ventricular systems',
    'grooves',
    'retina'
]

EXPECTED_KEYS_GM = [
    "Cerebrum",
    "Brain stem",
    "Cerebellum",
]

EXPECTED_KEYS_WM = [
    "cranial nerves",
    "cerebellum related fiber tracts",
    "supra-callosal cerebral white matter",
    "lateral forebrain bundle system",
    "extrapyramidal fiber systems",
    "medial forebrain bundle system",
]

EXPECTED_KEYS_VS = [
    "lateral ventricle",
    "interventricular foramen",
    "third ventricle",
    "cerebral aqueduct",
    "fourth ventricle",
    "central canal, spinal cord/medulla",
]

def organize_main_structures(aba_dict: Dict) -> Dict:
    org = aba_dict['msg'][0]['children']
    organized_dict = {item['name']: item for item in org}
    assert_keys(organized_dict, EXPECTED_KEYS_MAIN)
    return organized_dict

def organize_gm_structures(aba_dict: Dict) -> Dict:
    return organize_substructures(aba_dict, 'Basic cell groups and regions', EXPECTED_KEYS_GM)

def organize_wm_structures(aba_dict: Dict) -> Dict:
    return organize_substructures(aba_dict, 'fiber tracts', EXPECTED_KEYS_WM)

def organize_vs_structures(aba_dict: Dict) -> Dict:
    return organize_substructures(aba_dict, 'ventricular systems', EXPECTED_KEYS_VS)

def extract_levels(
    children: Dict, 
    extract_level: int,
    level: int = 1, 
    graph_idxs: Dict = {},
    main_key = None,
    verbose = False,
):

    from copy import deepcopy

    graph_idxs = deepcopy(graph_idxs)
    
    for child in children:
        k_tuple = (child['name'], child['acronym'])
        v_tuple = (child['name'], child['acronym'], child['graph_order'], child['color_hex_triplet'], child['st_level'], child['ontology_id'])
        if extract_level == level:
            graph_idxs[k_tuple] = [v_tuple]
            main_key = k_tuple
            tracker = 'x'
        elif level > extract_level:
            graph_idxs[main_key].append(v_tuple)
            tracker = '>'
        else:
            tracker = 'o'

        if verbose:
            print(f"{tracker} [{str(child['graph_order']).zfill(4)}] {'-'*level} {child['name']} [{child['acronym']}] {level}")
            
        if 'children' in child:
            graph_idxs = extract_levels(child['children'], extract_level, level + 1, graph_idxs, main_key, verbose)

    return graph_idxs

def check_template_for_idx(template_idx: int, template_nifti: str = TEMPLATE_LABEL_NIFTI) -> bool:
    data = nib.load(template_nifti).get_fdata()
    n_voxels_with_idx = np.where(data==template_idx)[0].shape[0]
    if n_voxels_with_idx > 0:
        return True
    else:
        return False

def get_template_coords_from_idx(template_idx: int, template_nifti: str = TEMPLATE_LABEL_NIFTI) -> Tuple:
    data = nib.load(template_nifti).get_fdata()
    coords = np.where(data == template_idx)

    return coords

def save_template_roi(
    parent_structure_label: str,
    roi_label: str,
    template_coords: Tuple, 
    outdir: Path,
    template_nifti: str = TEMPLATE_LABEL_NIFTI
) -> None:

    
    template_img = nib.load(template_nifti)
    template_data = template_img.get_fdata()
    roi_data = np.zeros(template_data.shape)
    roi_data[(template_coords[0,:], template_coords[1,:], template_coords[2,:])] = 1
    roi_img = nib.Nifti1Image(
        roi_data, 
        header = template_img.header,
        affine = template_img.affine,
    )
    output_path = outdir / f"P56_desc-{parent_structure_label}_roi-{roi_label}.nii.gz"
    if not Path(output_path).exists():
        print(f"Saving to {output_path}.")
        nib.save(roi_img, output_path)
    else:
        print(f"{output_path} already exists.")
    
def parse_children(
    children, 
    level = 1, 
    parent_structure = None, 
    previous_structure_label = None, 
    previous_structure_name = None,
    structure_mapping = None, 
    roi_hierarchy = None,
    outdir: Path = None
):
    
    if structure_mapping is None:
        structure_mapping = dict()

    if roi_hierarchy is None:
        roi_hierarchy = dict()

    if len(children) > 0:
        #print('ROI includes: ')
        exist_first = False
        create_roi  = False
        for c in children:

            # Fill structure mapping
            structure_mapping[c['acronym']] = c['name']

            # Fill roi hierarchy mapping
            if previous_structure_label is not None:
                if previous_structure_label not in roi_hierarchy.keys():
                    roi_hierarchy[previous_structure_label] = [c['acronym']]
                else:
                    roi_hierarchy[previous_structure_label].append(c['acronym'])

            # Check if nifti label exists in the atlas
            nifti_label = c['graph_order']
            label_exists = check_template_for_idx(nifti_label)
            # Print info
            if label_exists:
                create_roi = True
                # print(f"{'-'*level} [{c['acronym']}] {c['name']} {c['color_hex_triplet']} {label_exists} || PRIOR: [{previous_structure_label}] {previous_structure_name}")
                coords = get_template_coords_from_idx(nifti_label)
                if not exist_first:
                    joined_coords = np.vstack(coords)
                    exist_first = True
                else:
                    joined_coords = np.concatenate((coords, joined_coords), axis=1)
            else:
                pass
            structure_mapping, roi_hierarchy = parse_children(
                c['children'], 
                level = level + 1, 
                parent_structure = parent_structure, 
                previous_structure_label = c['acronym'], 
                previous_structure_name = c['name'],
                structure_mapping = structure_mapping,
                roi_hierarchy = roi_hierarchy,
                outdir = outdir
            )

        if create_roi:
            save_template_roi(parent_structure, previous_structure_label, joined_coords, outdir)

    return structure_mapping, roi_hierarchy

"""
ABA ROI sorter
"""
def get_roi_path(roi_dir, roi_acronym, parent_acronym = 'CH'):

    roi_path = roi_dir / f"P56_desc-{parent_acronym}_roi-{roi_acronym}.nii.gz"
    if roi_path.exists():
        return roi_path

def search_for_roi_paths(roi_dir, all_label_hierarchies, main_k, sub_k, parent_acronym, roi_paths = None):
    
    all_keys = [i for i in all_label_hierarchies[main_k].keys()]

    if roi_paths is None:
        roi_paths = []
    
    for s in all_label_hierarchies[main_k][sub_k]:
        roi_path = get_roi_path(roi_dir, s, parent_acronym)
        if s not in all_keys:
            continue
            
        if roi_path is not None:
            roi_paths.append(roi_path)
            roi_paths = search_for_roi_paths(roi_dir, all_label_hierarchies, main_k, s, parent_acronym, roi_paths = roi_paths)
        else:
            roi_paths = search_for_roi_paths(roi_dir, all_label_hierarchies, main_k, s, parent_acronym, roi_paths = roi_paths)

    return roi_paths


"""
Atlas functions
"""
def extract_roi_name(nifti_path: Path) -> str:
    
    nifti_path = str(nifti_path)
    assert '_roi-' in nifti_path
    
    return str(nifti_path).split('_roi-')[1].split('.nii.gz')[0]

def get_empty_template(nifti_path: str) -> np.ndarray:

    return np.zeros(nib.load(nifti_path).shape)

def create_roi_array(roi_index, roi_path, hemi_path, hemi_idx) -> np.ndarray:

    # Load data
    roi_data = (nib.load(roi_path).get_fdata() > 0).astype(int)
    hemi_data = nib.load(hemi_path).get_fdata()
    
    # Remove opposite hemisphere depending on label
    hemi_data[np.where(hemi_data != hemi_idx)] = 0
    hemi_data[np.where(hemi_data == hemi_idx)] = 1
    
    # Filter roi data with hemisphere
    roi_data = roi_data * hemi_data

    # Label roi data
    roi_data *= roi_index
    
    return roi_data

def create_atlas(roi_nifti_list: List, hemi_nifti: str, out_dir: str = "/tmp/mouse_atlas.nii.gz"):
    
    # `hemi_nifti` annotates RH with idx == 2 and LH with idx == 1
    hemi_mapping = {
        'RH': 2,
        'LH': 1,
    }

    # Instantiate empty atlas array (to be filled in)
    merged_atlas = get_empty_template(roi_nifti_list[0])
    atlas_annotations = [] # Set up empty annotation list
    
    for ix, (roi_nifti, hemi_label) in enumerate(itertools.product(roi_nifti_list, hemi_mapping.keys())):

        ix = ix + 1
        hemi_idx = hemi_mapping[hemi_label]
        
        # All niftis must be the same shape as the empty atlas array
        assert nib.load(roi_nifti).shape == merged_atlas.shape

        # Add reannotated ROI to a single array
        merged_atlas += create_roi_array(ix, roi_nifti, hemi_nifti, hemi_idx)
        # Add label to atlas annotations
        roi_label = extract_roi_name(roi_nifti)
        atlas_annotations.append( (ix, f"{roi_label}_{hemi_label}") )

    # Round
    merged_atlas = np.round(merged_atlas)
    
    # Save
    template_img = nib.load(roi_nifti_list[0])
    merged_atlas_img = nib.Nifti1Image(
        merged_atlas, 
        affine=template_img.affine, 
        header = template_img.header
    )
    nib.save(merged_atlas_img, out_dir)

    return atlas_annotations, out_dir

def check_labels_in_atlas(atlas_annots, atlas_path):

    import nibabel as nib
    atlas_data = np.round(nib.load(atlas_path).get_fdata())
    unique_labels = np.unique(atlas_data)
    for roi_idx, roi_label in atlas_annots:
        if roi_idx not in unique_labels:
            print(f"[WARNING] {roi_label} [{roi_idx}] not in atlas.")

def calculate_mean_tsnr(timeseries: np.ndarray) -> float:
    _values = timeseries.mean(axis=1) / timeseries.std(axis=1) 
    # Create a boolean mask to identify NaN values
    nan_mask = np.isnan(_values)
    # Filter out NaN values 
    _values = _values[~nan_mask]
    mean_value = _values.mean()

    return mean_value

def alff_and_falff(x, sampling_rate, low_freq_range = (.01, .1)):

    from scipy.signal import welch

    fs, power = welch(x, fs = sampling_rate)
    # get indices corresponding to `low_freq_range`
    low_freq_idx = np.where((fs >= low_freq_range[0]) & (fs <= low_freq_range[1]))
    # calculate ALFF: square root of the average power in the low-frequency range
    alff = np.sqrt(np.mean(power[low_freq_idx]))
    # calculate fALFF: ALFF / total power across all frequencies
    total_power = np.sum(power)
    falff = alff / total_power

    return alff, falff

def lag_1_ta(x):

    return np.corrcoef(x[0:-1], x[1:])[0,1]


def calculate_correlation_matrix(bold_nifti, atlas_nifti, atlas_annots, TR):

    # Load niftis
    bold_img = nib.load(bold_nifti)
    atlas_img = nib.load(atlas_nifti)
    
    # Check whether images have same coordinate system
    assert np.allclose(bold_img.affine, atlas_img.affine)

    bold_data = bold_img.get_fdata()
    atlas_data = np.round(atlas_img.get_fdata())

    # Extract average timeseries from all labels
    label_timeseries, label_tsnr, labels = [], [], []
    label_alff, label_falff, label_ta_lag_1 = [], [], []
    for atlas_idx, atlas_labels in atlas_annots:
        atlas_coords = np.where(atlas_data == atlas_idx)
        label_all_timeseries = bold_data[atlas_coords]
        label_avg_timeseries = np.mean(label_all_timeseries, axis=0)
        alff, falff = alff_and_falff(label_avg_timeseries, TR)
        ta_lag_1 = lag_1_ta(label_avg_timeseries)
        label_alff.append(alff) # append alff
        label_falff.append(falff) # append falff
        label_ta_lag_1.append(ta_lag_1) # append temporal autocorrelation
        label_avg_timeseries = zscore(label_avg_timeseries) # z-score normalization
        label_timeseries.append(label_avg_timeseries) # store all z-scored timeseries -> Compute correlation matrix, outside of loop
        labels.append(atlas_labels) # store labels
        label_tsnr.append( calculate_mean_tsnr(label_all_timeseries) ) # append tsnr

    # coorelation matrix
    C = np.corrcoef(label_timeseries)
    
    return C, labels, label_tsnr, label_alff, label_falff, label_ta_lag_1

In [2]:
# Parse `json_file`
json_file = mouse_template_dir / "ABA_ontology.json"
aba_onto = load_json(json_file)
main = organize_main_structures(aba_onto)
gm = organize_gm_structures(aba_onto)
wm = organize_wm_structures(aba_onto)
vs = organize_vs_structures(aba_onto)

# Generate ROIs and store to `roi_dir`
all_label_mappings = {}
all_label_hierarchies = {}
for structure in [gm, wm, vs]:
    for structure_ix, (k, v) in enumerate(structure.items()):
        print(f"[{structure_ix + 1}/{len(structure)}] {v['name']} {v['acronym']}")
        parent_label = f"{v['name']} {v['acronym']}"
        label_mapping, label_hierarchy = parse_children(v['children'], level=1, parent_structure=v['acronym'], outdir=roi_dir)
        all_label_mappings[parent_label] = label_mapping
        all_label_hierarchies[parent_label] = label_hierarchy

# Remove this ROI - it is small and does not include DG or CA
# MM, CN, VNC was not allocated any voxels after resampling atlas to EPI space
!rm {roi_dir}/*roi-HIP* {roi_dir}/*roi-HY* {roi_dir}/*roi-PB* {roi_dir}/*roi-AMB* {roi_dir}/*roi-MY-sat*


# Get list of ROI paths
# Cerebrum
k = "Cerebrum CH"
isocortex = search_for_roi_paths(roi_dir, all_label_hierarchies, k, "Isocortex", "CH")
olf = search_for_roi_paths(roi_dir, all_label_hierarchies, k, "OLF", "CH")
hpf = search_for_roi_paths(roi_dir, all_label_hierarchies, k, "HPF", "CH")
# Brain stem
k = 'Brain stem BS'
interbrain = search_for_roi_paths(roi_dir, all_label_hierarchies, k, "IB", "BS")
midbrain = search_for_roi_paths(roi_dir, all_label_hierarchies, k, "MB", "BS")
hindbrain = search_for_roi_paths(roi_dir, all_label_hierarchies, k, "HB", "BS")
# Cerebellum
k = 'Cerebellum CB'
cerebellar_cortex = search_for_roi_paths(roi_dir, all_label_hierarchies, k, "CBX", "CB")

# Create atlas
hemi_nifti = "/opt/animalfmritools/animalfmritools/data_template/MouseABA/roi-hemispheres_space-P56_downsample2.nii.gz" # RH == 2 & LH == 1
atlas_annots, atlas_nifti = create_atlas(
    isocortex + hpf + interbrain + midbrain + hindbrain, 
    hemi_nifti,
)

[1/3] Cerebrum CH
/opt/animalfmritools/animalfmritools/data_template/MouseABA/rois/P56_desc-CH_roi-FRP.nii.gz already exists.
/opt/animalfmritools/animalfmritools/data_template/MouseABA/rois/P56_desc-CH_roi-MOp.nii.gz already exists.
/opt/animalfmritools/animalfmritools/data_template/MouseABA/rois/P56_desc-CH_roi-MOs.nii.gz already exists.
/opt/animalfmritools/animalfmritools/data_template/MouseABA/rois/P56_desc-CH_roi-SSp-n.nii.gz already exists.
/opt/animalfmritools/animalfmritools/data_template/MouseABA/rois/P56_desc-CH_roi-SSp-bfd.nii.gz already exists.
/opt/animalfmritools/animalfmritools/data_template/MouseABA/rois/P56_desc-CH_roi-SSp-ll.nii.gz already exists.
/opt/animalfmritools/animalfmritools/data_template/MouseABA/rois/P56_desc-CH_roi-SSp-m.nii.gz already exists.
/opt/animalfmritools/animalfmritools/data_template/MouseABA/rois/P56_desc-CH_roi-SSp-ul.nii.gz already exists.
/opt/animalfmritools/animalfmritools/data_template/MouseABA/rois/P56_desc-CH_roi-SSp-tr.nii.gz already e

In [3]:
bolds = !ls /opt/animalfmritools/animalfmritools/data/MouseAD/bids/derivatives/bold_preproc/sub-*/ses-*/func/*desc-denoised_bold.nii.gz

TR = 1.5
for bold_ix, bold in enumerate(bolds):

    bold = Path(bold)
    sub_id = bold.stem.split('sub-')[1].split('_')[0]
    ses_id = bold.stem.split('ses-')[1].split('_')[0]
    run_id = bold.stem.split('run-')[1].split('_')[0]

    resliced_atlas_path = Path(f"/tmp/sub-{sub_id}_ses-{ses_id}_atlas.nii.gz")
    
    # Reslice atlas to subject's templace space
    if not resliced_atlas_path.exists():
        print(f"Reference: {bold.stem}")
        !flirt -in {atlas_nifti} -ref {bold} -out {resliced_atlas_path} -interp nearestneighbour -applyxfm -usesqform

    if run_id == '01':
        print(resliced_atlas_path)
        check_labels_in_atlas(atlas_annots, resliced_atlas_path)

    C, labels, label_tsnr, label_alff, label_falff, label_ta_lag_1 = calculate_correlation_matrix(bold, resliced_atlas_path, atlas_annots, TR)

/tmp/sub-FGD3149F1_ses-20211208_atlas.nii.gz
Reference: sub-FGD3159F2_ses-20220222_task-rest_dir-AP_run-01_space-template_desc-denoised_bold.nii
/tmp/sub-FGD3159F2_ses-20220222_atlas.nii.gz
