In [2]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

In [3]:
atlas_path = 'atlases/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_2mm.nii.gz'
atlas = nib.load(atlas_path)
atlas_data = atlas.get_fdata()
atlas_copy = np.copy(atlas_data) # making copies so the original doesn't get modified with any of the below code

mni_path = 'atlases/MNI152_T1_2mm_brain.nii.gz'
mni = nib.load(mni_path)
mni_data = mni.get_fdata()
mni_copy = np.copy(mni_data)

In [4]:
def find_maximal_slices(nifti, parcel):
    # finds the slice with the maximum number of occurrences of 'parcel'
    def find_max_slice(axis):
        axis_counts = np.sum(nifti == parcel, axis=axis) # counts of boolean arrays for slices 
        return np.argmax(axis_counts)

    # finding max on specific axes
    m = find_max_slice(axis=(1, 2)) 
    n = find_max_slice(axis=(0, 2))
    p = find_max_slice(axis=(0, 1))
    return [m,n,p]
    
    
def plot_slices(data, max_list, parcel):
    x,y,z = max_list[0], max_list[1], max_list[2] # taking in slices that show maximal coverage of current parcel
    
    slices = [np.rot90(data[x,:,:]), # saggital
              np.rot90(data[:,y,:]), # coronal
              np.rot90(data[:,:,z])] # axial
    
    mni_slices = [np.rot90(mni_copy[x,:,:]), 
                  np.rot90(mni_copy[:,y,:]), 
                  np.rot90(mni_copy[:,:,z])]
    
    
    if parcel < 200: # flipping the left hemisphere images so they're more intuitive
        slices[0] = np.fliplr(slices[0])
        mni_slices[0] = np.fliplr(mni_slices[0])
    
    for i in range(1,3): # rotate views 2 and 3 for all parcels since atlases are flipped
        slices[i] = np.fliplr(slices[i])
        mni_slices[i] = np.fliplr(mni_slices[i])
    
    mask_color = [1, 0, 0, 1]  # RGBA for red
    fig, axes = plt.subplots(1, 3, figsize=(10,8))
    
    for i in range(3):
        curr_slice = slices[i]
        idx = np.where(curr_slice == parcel)    
        mask = np.zeros((*curr_slice.shape, 4))  # Shape: height x width x 4 (RGBA)
        
        # Set the color only in the region of interest
        mask[idx] = mask_color
        axes[i].imshow(mni_slices[i], cmap='gray')
        axes[i].imshow(mask, interpolation='none') # plotting mask on top of grayscale image
        axes[i].axis('off')
        
    plt.savefig(f'parcels/parcel_{parcel}.png')
    plt.close()

    

In [5]:
for i in range(1, 401): # for 400 parcels
    slices = find_maximal_slices(atlas_copy, i)
    plot_slices(atlas_copy, slices, i)