### Important
The code was done on an example dataset. Need to change to the real preprocessed data when available

In [None]:
import pandas as pd
import numpy as np
import sys
import os
import os.path as op
from utils import mkdir_no_exist#, make_mask_from_aal

import nibabel as nib
from nilearn import image, datasets
from nilearn.glm import threshold_stats_img
from nilearn.glm.first_level import make_first_level_design_matrix, FirstLevelModel
from nilearn.plotting import plot_design_matrix, plot_contrast_matrix, plot_stat_map, plot_roi
from nilearn.image import mean_img

from matplotlib import pyplot as plt

In [None]:
# ACCESS DIRECTORIES #

current_dir = os.path.abspath("")
print(f"current_dir: {current_dir}")
sys.path.append(current_dir)

dataset_id = 'ds000171'
subjects = ['sub-control{:02d}'.format(i+1) for i in range(20)]

dataset_path = os.path.join(current_dir, "data", dataset_id)
deriv_path = os.path.join(current_dir,"data", "derivatives")
preproc_path = os.path.join(deriv_path, 'preprocessed_data')

mkdir_no_exist(dataset_path)
mkdir_no_exist(preproc_path)


In [None]:
# READ AND CONCAT EVENTS #

events = {}
for i in range(1,4):
    events[i] = pd.read_csv(os.path.join(dataset_path,subjects[0],"func",f"sub-control01_task-music_run-{i}_events.tsv"), sep = "\t")
    if i>1: #adjust onset time for concatenation
        events[i]['onset'] += (events[i-1]['onset'].iloc[-1] + events[i-1]['duration'].iloc[-1])
        
# for key, event in events.items():
#     display(f"EVENT #{key}\n",event)
events_concat = pd.concat(events.values(),ignore_index=True)
display(events_concat)
events_concat.to_csv(op.join(deriv_path, 'events_concat.csv'))

In [None]:
# ONLY FOR TESTING: need to replace with the preprocessed data
# plotted design matrix is not correct until glm is fit to own data!

# from nilearn.datasets import fetch_spm_auditory
# subject_data = fetch_spm_auditory()

# from nilearn.image import concat_imgs, mean_img
# fmri_img = concat_imgs(subject_data.func)

fmri_img = nib.load(op.join(preproc_path,"sub-control01/func/final/sub-control01_task-music_run-all_bold_smoothed-9mm.nii.gz"))

In [None]:
fmri_img.shape

In [None]:
task_music_bold = pd.read_json(op.join(dataset_path, "task-music_bold.json"), typ = 'series')

# fmri_glm = FirstLevelModel(t_r = task_music_bold.RepetitionTime, 
#                            noise_model='ar1',
#                            standardize=False,
#                            hrf_model='spm + derivative + dispersion',
#                            drift_model="cosine",
#                            high_pass=.01)

# # Fit the model to our design and data
# fmri_glm = fmri_glm.fit(fmri_img, events_concat)
# design = fmri_glm.design_matrices_[0]

# plot_design_matrix(design,rescale=True, output_file = op.join(deriv_path, 'design_matrix.jpg'))


In [None]:
t_r = task_music_bold.RepetitionTime
n_scans = fmri_img.get_fdata().shape[-1]
# .get_fdata()
frame_times = (
    np.arange(n_scans) * t_r
)

motion_outliers = [104, 121, 122, 209]
new_regs = []
new_reg_names = []
for idx, out_frame in enumerate(motion_outliers):
    column_values = np.zeros(n_scans)
    column_values[out_frame] = 1
    new_regs.append(column_values)
    new_reg_names.append(f"motion_outlier_{idx}")
new_regs = np.vstack(new_regs)
                    

design = make_first_level_design_matrix(frame_times = frame_times,
                                        events= events_concat, 
                                        hrf_model='spm + derivative + dispersion',
                                        drift_model="cosine",
                                        high_pass = .01,
                                        add_regs = new_regs.T,
                                        add_reg_names = new_reg_names,
                                       )

plot_design_matrix(design,rescale=False, output_file = op.join(deriv_path, 'design_matrix.jpg'))
plot_design_matrix(design, rescale=False)
plt.show()

In [None]:
fmri_glm = FirstLevelModel(t_r = task_music_bold.RepetitionTime, 
                           noise_model='ar1',
                           standardize=False,
                           hrf_model='spm + derivative + dispersion',
                           drift_model="cosine",
                           high_pass=.01)
fmri_glm = fmri_glm.fit(fmri_img, design_matrices=[design])

### Beta/ Statistical maps of each of the regressors

In [None]:
to_show = [x for x in np.arange(-30,30,5)]
z_threshold = 1.5

In [None]:
def condition_vector(position:int, n_regressors:int)->np.array:
    vec = np.zeros((1,n_regressors))
    vec[0,position] = 1
    return vec

n_regressors = design.shape[1]
conditions = {
    col:condition_vector(idx, n_regressors) for idx, col in enumerate(design.columns)
}

In [None]:
# All those that are not covariates
covariates = [key for key, value in conditions.items() if "drift" in key or "constant" in key or "outlier" in key or "derivative" in key or "dispersion" in key]
conditions_of_interest = {key: value for key, value in conditions.items() if key not in covariates}

betas_zmap = op.join(preproc_path,'betas_zmap')
mkdir_no_exist(betas_zmap)

cluster_size = 1 #number of voxels for a cluster to be kept
fdr_rate = 0.05 # alpha = 5%

In [None]:
#TESTING: Plot the statistical map of a single component
# key = 'negative_music'
# value = conditions_of_interest[key]

# z_map = fmri_glm.compute_contrast(value, output_type='z_score')

# mean_img_ = mean_img(fmri_img),
# plot_stat_map(z_map,
#               bg_img=mean_img_[0],
#               threshold=z_threshold,
#               display_mode='z',
#               cut_coords=to_show,
#               black_bg=True,
#               title='Condition: {} (|Z|>{})'.format(key, z_threshold),
#               #output_file = op.join(betas_zmap,f"z_map_fdr5p_cl10_{key}.jpg")
#               )
# plt.show()


In [None]:
mean_img_ = mean_img(fmri_img),
fig, axs = plt.subplots(len(conditions_of_interest), 1, figsize=(20,14))
for i, (key,value) in enumerate(conditions_of_interest.items()):
    z_map = fmri_glm.compute_contrast(value, output_type='z_score')
    
    
    plot_stat_map(z_map,
                  bg_img=mean_img_[0],
                  threshold=z_threshold,
                  display_mode='z',
                  cut_coords=to_show,
                  black_bg=True,
                  title='Condition: {} (|Z|>{})'.format(key, z_threshold),
                  axes = axs[i])
plt.subplots_adjust(wspace=0, hspace=0)
output_plot_path = 'data/derivatives/all_cdts_9mm.png'
plt.savefig(output_plot_path, format='png')
plt.show()

### Contrasts

In [None]:
positive_music_position = 2
negative_music_position = 0

pos_vs_neg = conditions['positive_music'] - conditions['negative_music']

plot_contrast_matrix(pos_vs_neg, design_matrix=design, output_file = op.join(deriv_path, 'contrast_pos_neg.jpg'))
plt.show()

In [None]:
z_map = fmri_glm.compute_contrast(pos_vs_neg, output_type='z_score')

mean_img_ = mean_img(fmri_img),
plot_stat_map(z_map,
              bg_img=mean_img_[0],
              threshold=z_threshold,
              display_mode='z',
              cut_coords=to_show,
              black_bg=True,
              title='Positive minus Negative Music (|Z|>{})'.format(z_threshold))
output_plot_path = 'data/derivatives/contrast_9mm.png'
plt.savefig(output_plot_path, format='png')
plt.show()

### AAL overlay

In [None]:
from nilearn import datasets, image
from nilearn.image import resample_to_img
import ants

atlas_path = "data/derivatives/atlas_template.nii"
fMRI_MNI_path = op.join(preproc_path,"sub-control01/func/sub-control01_task-music_run-all_bold_moco_MNI.nii")

def atlas_fMRI_MNI(fmri_img, atlas_path, fMRI_MNI_path):
    aal_atlas = datasets.fetch_atlas_aal(version='SPM12')
    atlas_img = image.load_img(aal_atlas.maps)
    atlas_img.to_filename(atlas_path)
    
    fmri_resampled = resample_to_img(fmri_img, atlas_img, interpolation='continuous')
    mean_fmri=mean_img(fmri_resampled)
    fmri_resampled_path = op.join(preproc_path,"sub-control01/func/sub-control01_task-music_run-all_bold_moco_resampled.nii")
    mean_fmri.to_filename(fmri_resampled_path)
    
    moving_image = ants.image_read(fmri_resampled_path)
    fixed_image = ants.image_read("data/derivatives/atlas_template.nii")
    
    transformation = ants.registration(fixed=fixed_image, moving=moving_image, type_of_transform = 'SyN')
    warpedImage = ants.apply_transforms(fixed=fixed_image, moving=moving_image, transformlist=transformation['fwdtransforms'])  
    ants.image_write(warpedImage, fMRI_MNI_path)

    return aal_atlas, atlas_img, nib.load(fMRI_MNI_path)

aal_atlas, atlas_img, fMRI_MNI = atlas_fMRI_MNI(fmri_img, atlas_path, fMRI_MNI_path)

In [None]:
deriv_path

In [None]:
# TODO: same MNI space -> run it before analysis

plot_stat_map(z_map,
              bg_img=mean_img_[0],
              threshold=z_threshold,
              display_mode='z',
              cut_coords=to_show,
              black_bg=True,
              title='Positive minus Negative Music (|Z|>2)',
              # output_file = op.join(deriv_path,"Z_contrasts"),
            )

plot_roi(atlas_img,
          bg_img=fMRI_MNI,
          display_mode='z',
          cut_coords=to_show,
          alpha=0.3,
          title="AAL Atlas Overlay",
          # output_file = op.join(deriv_path,"contrast_ROI"),
        )

plt.show()

In [None]:
# oooor choose which one

def make_mask_from_aal(mask_value, mask_name):
    # Load the AAL atlas
    aal_atlas = datasets.fetch_atlas_aal(version='SPM12')
    atlas_img = image.load_img(aal_atlas.maps)
    atlas_data = atlas_img.get_fdata()  # Extract atlas data as numpy array

    # Create a binary mask for the specified region
    mask_data = atlas_data == mask_value

    # Create a new Nifti image with the mask
    mask_img = nib.Nifti1Image(mask_data.astype(np.uint8), atlas_img.affine, atlas_img.header)

    # Save the mask
    if ".nii" not in mask_name:
        mask_name += ".nii"
    nib.save(mask_img, mask_name)

mask_value = 2001  # Replace with the region value from the AAL atlas
mask_name = "mask.nii.gz"  # Name the mask file

make_mask_from_aal(mask_value, mask_name)

# Overlay AAL atlas mask on top of the fMRI
plot_roi(mask_name,
          bg_img=fMRI_MNI,
          display_mode='z',
          cut_coords=[-3, 20, 36, 70],
          alpha=0.3,
          title="AAL Atlas Overlay")

plt.show()


In [None]:
for idx, label in enumerate(aal_atlas.labels):
    print('{:<21s} {}'.format(label, aal_atlas.indices[idx]))