In [7]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

In [134]:
atlas_path = 'atlases/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_2mm.nii.gz'
atlas = nib.load(atlas_path)
atlas_data = atlas.get_fdata()

In [155]:
def isolate_parcel(nifti, i):
    idx = np.where(nifti==i)
    non_idx = np.where((nifti != i) & (nifti != 0))
    nifti[non_idx] = 300 # 300 and 100 are arbitrary. Just giving contrast to parcel
    nifti[idx] = 100 
    return nifti
    
def plot_slices(data, slice_nums, parcel_num, cmap='jet', figsize=(10, 8), cols_per_row=3):
    # Calculate number of rows needed
    num_slices = len(slice_nums)
    rows = (num_slices + cols_per_row - 1) // cols_per_row  # Ceiling division

    # Create subplots
    fig, axes = plt.subplots(rows, cols_per_row, figsize=figsize)
    axes = axes.flatten()  # Flatten the axes array for easy iteration

    # Plot each slice
    for i, slice_num in enumerate(slice_nums):
        axes[i].imshow(data[:, :, slice_num], cmap=cmap)
        axes[i].axis('off')

    # Turn off any unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    # Save and close the figure
    # plt.show()
    plt.savefig(f'parcels/parcel_{parcel_num}.png')
    plt.close()
    

In [156]:
for i in range(1, 401): # for 400 parcels
    new_nifti = np.copy(atlas_data)
    parcel = isolate_parcel(new_nifti, i)
    slices = list(range(20, 70, 3))
    plot_slices(parcel, slices, i, cmap='jet')