Skip to content

Commit

Permalink
Revert "Revert "First draft of surface stuff after discussion during …
Browse files Browse the repository at this point in the history
…meetings & drop-in hours" (nilearn#3848)"

This reverts commit e002d24.
  • Loading branch information
jeromedockes committed Jul 20, 2023
1 parent e002d24 commit 1728c0b
Show file tree
Hide file tree
Showing 6 changed files with 494 additions and 0 deletions.
159 changes: 159 additions & 0 deletions examples/08_experimental/surface_image_and_maskers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""A short demo of the surface images & maskers
copied from the nilearn sandbox discussion, to be transformed into tests &
examples
"""
from typing import Optional, Sequence

from matplotlib import pyplot as plt

from nilearn import plotting
from nilearn.experimental import surface


def plot_surf_img(
img: surface.SurfaceImage,
parts: Optional[Sequence[str]] = None,
mesh: Optional[surface.PolyMesh] = None,
**kwargs,
) -> plt.Figure:
if mesh is None:
mesh = img.mesh
if parts is None:
parts = list(img.data.keys())
fig, axes = plt.subplots(
1,
len(parts),
subplot_kw={"projection": "3d"},
figsize=(4 * len(parts), 4),
)
for ax, mesh_part in zip(axes, parts):
plotting.plot_surf(
mesh[mesh_part],
img.data[mesh_part],
axes=ax,
title=mesh_part,
**kwargs,
)
assert isinstance(fig, plt.Figure)
return fig


img = surface.fetch_nki()[0]
print(f"NKI image: {img}")

masker = surface.SurfaceMasker()
masked_data = masker.fit_transform(img)
print(f"Masked data shape: {masked_data.shape}")

mean_data = masked_data.mean(axis=0)
mean_img = masker.inverse_transform(mean_data)
print(f"Image mean: {mean_img}")

plot_surf_img(mean_img)
plotting.show()

###############################################################################
# ### Connectivity with a surface atlas and `SurfaceLabelsMasker`

from nilearn import connectome, plotting

img = surface.fetch_nki()[0]
print(f"NKI image: {img}")

labels_img, label_names = surface.fetch_destrieux()
print(f"Destrieux image: {labels_img}")
plot_surf_img(labels_img, cmap="gist_ncar", avg_method="median")

labels_masker = surface.SurfaceLabelsMasker(labels_img, label_names).fit()
masked_data = labels_masker.transform(img)
print(f"Masked data shape: {masked_data.shape}")

connectome = (
connectome.ConnectivityMeasure(kind="correlation").fit([masked_data]).mean_
)
plotting.plot_matrix(connectome, labels=labels_masker.label_names_)

plotting.show()


###############################################################################
# ### Using the `Decoder`

import numpy as np

from nilearn import decoding, plotting
from nilearn._utils import param_validation

###############################################################################
# The following is just disabling a couple of checks performed by the decoder
# that would force us to use a `NiftiMasker`.


def monkeypatch_masker_checks():
def adjust_screening_percentile(screening_percentile, *args, **kwargs):
return screening_percentile

param_validation._adjust_screening_percentile = adjust_screening_percentile

def check_embedded_nifti_masker(estimator, *args, **kwargs):
return estimator.mask

decoding.decoder._check_embedded_nifti_masker = check_embedded_nifti_masker


monkeypatch_masker_checks()

###############################################################################
# Now using the appropriate masker we can use a `Decoder` on surface data just
# as we do for volume images.

img = surface.fetch_nki()[0]
y = np.random.RandomState(0).choice([0, 1], replace=True, size=img.shape[0])

decoder = decoding.Decoder(
mask=surface.SurfaceMasker(),
param_grid={"C": [0.01, 0.1]},
cv=3,
screening_percentile=1,
)
decoder.fit(img, y)
print("CV scores:", decoder.cv_scores_)

plot_surf_img(decoder.coef_img_[0], threshold=1e-6)
plotting.show()

###############################################################################
# ### Decoding with a scikit-learn `Pipeline`

import numpy as np
from sklearn import feature_selection, linear_model, pipeline, preprocessing

from nilearn import plotting

img = surface.fetch_nki()[0]
y = np.random.RandomState(0).normal(size=img.shape[0])

decoder = pipeline.make_pipeline(
surface.SurfaceMasker(),
preprocessing.StandardScaler(),
feature_selection.SelectKBest(
score_func=feature_selection.f_regression, k=500
),
linear_model.Ridge(),
)
decoder.fit(img, y)

coef_img = decoder[:-1].inverse_transform(np.atleast_2d(decoder[-1].coef_))


vmax = max([np.absolute(dp).max() for dp in coef_img.data.values()])
plot_surf_img(
coef_img,
cmap="cold_hot",
vmin=-vmax,
vmax=vmax,
threshold=1e-6,
)
plotting.show()
29 changes: 29 additions & 0 deletions nilearn/experimental/surface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from nilearn.experimental.surface._datasets import (
fetch_destrieux,
fetch_nki,
load_fsaverage,
)
from nilearn.experimental.surface._maskers import (
SurfaceLabelsMasker,
SurfaceMasker,
)
from nilearn.experimental.surface._surface_image import (
SurfaceImage,
Mesh,
PolyMesh,
FileMesh,
InMemoryMesh,
)

__all__ = [
"FileMesh",
"InMemoryMesh",
"Mesh",
"PolyMesh",
"SurfaceImage",
"SurfaceLabelsMasker",
"SurfaceMasker",
"fetch_destrieux",
"fetch_nki",
"load_fsaverage",
]
62 changes: 62 additions & 0 deletions nilearn/experimental/surface/_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Fetching a few example datasets to use during development.
eventually nilearn.datasets would be updated
"""
from typing import Dict, Sequence, Tuple

from nilearn import datasets
from nilearn.experimental.surface import _io
from nilearn.experimental.surface._surface_image import (
FileMesh,
Mesh,
PolyMesh,
SurfaceImage,
)


def load_fsaverage(mesh_name: str = "fsaverage5") -> Dict[str, PolyMesh]:
fsaverage = datasets.fetch_surf_fsaverage(mesh_name)
meshes: Dict[str, Dict[str, Mesh]] = {}
renaming = {"pial": "pial", "white": "white_matter", "infl": "inflated"}
for mesh_type, mesh_name in renaming.items():
meshes[mesh_name] = {}
for hemisphere in "left", "right":
meshes[mesh_name][f"{hemisphere}_hemisphere"] = FileMesh(
fsaverage[f"{mesh_type}_{hemisphere}"]
)
return meshes


def fetch_nki(n_subjects=1) -> Sequence[SurfaceImage]:
fsaverage = load_fsaverage("fsaverage5")
nki_dataset = datasets.fetch_surf_nki_enhanced(n_subjects=n_subjects)
images = []
for left, right in zip(
nki_dataset["func_left"], nki_dataset["func_right"]
):
left_data = _io.read_array(left).T
right_data = _io.read_array(right).T
img = SurfaceImage(
{"left_hemisphere": left_data, "right_hemisphere": right_data},
mesh=fsaverage["pial"],
)
images.append(img)
return images


def fetch_destrieux() -> Tuple[SurfaceImage, Dict[int, str]]:
fsaverage = load_fsaverage("fsaverage5")
destrieux = datasets.fetch_atlas_surf_destrieux()
label_names = {
i: label.decode("utf-8") for (i, label) in enumerate(destrieux.labels)
}
return (
SurfaceImage(
{
"left_hemisphere": destrieux["map_left"],
"right_hemisphere": destrieux["map_right"],
},
mesh=fsaverage["pial"],
),
label_names,
)
15 changes: 15 additions & 0 deletions nilearn/experimental/surface/_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pathlib
from typing import Dict, Union

import numpy as np

from nilearn import surface as old_surface


def read_array(array_file: Union[pathlib.Path, str]) -> np.ndarray:
return old_surface.load_surf_data(array_file)


def read_mesh(mesh_file: Union[pathlib.Path, str]) -> Dict[str, np.ndarray]:
loaded = old_surface.load_surf_mesh(mesh_file)
return {"coordinates": loaded.coordinates, "faces": loaded.faces}

0 comments on commit 1728c0b

Please sign in to comment.