# Atlas-based parcellation

In [None]:
import numpy as np

from mne.io import read_raw_fif
from mne_bids import BIDSPath
from nilearn import datasets, image
from collections import Counter

import pandas as pd
from sklearn.utils import Bunch
from scipy.spatial import KDTree
from nilearn.datasets._utils import (
    get_dataset_descr,
    get_dataset_dir,
    fetch_single_file,
)
import shutil
import os


def get_coord_atlas_labels(
    coords: np.array, atlas_map: str, atlas_labels: list[str]
) -> list[str]:
    atlas_map = image.load_img(atlas_map)
    atlas_image = atlas_map.get_fdata().astype(int)

    # find non-zero labels
    image_label_coords = np.nonzero(atlas_image)

    # transform label indices to MNI space
    atlas_coords = np.vstack(
        image.coord_transform(*image_label_coords, atlas_map.affine)
    ).T

    # find nearest neighbor
    # dists = cdist(
    #     coords.astype(np.float32), atlas_coords.astype(np.float32), metric="euclidean"
    # )
    # nearest_neighbor = dists.argmin(-1)
    tree = KDTree(atlas_coords)
    dists, nearest_neighbor = tree.query(coords, k=1)

    # look up neighbor index in map
    x = image_label_coords[0][nearest_neighbor]
    y = image_label_coords[1][nearest_neighbor]
    z = image_label_coords[2][nearest_neighbor]

    # convert map index to label
    elec_label_ids = atlas_image[x, y, z]
    elec_labels = [atlas_labels[i] for i in elec_label_ids]

    return elec_labels

In [None]:
atlas = datasets.fetch_atlas_destrieux_2009()
indices = atlas['labels'].index
names = atlas['labels'].name
atlas_labels = {idx: name for idx, name in zip(indices, names)}

# atlas = datasets.fetch_atlas_schaefer_2018(n_rois=100, resolution_mm=1, yeo_networks=17)
# atlas_labels = ['background'] + [l.decode() for l in atlas['labels']]

# atlas = fetch_atlas_glasser_2016()
# atlas_labels = atlas['labels'].to_dict()['label']

original_elec_labels = get_coord_atlas_labels(coords, atlas['maps'], atlas_labels)
elec_labels = [label.split()[1] for label in original_elec_labels]  # for destrieux

# elec_labels = [label.split("_")[1] for label in elec_labels]  # for glasser
# elec_labels = [label.split("_", 2)[-1].split("_")[1] for label in elec_labels]  # for schaefer
# elec_labels = [label.split("_", 2)[-1].split("_")[0] for label in elec_labels]  # for kong networks
# elec_labels = [label.split("_", 2)[-1] for label in elec_labels]  # for schaefer, just remove network and hemi

In [None]:
counter = Counter(elec_labels)
top_rois = counter.most_common(24)

fig, axes = plt.subplots(6, 4, sharex=True, sharey=True, dpi=300, figsize=(12, 10))

for (roi, count), ax in zip(top_rois, axes.ravel()):

    roi_mask = np.array([roi in label for label in elec_labels])
    roi_coords = coords[roi_mask]
    scores = np.ones(len(roi_coords))

    plot_markers(scores, coords[roi_mask],
                node_size=10, display_mode='xz',
                alpha=0.8, colorbar=False,
                node_cmap='Grays', node_vmin=0, node_vmax=1,
                figure=fig, axes=ax)

    ax.set_title(roi, fontsize=8)
fig.show()