# Jan 31 2024: generate parcellation: in Allen then transform to N162

REFERENCE:
https://allensdk.readthedocs.io/en/latest/_static/examples/nb/mouse_connectivity.html?highlight=major#

In [1]:
import os
import numpy as np 
import pandas as pd
import ants
import seaborn as sns 

from allensdk.core.mouse_connectivity_cache import (
    MouseConnectivityCache,
    MouseConnectivityApi
)
from allensdk.api.queries.ontologies_api import OntologiesApi

from sklearn.cluster import KMeans
from tqdm import tqdm

# ignore user warnings
import warnings
warnings.filterwarnings("ignore") #, category=UserWarning)

In [2]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

In [3]:
def to_nifti(args, img):
    img = img.transpose(2, 0, 1)
    img = img[::-1,:,::-1]
    img = np.pad(
        img, 
        pad_width=((2, 2), (4, 24), (8, 2)), 
        mode='constant',
        constant_values=((0, 0), (0, 0), (0, 0))
        )
    print(img.dtype, img.shape)
    ndims = len(img.shape)
    ants_img = ants.from_numpy(
        data=img.astype(np.float32), 
        origin=[6.4, -13.2, -7.8],
        spacing=[0.1]*ndims,
    )
    return ants_img

In [4]:
args.atlas_path = f'/home/govindas/mouse_dataset/allen_atlas_ccfv3'
args.mcc_path = f'{args.atlas_path}/MouseConnectivity'
mcc = MouseConnectivityCache(
    resolution=100,
    ccf_version=MouseConnectivityApi().CCF_2017,
    manifest_file=f'{args.mcc_path}/manifest.json',
)
AVGT, metaAVGT = mcc.get_template_volume()
ANO, metaANO = mcc.get_annotation_volume()
AVGT = AVGT.astype(np.float32)
ANO = ANO.astype(np.uint32)
print(AVGT.shape, ANO.shape)

STree = mcc.get_structure_tree()
STree_df = pd.DataFrame(STree.nodes()) 
# for idx in STree_df.id.to_list():
#     try: 
#         mcc.get_structure_mask(structure_id=idx) 
#     except:
#         pass

(132, 80, 114) (132, 80, 114)


In [5]:
args.out_path = f'{args.mcc_path}/parcels'
cmd = (
    f'mkdir -p {args.out_path}'
)
os.system(cmd)

0

In [6]:
def separate_hemis(args, mask, lr_axis=2):
    # separate hemispheres
    
    nvxl_lr = mask.shape[lr_axis]
    coverage = int(np.ceil(nvxl_lr / 2))
    
    # create separate left-right masks
    slices_l = tuple(
        slice(0, coverage) if i == lr_axis
        else slice(mask.shape[i])
        for i in range(len(mask.shape))
    )
    slices_r = tuple(
        slice(nvxl_lr - coverage, nvxl_lr)
        if i == lr_axis else slice(mask.shape[i])
        for i in range(len(mask.shape))
    )
    mask_l = mask.copy().astype(bool)
    mask_r = mask.copy().astype(bool)
    mask_l[slices_r] = 0
    mask_r[slices_l] = 0
    
    # ensure symmetry
    mask_r_full = mask_r.copy()
    mask_l_full = mask_l.copy()
    mask_r_full[slices_l] = np.flip(mask_r[slices_r], axis=lr_axis)
    mask_l_full[slices_r] = np.flip(mask_l[slices_l], axis=lr_axis)
    mask_sym = np.logical_and(mask_r_full, mask_l_full)

    mask_l = mask_sym.copy().astype(bool)
    mask_r = mask_sym.copy().astype(bool)
    mask_l[slices_r] = 0
    mask_r[slices_l] = 0
    
    return slices_l, slices_r, mask_l, mask_r, nvxl_lr

def do_kmeans(args, mask, ):
    # k-means
    
    nonzero_voxels = list(zip(*np.where(mask.astype(bool))))
    nonzero_voxels = np.reshape(nonzero_voxels, (-1, len(mask.shape))).astype(int)
    
    num_rois = int(len(nonzero_voxels) // args.roi_size)
    num_rois = num_rois if num_rois > 0 else 1

    kmeans = KMeans(
        n_clusters=num_rois,
        init='k-means++',
        random_state=args.SEED,
    ).fit(nonzero_voxels.astype(float))
    rois = kmeans.predict(nonzero_voxels.astype(float)) + 1

    parcels = np.zeros_like(mask, dtype=int)
    for roi in np.unique(rois):
        vxls = nonzero_voxels[rois == roi]
        parcels[tuple(zip(*vxls))] = roi
        
    clust_cntrs = {
        k:v 
        for k, v in zip(range(1, num_rois+1), kmeans.cluster_centers_)
    }
    return parcels, clust_cntrs, num_rois

def kmeans_parcellation(args, mask, ):
    lr_axis = 2
    (
        slices_l, slices_r, 
        mask_l, mask_r,
        nvxl_lr
    ) = separate_hemis(
        args, mask=mask, lr_axis=lr_axis,
    )

    parcels_l, cntrs_l, num_rois_l = do_kmeans(args, mask_l)
    if not args.maintain_symmetry:
        parcels_r, cntrs_r, num_rois_r = do_kmeans(args, mask_r)
    else:
        parcels_r = np.flip(parcels_l, axis=lr_axis).copy() * mask_r
        cntrs_r = {k:np.array([v[0], v[1], nvxl_lr-v[2]]) for k,v in cntrs_l.items()}
        num_rois_r = num_rois_l
    
    return [
        {'parcels': parcels_l, 'cntrs':cntrs_l, 'mask': mask_l, 'num_rois':num_rois_l, 'hemi':'l'},
        {'parcels': parcels_r, 'cntrs':cntrs_r, 'mask': mask_r, 'num_rois':num_rois_r, 'hemi':'r'}
    ]

In [7]:
# major brain divisions
set_ids = STree.get_structure_sets()
onto_df = pd.DataFrame(
    OntologiesApi().get_structure_sets(set_ids)
)
major_divs_id = onto_df[onto_df['name'] == 'Brain - Major Divisions']['id'].item()
major_divs_df = pd.DataFrame(STree.get_structures_by_set_id([major_divs_id]))

# kmeans within each division
args.maintain_symmetry = True
args.roi_size = 75 # voxels in roi: 3000, 1500, 500, 250, 30000, etc
rois_all = []
for idx, row in tqdm(major_divs_df.iterrows()):
    acro, div_id = row[['acronym', 'id']].to_list()
    DIV, metaDIV = mcc.get_structure_mask(div_id)
    DIV = DIV.astype(np.uint32)
    rois_all += kmeans_parcellation(args, mask=DIV)

# collect all kmeans outputs
rois_df = pd.DataFrame(rois_all)
parcels_all = np.zeros_like(AVGT, dtype=int)
cntrs_all = {}
for hemi, group in rois_df.groupby(by='hemi'):
    for idx, row in group.iterrows():
        num_rois = len(np.unique(parcels_all)[1:])
        
        parcels, cntrs, mask = row[['parcels', 'cntrs', 'mask']].to_list()
        parcels += mask * num_rois
        parcels_all += parcels
        
        cntrs = {k+num_rois:cntr for k, cntr in cntrs.items()}
        cntrs_all = {**cntrs_all, **cntrs}
        
num_rois = len(np.unique(parcels_all)[1:])
print(num_rois)
to_nifti(args, parcels_all).to_filename(f'{args.out_path}/parcels_{num_rois}.nii.gz')

0it [00:00, ?it/s]

12it [00:15,  1.28s/it]


5512
int64 (118, 160, 90)


---

In [8]:
n162_template = f'/home/govindas/mouse_dataset/gabe_symmetric_N162/Symmetric_N162_0.10_RAS.nii.gz'
n162_template = ants.image_read(n162_template)
n162_template.numpy().shape

(118, 160, 90)

In [9]:
allen_template = to_nifti(args, AVGT)

float32 (118, 160, 90)


In [10]:
tx = ants.registration(
    fixed=n162_template,
    moving=allen_template,
    type_of_transform=('SyN'),
)

In [11]:
allen_template_warped = ants.apply_transforms(
    fixed=n162_template,
    moving=allen_template,
    transformlist=tx['fwdtransforms'],
    interpolator='genericLabel',
)
allen_template_warped.to_filename(f'allen_warped.nii.gz')

In [31]:
num_rois = 266
parcels = f'{args.out_path}/parcels_{num_rois}.nii.gz'
parcels = ants.image_read(parcels)
parcels_warped = ants.apply_transforms(
    fixed=n162_template,
    moving=parcels,
    transformlist=tx['fwdtransforms'],
    interpolator='genericLabel',
)
parcels_warped.to_filename(f'parcels_warped_{num_rois}.nii.gz')

In [32]:
cmask = f'/home/govindas/mouse_dataset/voxel/common_brain_mask.nii.gz'
cmask = ants.image_read(cmask)

In [33]:
parcels_warped_cm = ants.resample_image_to_target(
    image=parcels_warped,
    target=cmask,
    interp_type='genericLabel',
)
parcels_warped_cm = parcels_warped_cm.new_image_like(
    data=parcels_warped_cm.numpy() * cmask.numpy()
)
parcels_warped_cm.to_filename(f'parcels_warped_cm_{num_rois}.nii.gz')