# Project overlapping masks to atlas
This notebook project a list of masks (generated from CONN for example or SPM), then projects to an atlas, then extract overlapping regions and plot these regions (and generate a mask so that you can use another application to generate a nice visualization).

By Stephen Larroque from Coma Science Group, University of Liège, created on 2017-04-18.

Version v1.6.6

TODO:
* generate maps for all x top overlaps from 1 to max if min_threshold is set to 'all' instead of int.

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as pltcol
import numpy as np
import nibabel as nib
from nilearn import image
from nilearn import plotting

In [None]:
voxel_threshold = 0.0001 # minimum value to be considered a non-background voxel signal (because background voxels not just 0.0, can be 0.0000000000001), can be float or str ('1%' to give a percentage). TODO: autodetect minimum value (can be -4, 0.02, etc) as the background and use it as the threshold value.
dpi_resolution = 300
# List of masks
# NOTE: the first image will be used as the template to resample other masks!
list_imgs = [
            r'hyperc2019\spmresults\dex\dmn-dex-hyperc.nii',
            r'hyperc2019\spmresults\doc-uws\dmn-doc-hyperc.nii',
            r'hyperc2019\spmresults\doc-mcs\dmn-docmcs-hyperc.nii',
            r'hyperc2019\spmresults\doc-emcs\dmn-docemcs-hyperc.nii',
            r'hyperc2019\spmresults\doc-uws-acute\dmn-docacute-hyperc.nii',
            r'hyperc2019\spmresults\doc-mcs-acute\dmn-docacutemcs-hyperc.nii',
            r'hyperc2019\spmresults\ket\dmn-ket-hyperc.nii',
            r'hyperc2019\spmresults\prop\dmn-prop-hyperc.nii',
            r'hyperc2019\spmresults\sevo\dmn-sevo-hyperc.nii',
            r'hyperc2019\spmresults\sleep\dmn-sleep-hyperc.nii',
            r'hyperc2019\spmresults\sleep-n2n3\dmn-sleepn2n3-hyperc.nii',
            ]
#list_imgs = [
#            r'hyperc2019\spmresults\dex-allsess\dmn-dex-hypercgroup.nii',
#            r'hyperc2019\spmresults\doc-allsess\dmn-doc-hypercgroup.nii',
#            r'hyperc2019\spmresults\doc-acute-allsess\dmn-docacute-hypercgroup.nii',
#            r'hyperc2019\spmresults\ket-allsess\dmn-ket-hypercgroup.nii',
#            r'hyperc2019\spmresults\prop-allsess\dmn-prop-hypercgroup.nii',
#            r'hyperc2019\spmresults\sevo-allsess\dmn-sevo-hypercgroup.nii',
#            r'hyperc2019\spmresults\sleep-allsess\dmn-sleep-hypercgroup.nii',
#            ]
# Optional: name for each mask, will be used to compute the top overlapping regions table at the end
list_imgs_names = [
    'Dex',
    'UWS',
    'MCS',
    'EMCS',
    'UWS acute',
    'MCS acute',
    'Ketamine',
    'Propofol',
    'Sevo',
    'Sleep N3-Awake',
    'Sleep N3-N2',
]
#list_imgs_names = [
#    'Dex',
#    'DOC',
#    'DOC acute',
#    'Ketamine',
#    'Propofol',
#    'Sevo',
#    'Sleep N3',
#]

# Load masks and resample to first
imgs = []
for img in list_imgs:
    im = image.load_img(img)
    if imgs:
        if im.shape != imgs[0].shape:
            im = image.resample_to_img(im, imgs[0])
    im = image.threshold_img(im, voxel_threshold)
    imgs.append(im)
# PLot!
plotting.plot_prob_atlas(imgs, view_type="filled_contours",
                    title="lala", colorbar=True, cut_coords=(0,0,0), draw_cross=True, cmap=pltcol.ListedColormap(['b', 'g', 'r', 'c', 'm'], name='from_list', N=None))
plotting.plot_roi(imgs[0])

In [None]:
from nilearn import plotting, datasets
# Atlas
atlas_choice = 'aal2' # anatomytoolbox or aal2
if atlas_choice == 'anatomytoolbox':
    atlas_path = 'masks\AnatomyToolbox_Atlas_Map.nii'  # TODO: build atlas variable with all infos and data (labels, indices, nib niftiimage with affine etc)
else:
    atlas = datasets.fetch_atlas_aal(version='SPM12', data_dir='atlas')
# Load mask images
imgs = []
for img in list_imgs:
    im = image.load_img(img)
    if imgs: # if there is already at least one image loaded, resample subsequent images to the first one
        if im.shape != imgs[0].shape:
            im = image.resample_to_img(im, imgs[0])
    im = image.threshold_img(im, voxel_threshold)
    imgs.append(im)

In [None]:
atlas_im = image.load_img(atlas.maps)
print('Atlas shape: %s' % str(atlas_im.shape))
nb_regions = len(np.unique(atlas_im.get_data()))- 1
print('%i regions in this atlas: %s' % (nb_regions, str(np.unique(atlas_im.get_data())))) # 48 regions because 0 is background
print('%i labels' % len(atlas['labels']))
print('%i indices: %s' % (len(atlas['indices']), atlas['indices']))
atlas.keys()

In [None]:
print('Number of non-zero voxels per map (after thresholding):')
for im in imgs:
    print(np.nonzero(im.get_data())[0].shape)

In [None]:
# just plot each map
for im in imgs:
    plotting.plot_stat_map(im)

In [None]:
def get_atlas_label(atlas, region_idx):
    return atlas['labels'][atlas['indices'].index(str(region_idx))]

In [None]:
# Resample masks to atlas size
imgs2 = []
for img in imgs:
    if img.shape != atlas_im.shape:
        img = image.resample_to_img(img, atlas_im)
    img = image.threshold_img(img, voxel_threshold)
    imgs2.append(img)
imgs = imgs2
del imgs2
imgs[0].shape

# Extract activated atlas brain regions for each mask
maps_regions = []
maps_regions_idxs = []
maps_regions_count = []
for img in imgs:
    # Extract only non zeros voxels indices from mask
    im_data = img.get_data()
    #np.extract(im_data>0, im_data)
    vox_thres = np.nonzero(im_data)
    # Compare with atlas regions to extract region indices
    atlas_data = atlas_im.get_data()
    region_indices = set()
    region_count = {}
    for x in zip(*vox_thres): # walk through all non zero voxels of mask
        region_idx = atlas_data[x] # get equivalent voxel from atlas
        if region_idx != 0: # if not background
            # Append region index into the set (so that they are unique)
            region_indices.add(region_idx)
            # Increase the count of voxels activated in this region
            region_label = get_atlas_label(atlas, region_idx)
            if region_label not in region_count:
                region_count[region_label] = 0
            region_count[region_label] += 1
    if 0 in region_indices:
        region_indices.remove(0) # remove background, not part of the atlas labels
    print('Atlas indices of brain regions activated in current mask: %s' % str(sorted(region_indices)))
    # Extract brain region names from atlas that are present in this mask
    matching_idxs = [int(idx) in region_indices for idx in atlas['indices']]
    map_brain_regions = filter(None, [label if match else None for label, match in zip(atlas['labels'], matching_idxs)])
    maps_regions.append(map_brain_regions)
    maps_regions_idxs.append(region_indices)
    maps_regions_count.append(region_count)

In [None]:
maps_regions

In [None]:
# Display number of voxels for each region of each mask
from collections import OrderedDict
for i, region_count in enumerate(maps_regions_count):
    print('== %i: %s' % (i, str(OrderedDict(sorted(region_count.items(), key=lambda x: x[1], reverse=True)))) )

In [None]:
# Compute count of voxels per atlas region
atlas_regions_count = {}
for coord in zip(*np.where(atlas_data)):
    region_idx = atlas_data[coord]
    region_label = get_atlas_label(atlas, region_idx)
    if not region_label in atlas_regions_count:
        atlas_regions_count[region_label] = 0
    atlas_regions_count[region_label] += 1
atlas_regions_count

In [None]:
# Compute percentage of voxels (maps_regions_count / atlas_regions_count)
maps_regions_percent = []
for regions_count in maps_regions_count:
    maps_regions_percent.append({region_label: (float(region_count) / atlas_regions_count[region_label] * 100) for region_label, region_count in regions_count.items()})

# Display percentage
from collections import OrderedDict
for i, region_percent in enumerate(maps_regions_percent):
    print('== %i: %s' % (i+1, str(OrderedDict(sorted(region_percent.items(), key=lambda x: x[1], reverse=True)))) )

In [None]:
# Rank how much each brain regions are overlapping
overlap_min_thres = 1 # len(imgs)-1 # minimum overlapping threshold = minimum number of maps that need to have these regions activated
overlap_min_voxels = 5 # minimum number of voxels required to consider a region really activated and not spurious activity
overlap_min_percent = 1 # minimum percentage of voxels activated over total atlas region surface to consider as real activation and not spurious

overlaps = {}
for region in atlas['labels']:
    overlaps[region] = 0
    for regions_mask, regions_count, regions_percent in zip(maps_regions, maps_regions_count, maps_regions_percent):
        if region in regions_mask and regions_count[region] >= overlap_min_voxels and regions_percent[region] >= overlap_min_percent:
            overlaps[region] += 1
overlaps_top = {k:v for k,v in overlaps.items() if v >= overlap_min_thres}
print(len(overlaps_top))
overlaps_top

In [None]:
# Get atlas indices for top overlapping regions
overlap_condition = [True if label in overlaps_top.keys() else False for label in atlas['labels']]
overlap_idxs = [int(idx) for idx,cond in zip(atlas['indices'], overlap_condition) if cond]
overlap_idxs

In [None]:
# Construct atlas constrained to only top overlapping regions
# Extract overlapping regions voxels
ix = np.in1d(atlas_data, overlap_idxs).reshape(atlas_data.shape)
vox_overlap_atlas = np.where(ix)
if len(vox_overlap_atlas[0]) == 0:
    print('Found no overlapping voxel, cannot generate a map!')
else:
    # Generate atlas with indices/labels
    overlap_data = np.zeros(atlas_data.shape, dtype='uint8')
    for coord in zip(*vox_overlap_atlas):
        overlap_data[coord] = int(atlas_data[coord])
print(overlap_data.dtype)
np.where(overlap_data)

In [None]:
# Generate a full brain atlas to show in other vis softwares (so that there is no missing brain)
overlap_data_full = overlap_data.copy()
for coord in zip(*np.where(atlas_data)):
    if overlap_data_full[coord] == 0:
        overlap_data_full[coord] = 7 # TODO: change value here if some areas are label 7, we want to use here a label that does not exist! 7 is grey in MRIcroGL

In [None]:
# Convert numpy mask to nifti image using nibabel
overlap_image = nib.Nifti1Image(overlap_data, affine=atlas_im.affine)
overlap_image_full = nib.Nifti1Image(overlap_data_full, affine=atlas_im.affine)
#overlap_image.header['regular'] = 'r'
# Set the header intent_code to 1002 to signify it's an atlas (label regions) to ease visualization in other softwares
# https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h
# https://nifti.nimh.nih.gov/nifti-1/documentation/nifti1fields/nifti1fields_pages/group__NIFTI1__INTENT__CODES.html#a26
overlap_image.header['intent_code'] = 1002
overlap_image_full.header['intent_code'] = 1002
# Save overlap image as a mask
nib.save(overlap_image, 'overlap_image.nii')  # just the overlap regions
nib.save(overlap_image_full, 'overlap_image_full.nii')  # overlap regions + rest of the atlas regions as region 0 (black)

# Plot!
print('Overlapping brain regions with min overlap=%i: %s' % (overlap_min_thres, overlaps_top))
#plotting.plot_prob_atlas(overlap_data, view_type="filled_contours",
                    #title="lala", colorbar=True, cut_coords=(0,0,0), draw_cross=True, cmap=pltcol.ListedColormap(['b', 'g', 'r', 'c', 'm'], name='from_list', N=None))
fig1 = plotting.plot_roi(overlap_image, title='Overlapping brain regions with min overlap=%i' % overlap_min_thres, cut_coords=[5, 10, 37], cmap=plt.cm.prism)
fig1.savefig('overlap_image.png', dpi=dpi_resolution)
fig2 = plotting.plot_glass_brain(overlap_image, title='Overlapping brain regions with min overlap=%i' % overlap_min_thres, cmap=plt.cm.prism)
fig2.savefig('overlap_image_glass.png', dpi=dpi_resolution)
print('Saved results in overlap_image.nii, overlap_image_full.nii, overlap_image.png and overlap_image_glass.png')

In [None]:
# Compute overlap table, which is a summary of whether each region is present or not in each map
import pandas as pd
disable_overlap_min_checks = False  # can disable the minimum amount of voxels required to count the region in
overlap_table = pd.DataFrame(columns=list_imgs_names + ['total', 'real_total'], dtype='uint8')

for region_label, total in overlaps_top.items():
    overlap_table.loc[region_label, :] = False
    for map_id, regions in enumerate(maps_regions):
        if region_label in regions:
            if disable_overlap_min_checks or (maps_regions_count[map_id][region_label] >= overlap_min_voxels and maps_regions_percent[map_id][region_label] >= overlap_min_percent):
                overlap_table.loc[region_label, overlap_table.columns[map_id]] = True
    overlap_table.loc[region_label, 'total'] = total
    overlap_table.loc[region_label, 'real_total'] = overlap_table.loc[region_label, list_imgs_names].astype(int).sum()

# Display results
# Better table formatting + enable hovering line by line
from IPython.display import HTML
def hover(hover_color="#ffff99"):
    return dict(selector="tr:hover",
                props=[("background-color", "%s" % hover_color)])
styles = [
    hover(),
    dict(selector="table", props=[("border", "none"),
                                  ("border-collapse", "collapse")]),
    dict(selector="th", props=[("font-size", "110%"),
                               ("text-align", "center"),
                               ("border", "none")]),
    dict(selector="caption", props=[("caption-side", "top"),
                                    ("text-align", "center")]),
    dict(selector="td", props=[("text-align", "center"),
                               ("border", "none"),
                               ("font-size", "110%")]),
    dict(selector="tr", props=[("border", "none")])
]

def green(val):
    # From https://stackoverflow.com/questions/41555678/highlighting-multiple-cells-in-different-colors-with-pandas#
    if val is True:
        color = 'lightgreen'
        return 'background-color: %s' % color
    else:
        return ''  # need to return a string in any case

def green_image(val):
    # If true, hide the text and show a picture inplace
    if val is True:
        return 'width:32px;height:27px;background:url("checkgreen_small.png");background-repeat:no-repeat;background-position:center;text-indent:-9999px'
    else:
        return ''  # need to return a string in any case

overlap_table.sort_values(['real_total'], ascending=False, inplace=True)  # sort
overlap_table = overlap_table.replace(False, '')  # empty false (cleaner table, more legible)
html = (overlap_table.style.set_caption('Overlapping hyperconnectivity regions summary table')\
                .set_table_styles(styles)\
                .applymap(green_image))  # highlight True cells in green

html  # display!

In [None]:
# Save the table to a csv file
overlap_table.to_csv('overlap_table.csv', sep=';')

----------------------------------------------
## Test

In [None]:
overlap_image.header.values()

In [None]:
aal = nib.load('aal.nii.gz')

In [None]:
aal.header['sizeof_hdr']

In [None]:
overlap_image.header['regular'] = 'r'
overlap_image.header['intent_code'] = 1002

In [None]:
for k in aal.header.keys():
    #if aal.header[k] != overlap_image.header[k]:
    print('%s=%s OR %s' % (k, aal.header[k], overlap_image.header[k]))
    #elif k not in overlap_image.header.keys():
    #    print('MISSING: %s=%s' % (k, aal.header[k]))

In [None]:
overlaps