In [None]:
from pathlib import Path
from bids.layout import BIDSLayout
import nibabel as nib
import numpy as np
import sys
import random
import matplotlib.pyplot as plt

from nilearn.interfaces.bids import get_bids_files, parse_bids_filename
from nilearn import datasets, image, plotting
from nilearn.connectome import ConnectivityMeasure
from nilearn.interfaces.fmriprep import load_confounds
from nilearn.maskers import NiftiMapsMasker, NiftiLabelsMasker
from connectivity.atlases import fetch_aicha, overlay_atlas, make_overlay_slices, atlas_mapping

sys.path.append("/homes_unix/jlegrand/MEMENTO")
from utils.visualisation import make_and_show_middle_slices

BIDSDIR = Path("/georges/memento/BIDS")

ATLAS = "difumo"


In [None]:

fmri_path = get_bids_files(
    BIDSDIR / "derivatives/fmriprep-23.2.0",
    "bold",
    modality_folder="func",
    file_type="nii.gz",
    sub_label="0001",
    filters=[
    #    #("space", "MNI152NLin2009cAsym")
        ("space", "MNI152NLin6Asym")
    ],
)

fmri_path


In [None]:


# This should be refactored with atlases as objects

atlas = atlas_mapping[ATLAS]()

if not "maps" in atlas.keys():
    atlas["maps"] = atlas["rsn20"]

try:
    atlas_img = nib.load(atlas["maps"]).get_fdata()
except TypeError:
    atlas_img = atlas["maps"].get_fdata()

if atlas_img.ndim == 4:
    atlas_img = atlas_img.mean(axis=3)

img = nib.load(random.choice(fmri_path)).get_fdata().mean(axis=3)
fig = overlay_atlas(img, np.where(atlas_img == 0, 0, 1)) # TODO Make the function more flexible
fig.suptitle(ATLAS, y=0.95)

plt.show()

In [None]:

SOFT_ATLASES = {"smith", "difumo", "msdl"}

if ATLAS in SOFT_ATLASES:
    plotting.plot_prob_atlas(atlas.maps, title=ATLAS)
else:
    plotting.plot_roi(atlas.maps, title=ATLAS)

plt.show()


In [None]:
if ATLAS in SOFT_ATLASES:
    masker = NiftiMapsMasker(
        maps_img=atlas.maps,
        standardize="zscore_sample"
    )
else:
    masker = NiftiLabelsMasker(
        labels_img=atlas.maps,
        standardize="zscore_sample",
        verbose=1
    )

masker.fit()

confounds, sample_mask = load_confounds(
    fmri_path,
    strategy=["high_pass", "motion", "wm_csf"],
    motion="basic",
    wm_csf="basic"
)


In [None]:
    
    
time_series = []
try:
    for fmri, confound, sample_mask in zip(fmri_path, confounds, sample_mask):
        img = nib.load(fmri)
        ts = masker.transform(img, confound, sample_mask)
        
        # Check resampling
        try:
            maps_img = masker._resampled_maps_img_
        except AttributeError:
            maps_img = masker._resampled_labels_img_

        overlay_atlas(
            img.get_fdata().mean(axis=3),
            np.where(masker._resampled_maps_img_.get_fdata().mean(axis=3) == 0, 0, 1)
        )
        plt.suptitle(parse_bids_filename(fmri_path)["file_basename"])
        plt.show()
        

        time_series.append(ts)

except TypeError:
    print("Probably an empty sample mask")
    for fmri, confound in zip(fmri_path, confounds):
        img = nib.load(fmri)
        ts = masker.transform(img, confound)

        try:
            maps_img = masker._resampled_maps_img_.get_fdata().mean(axis=3)
        except AttributeError:
            maps_img = masker._resampled_labels_img_.get_fdata()
        # Check resampling
        overlay_atlas(
            img.get_fdata().mean(axis=3),
            maps_img
        )
        plt.suptitle(parse_bids_filename(fmri)["file_basename"])
        plt.show()
        
        time_series.append(ts)
    


In [None]:

cm = ConnectivityMeasure(kind="correlation")
res = cm.fit_transform(time_series).squeeze()

fig, axes = plt.subplots(3, 2, figsize=(10, 15))

for i, ts in enumerate(time_series):
    sub_dct = parse_bids_filename(fmri_path[i])
    ses = sub_dct["ses"]
    
    np.fill_diagonal(res[i, :, :], 0)
    
    # Group by networks?
    plotting.plot_matrix(
        res[i, :, :],
        #labels=atlas.labels.yeo_networks7,#[1:], # Remove background
        vmax=1,
        vmin=-1,
        axes=axes[i, 0],
        title=f"{ses}"
    )
    axes[i, 1].hist(res[i, :, :].flatten(), bins=30)
    axes[i, 1].set_xlim(-1, 1)
sub_id = sub_dct["sub"]
fig.suptitle(f"sub-{sub_id}, {ATLAS} atlas", y=1.01)
plt.show()


In [None]:

# SUPER BIZARRE TODO Show coords in atlas
# TODO Functionectome from 
if ATLAS in SOFT_ATLASES:
    coords = plotting.find_probabilistic_atlas_cut_coords(atlas.maps)
else:
    coords = plotting.find_parcellation_cut_coords(atlas.maps, label_hemisphere="right")
    
plotting.plot_connectome(
    res[0, :, :], coords, colorbar=True, title=ATLAS
)

plotting.show()

In [None]:
from sklearn.covariance import GraphicalLassoCV

estimator = GraphicalLassoCV()
estimator.fit(time_series[1])

plotting.plot_connectome(
    -estimator.precision_, coords, title="Sparse inverse covariance"
)

plotting.show()


In [None]:
plotting.plot_matrix(
        -estimator.precision_,
        #labels=atlas.labels.yeo_networks7,#[1:], # Remove background
)
plotting.show()