# Feb 10, 2024: Functional roi creation

In [1]:
import os
import numpy as np
import pandas as pd
import ants
import seaborn as sns

from allensdk.core.mouse_connectivity_cache import (
    MouseConnectivityCache,
    MouseConnectivityApi
)
from allensdk.api.queries.ontologies_api import OntologiesApi

from sklearn.cluster import KMeans
from tqdm import tqdm
from copy import deepcopy

# ignore user warnings
import warnings
warnings.filterwarnings("ignore") #, category=UserWarning)

In [2]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

In [3]:
def to_nifti(args, img, print_=True):
    img = img.transpose(2, 0, 1)
    img = img[:,:,::-1]
    img = np.pad(
        img, 
        pad_width=((2, 2), (4, 24), (8, 2)), 
        mode='constant',
        constant_values=((0, 0), (0, 0), (0, 0))
        )
    if print_: print(img.dtype, img.shape)
    ndims = len(img.shape)
    ants_img = ants.from_numpy(
        data=img.astype(np.float32), 
        origin=[6.4, -13.2, -7.8],
        spacing=[0.1]*ndims,
    )
    return ants_img

In [4]:
args.atlas_path = f'/home/govindas/mouse_dataset/allen_atlas_ccfv3'
args.mcc_path = f'{args.atlas_path}/MouseConnectivity'
mcc = MouseConnectivityCache(
    resolution=100, # in micro meters (um)
    ccf_version=MouseConnectivityApi().CCF_2017,
    manifest_file=f'{args.mcc_path}/manifest.json',
)
AVGT, metaAVGT = mcc.get_template_volume()
ANO, metaANO = mcc.get_annotation_volume()
AVGT = AVGT.astype(np.float32)
ANO = ANO.astype(np.uint32)
print(AVGT.shape, ANO.shape)

STree = mcc.get_structure_tree()
STree_df = pd.DataFrame(STree.nodes()) 
# for idx in STree_df.id.to_list():
#     try: 
#         mcc.get_structure_mask(structure_id=idx) 
#     except:
#         pass

(132, 80, 114) (132, 80, 114)


In [5]:
args.tx_path = f'{args.mcc_path}/parcels'
cmd = (
    f'mkdir -p {args.tx_path}'
)
os.system(cmd)

0

In [6]:
# templates in nifti
n162_100um_template = f'/home/govindas/mouse_dataset/gabe_symmetric_N162/Symmetric_N162_0.10_RAS.nii.gz'
n162_100um_template = ants.image_read(n162_100um_template)
print(n162_100um_template.numpy().dtype, n162_100um_template.numpy().shape)

allen_template = to_nifti(args, AVGT)

float32 (118, 160, 90)
float32 (118, 160, 90)


In [7]:
# for reproducible registration
os.system('export ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS=1')
os.system('export ANTS_RANDOM_SEED=1')

tx = ants.registration(
    fixed=n162_100um_template,
    moving=allen_template,
    type_of_transform=('SyN'),
    random_seed=args.SEED,
)

def transform(args, img):
    img_tx = ants.apply_transforms(
        fixed=n162_100um_template,
        moving=img,
        transformlist=tx['fwdtransforms'],
        interpolator='genericLabel',
    )
    return img_tx
    
allen_template_tx = transform(args, img=allen_template)

In [8]:
# resampling to 0.2mm resolution
n162_200um_template = f'/home/govindas/mouse_dataset/gabe_symmetric_N162/Symmetric_N162_0.20_RAS.nii.gz'
n162_200um_template = ants.image_read(n162_200um_template)

def resample(args, target, img):
    img_rs = ants.resample_image_to_target(
        image=img,
        target=target,
        interp_type='genericLabel',
    )
    img_rs = img_rs.new_image_like(
        data=img_rs.numpy() * (target.numpy() > 0)
    )
    print(img_rs.numpy().shape)
    return img_rs
    
allen_template_tx_rs = resample(args, target=n162_200um_template, img=allen_template_tx)

(60, 81, 46)


In [9]:
# common brain mask (across subs)
BASE_path = f'/home/govindas/mouse_dataset'
all_files_path = f'{BASE_path}/voxel/all_file_collections'
all_files = os.listdir(all_files_path)

# cmask : common brain mask
for idx, files in tqdm(enumerate(all_files[:])):
    if idx == 0:
        with open(f'{all_files_path}/{files}', 'r') as f:
            cmask_img = ants.image_read(f.readlines()[1][:-1])
        cmask = cmask_img.numpy()
    else:
        with open(f'{all_files_path}/{files}', 'r') as f:
            cmask *= ants.image_read(f.readlines()[1][:-1]).numpy()
cmask_img = cmask_img.new_image_like(cmask)
cmask_img.to_filename(
    f'{BASE_path}/voxel/common_brain_mask.nii.gz'
)

116it [00:00, 393.72it/s]


---

In [10]:
# major brain divisions
set_ids = STree.get_structure_sets()
onto_df = pd.DataFrame(
    OntologiesApi().get_structure_sets(set_ids)
)
major_divs_id = onto_df[onto_df['name'] == 'Brain - Major Divisions']['id'].item()
major_divs_df = pd.DataFrame(STree.get_structures_by_set_id([major_divs_id]))
major_divs_df

Unnamed: 0,acronym,graph_id,graph_order,id,name,structure_id_path,structure_set_ids,rgb_triplet
0,Isocortex,1,5,315,Isocortex,"[997, 8, 567, 688, 695, 315]","[2, 112905828, 691663206, 12, 184527634, 11290...","[112, 255, 113]"
1,OLF,1,379,698,Olfactory areas,"[997, 8, 567, 688, 695, 698]","[2, 3, 112905828, 691663206, 12, 184527634, 11...","[154, 210, 189]"
2,HPF,1,454,1089,Hippocampal formation,"[997, 8, 567, 688, 695, 1089]","[2, 112905828, 691663206, 12, 184527634, 11290...","[126, 208, 75]"
3,CTXsp,1,555,703,Cortical subplate,"[997, 8, 567, 688, 703]","[2, 3, 112905828, 691663206, 12, 184527634, 68...","[138, 218, 135]"
4,STR,1,571,477,Striatum,"[997, 8, 567, 623, 477]","[2, 112905828, 691663206, 12, 184527634, 11290...","[152, 214, 249]"
5,PAL,1,608,803,Pallidum,"[997, 8, 567, 623, 803]","[2, 112905828, 691663206, 12, 184527634, 11290...","[133, 153, 204]"
6,TH,1,641,549,Thalamus,"[997, 8, 343, 1129, 549]","[2, 112905828, 691663206, 12, 184527634, 11290...","[255, 112, 128]"
7,HY,1,715,1097,Hypothalamus,"[997, 8, 343, 1129, 1097]","[2, 112905828, 691663206, 12, 184527634, 11290...","[230, 68, 56]"
8,MB,1,806,313,Midbrain,"[997, 8, 343, 313]","[2, 112905828, 691663206, 12, 184527634, 11290...","[255, 100, 255]"
9,P,1,883,771,Pons,"[997, 8, 343, 1065, 771]","[2, 112905828, 691663206, 12, 184527634, 11290...","[255, 155, 136]"


In [11]:
# major divisions
PARCELS_md = np.zeros_like(n162_100um_template.numpy()) # parcels_major_div
for idx, row in tqdm(major_divs_df.iterrows()):
    acro, div_id = row[['acronym', 'id']].to_list()
    DIV, metaDIV = mcc.get_structure_mask(div_id)
    DIV = DIV.astype(np.uint32)
    
    DIV_img = to_nifti(args, DIV, print_=False)
    DIV_img_tx = transform(args, img=DIV_img)
    PARCELS_md += DIV_img_tx.numpy() * (idx+1)

PARCELS_md_tx_img = n162_100um_template.new_image_like(PARCELS_md)
PARCELS_md_tx_rs_img = resample(args, target=n162_200um_template, img=PARCELS_md_tx_img)
PARCELS_md_tx_rs = PARCELS_md_tx_rs_img.numpy()

# save
PARCELS_md_tx_rs_cm_img = resample(args, target=cmask_img, img=PARCELS_md_tx_rs_img)
PARCELS_md_tx_rs_cm_img.to_filename(f'{args.tx_path}/major_divisions.nii.gz')

12it [00:05,  2.34it/s]


(60, 81, 46)
(58, 79, 45)


In [12]:
def separate_hemis(args, mask,):
    # separate hemispheres
    
    nvxl_lr = mask.shape[args.lr_axis]
    coverage = int(np.ceil(nvxl_lr / 2))
    
    # create separate left-right masks
    slices_l = tuple(
        slice(0, coverage) if i == args.lr_axis
        else slice(mask.shape[i])
        for i in range(len(mask.shape))
    )
    slices_r = tuple(
        slice(nvxl_lr - coverage, nvxl_lr)
        if i == args.lr_axis else slice(mask.shape[i])
        for i in range(len(mask.shape))
    )
    
    mask_l = mask.copy().astype(bool)
    mask_r = mask.copy().astype(bool)
    mask_l[slices_r] = 0
    mask_r[slices_l] = 0
    
    # ensure symmetry
    mask_r_full = mask_r.copy()
    mask_l_full = mask_l.copy()
    mask_r_full[slices_l] = np.flip(mask_r[slices_r], axis=args.lr_axis)
    mask_l_full[slices_r] = np.flip(mask_l[slices_l], axis=args.lr_axis)
    mask_sym = np.logical_and(mask_r_full, mask_l_full)

    mask_l = mask_sym.copy().astype(bool)
    mask_r = mask_sym.copy().astype(bool)
    mask_l[slices_r] = 0
    mask_r[slices_l] = 0
    
    return slices_l, slices_r, mask_l, mask_r, nvxl_lr

def do_kmeans(args, mask, ):
    # k-means
    
    nonzero_voxels = list(zip(*np.where(mask.astype(bool))))
    nonzero_voxels = np.reshape(nonzero_voxels, (-1, len(mask.shape))).astype(int)
    
    num_rois = int(len(nonzero_voxels) // args.roi_size)
    num_rois = num_rois if num_rois > 0 else 1
    print(num_rois)

    kmeans = KMeans(
        n_clusters=num_rois,
        init='k-means++',
        random_state=args.SEED,
    ).fit(nonzero_voxels.astype(float))
    rois = kmeans.predict(nonzero_voxels.astype(float)) + 1

    parcels = np.zeros_like(mask, dtype=int)
    for roi in np.unique(rois):
        vxls = nonzero_voxels[rois == roi]
        parcels[tuple(zip(*vxls))] = roi
        
    clust_cntrs = {
        k:v 
        for k, v in zip(range(1, num_rois+1), kmeans.cluster_centers_)
    }
    assert num_rois == len(np.unique(rois))
    return parcels, clust_cntrs, num_rois

def kmeans_parcellation(args, mask, acro):
    (
        slices_l, slices_r, 
        mask_l, mask_r,
        nvxl_lr
    ) = separate_hemis(
        args, mask=mask,
    )
    
    parcels_l, cntrs_l, num_rois_l = do_kmeans(args, mask_l)
    if args.maintain_symmetry:
        parcels_r = np.flip(parcels_l, axis=args.lr_axis).copy() * mask_r
        cntrs_r = {k:np.array([v[0], v[1], nvxl_lr-v[2]]) for k,v in cntrs_l.items()}
        num_rois_r = num_rois_l
    else:
        parcels_r, cntrs_r, num_rois_r = do_kmeans(args, mask_r)
    
    return [
        {'acro':acro, 'hemi':'l', 'parcels': parcels_l, 'cntrs':cntrs_l, 'mask': mask_l, 'num_rois':num_rois_l,},
        {'acro':acro, 'hemi':'r', 'parcels': parcels_r, 'cntrs':cntrs_r, 'mask': mask_r, 'num_rois':num_rois_r,}
    ]

In [13]:
def combine(args, rois_df,):
    # combine kmeans outputs in rois_df
    parcels_all = np.zeros_like(PARCELS_md_tx_rs, dtype=int)
    cntrs_all = {}
    for hemi, group in rois_df.groupby(by='hemi'):
        for idx, row in group.iterrows():
            num_rois = len(np.unique(parcels_all)[1:])
            
            acro, parcels, cntrs, mask = row[['acro', 'parcels', 'cntrs', 'mask']].to_list()
            parcels_all += parcels + mask * num_rois
            
            cntrs = {k+num_rois:cntr for k, cntr in cntrs.items()}
            cntrs_all = {**cntrs_all, **cntrs}            

    num_rois = len(np.unique(parcels_all)[1:])
    print(f'total rois: {num_rois}')
    return parcels_all, num_rois

def save_parcels(args, parcels, base_parcels, cmask_img):
    parcels_name = (
        f'type-functional'
        f'_nrois-{args.num_rois}'
        f'_size-{args.roi_size}'
        f'_symm-{args.maintain_symmetry}'
        f'_braindiv-{args.brain_div}'
        f'_desc-parcels.nii.gz'
    )
    file = f'{args.tx_path}/{parcels_name}'
    PARCELS_img = base_parcels.new_image_like(parcels.astype(np.uint32))
    PARCELS_cm_img = resample(args, target=cmask_img, img=PARCELS_img)
    PARCELS_cm_img.to_filename(file)
    return file

def roi_labels(args, mask_file):
    # file with roi labels
    labels_name = (
        f'type-functional'
        f'_nrois-{args.num_rois}'
        f'_size-{args.roi_size}'
        f'_symm-{args.maintain_symmetry}'
        f'_braindiv-{args.brain_div}'
        f'_desc-labels.txt'
    )
    cmd = (
        f'3dROIstats -overwrite '
        f'-quiet '
        f'-mask {mask_file} '
        f'{mask_file} > {args.tx_path}/{labels_name}'
    )
    os.system(cmd)
    return None

In [14]:
# kmeans within each division
args.lr_axis = 0
args.maintain_symmetry = True
args.roi_size = 600
rois_all = []
for idx, row in (major_divs_df.iterrows()):
    print(row['acronym'])
    mask = PARCELS_md_tx_rs == (idx+1)
    rois_all += kmeans_parcellation(args, mask, row['acronym'])
rois_df = pd.DataFrame(rois_all)

Isocortex
9
OLF
3
HPF
3
CTXsp
1
STR
2
PAL
1
TH
1
HY
1
MB
2
P
1
MY
2
CB
3


In [15]:
# whole brain
print('whole brain')
args.brain_div = 'whl'
parcels, args.num_rois = combine(args, rois_df)
whl_file = save_parcels(args, parcels, PARCELS_md_tx_rs_img, cmask_img)
roi_labels(args, whl_file)

# isocortex
print('isocortex')
args.brain_div = 'ctx'
iso_rois_df = rois_df[rois_df['acro'].isin(['Isocortex'])].reset_index(drop=True)
parcels, args.num_rois = combine(args, iso_rois_df)
ctx_file = save_parcels(args, parcels, PARCELS_md_tx_rs_img, cmask_img)
roi_labels(args, ctx_file)

# subcortex
print('subcortex')
args.brain_div = 'sub'
sub_rois_df = rois_df[~rois_df['acro'].isin(['Isocortex', 'OLF'])].reset_index(drop=True)
parcels, args.num_rois = combine(args, sub_rois_df)
sub_file = save_parcels(args, parcels, PARCELS_md_tx_rs_img, cmask_img)
roi_labels(args, sub_file)

whole brain
total rois: 58
(58, 79, 45)
isocortex
total rois: 18
(58, 79, 45)
subcortex
total rois: 34
(58, 79, 45)
