Plotting function

In [None]:
def plot_slices(nifti_path,overlay=None,title='',z_threshold = 2.,_cmap='magma', underlay_vmax=30000,overlay_type='scatter',figsize=(12,3),scatter_s_size=3.):
    import nibabel as nib
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm

    data = nifti_path.get_fdata()
    
    z_dim = data.shape[2]
    nrows = int(z_dim/5) + 1
    if nrows <= 1:
        nrows = 2
    rowcount = 0
    for z in range(z_dim):
        if z == 0:
            fig, axs = plt.subplots(nrows=nrows, ncols=5, figsize=figsize)
            
        if z >= 5 and (z % 5) == 0:
            rowcount += 1

        im1 = axs[rowcount,z % 5].imshow(data[:, :, z], cmap='gray',vmax=underlay_vmax)
        axs[rowcount,z % 5].text(2,6,s=f"z={z}",c='white',fontsize=20)
        if overlay is not None:
            overlay_data = overlay.get_fdata()
            cmap = np.zeros_like(overlay_data)
            overlay_coords = np.where(overlay_data>z_threshold)
            cmap[overlay_coords] = overlay_data[overlay_coords]
            
            # overlaid imshow
            if overlay_type == 'imshow':
                im2 = axs[rowcount,z % 5].imshow(cmap[:, :, z], cmap=_cmap, alpha=.5)
            
            # Scatter plot
            if overlay_type == 'scatter':
                _xs, _ys, _zs = overlay_coords[1], overlay_coords[0], overlay_coords[2]
                x_coords_z_slice, y_coords_z_slice, intensities = [], [], []
                for _x, _y, _z in zip(_xs,_ys,_zs):
                    if _z == z:
                        x_coords_z_slice.append(_x-.25)
                        y_coords_z_slice.append(_y-.25)
                        intensities.append(overlay_data[_y,_x,_z])
                scatter = axs[rowcount,z % 5].scatter(
                    x_coords_z_slice,
                    y_coords_z_slice,
                    c=intensities,
                    s=scatter_s_size,
                    cmap=_cmap,
                    vmin=0,
                    vmax=overlay_data.max(),
                )
                # Create a ScalarMappable object for the colorbar
                try:
                    sm = cm.ScalarMappable(
                        cmap=_cmap, 
                        norm=plt.Normalize(
                            vmin=0,
                            vmax=overlay_data.max(),
                        )
                    )
                    sm._A = []  # Needed for matplotlib v3.3 and above
                    # Add a colorbar to the plot
                    cbar = fig.colorbar(sm, ax=axs[rowcount,z % 5])
                except:
                    pass

        axs[rowcount,z % 5].axis('off')
        try:
            fig.suptitle(f"{title}\nMax Z-score: {overlay_data.max():.2f}")
        except:
            fig.suptitle(f"{title}")
    
    fig.tight_layout()

Get all data

In [None]:
# get all niftis across datasets
niftis_1 = !ls /data/mouse_data/bids_visualblock/sub-43393072F/ses-06/func/*nii.gz
niftis_2 = !ls /data/mouse_data/bids/sub-06393073M/ses-Pilot01/func/*TEST0?VisualBlock*nii.gz
niftis = niftis_1 + niftis_2

# get a single tsv file
event_tsvs = !ls /data/mouse_data/bids_visualblock/sub-43393072F/ses-06/func/*tsv
event_tsv = event_tsvs[0]

niftis, event_tsv

Get data ready

In [None]:
import pandas as pd
import nibabel as nib
import numpy as np
from nilearn import image, masking

idx = 5

fmri_img = nib.load(niftis[idx])

mean_img = image.mean_img(fmri_img)
mask = np.zeros(mean_img.get_fdata().shape) + 1
mask = nib.Nifti1Image(mask,mean_img.affine,mean_img.header)

fmri_img = image.clean_img(fmri_img,standardize=False)
fmri_img = image.smooth_img(fmri_img,1.)

events = pd.read_table(event_tsv)

TR = fmri_img.header.get_zooms()[-1]

print(f"Nifti: {niftis[idx]}\nEvents_tsv: {event_tsv}\nTR: {TR}")

Set-up GLM

In [None]:
from nilearn.glm.first_level import FirstLevelModel

fmri_glm = FirstLevelModel(
    t_r=TR,
    drift_model="cosine",
    signal_scaling=False,
    mask_img=mask,
    minimize_memory=False,
)

fmri_glm = fmri_glm.fit(fmri_img, events)

# Fit GLM
z_map = fmri_glm.compute_contrast("visual_10Hz")

Plot

In [None]:
resize_factor = 4
plot_slices(
    mean_img,
    z_map,
    title=f"IDX: {idx}",
    z_threshold=2.,
    _cmap='magma',
    overlay_type='scatter',
    scatter_s_size=50,
    underlay_vmax=mean_img.get_fdata().max(),
    figsize=(6*resize_factor,5*resize_factor)
)