Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Ensure that fetchers for atlas and func return a Bunch #4233

Merged
merged 10 commits into from
Jan 22, 2024
4 changes: 2 additions & 2 deletions examples/07_advanced/plot_beta_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@
from nilearn.datasets import fetch_language_localizer_demo_dataset
from nilearn.glm.first_level import FirstLevelModel, first_level_from_bids

data_dir, _ = fetch_language_localizer_demo_dataset()
data = fetch_language_localizer_demo_dataset(legacy_output=False)

models, models_run_imgs, events_dfs, models_confounds = first_level_from_bids(
data_dir,
data.data_dir,
"languagelocalizer",
img_filters=[("desc", "preproc")],
n_jobs=2,
Expand Down
6 changes: 3 additions & 3 deletions examples/07_advanced/plot_bids_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
# confounds.tsv files.
from nilearn.datasets import fetch_language_localizer_demo_dataset

data_dir, _ = fetch_language_localizer_demo_dataset()
data = fetch_language_localizer_demo_dataset(legacy_output=False)

# %%
# Here is the location of the dataset on disk.
print(data_dir)
print(data.data_dir)

# %%
# Obtain automatically FirstLevelModel objects and fit arguments
Expand All @@ -57,7 +57,7 @@
models_events,
models_confounds,
) = first_level_from_bids(
data_dir, task_label, img_filters=[("desc", "preproc")], n_jobs=2
data.data_dir, task_label, img_filters=[("desc", "preproc")], n_jobs=2
)

# %%
Expand Down
8 changes: 4 additions & 4 deletions examples/07_advanced/plot_surface_bids_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
# confounds.tsv files.
from nilearn.datasets import fetch_language_localizer_demo_dataset

data_dir, _ = fetch_language_localizer_demo_dataset()
data = fetch_language_localizer_demo_dataset(legacy_output=False)

# %%
# Here is the location of the dataset on disk.
print(data_dir)
print(data.data_dir)

# %%
# Obtain automatically FirstLevelModel objects and fit arguments
Expand All @@ -53,7 +53,7 @@
task_label = 'languagelocalizer'
_, models_run_imgs, models_events, models_confounds = \
first_level_from_bids(
data_dir, task_label,
data.data_dir, task_label,
Remi-Gau marked this conversation as resolved.
Show resolved Hide resolved
img_filters=[('desc', 'preproc')],
n_jobs=2
)
Expand All @@ -65,7 +65,7 @@
import os

json_file = os.path.join(
data_dir,
data.data_dir,
'derivatives',
'sub-01',
'func',
Expand Down
44 changes: 40 additions & 4 deletions nilearn/datasets/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,24 +2168,44 @@ def _reduce_confounds(regressors, keep_confounds):


@fill_doc
def fetch_language_localizer_demo_dataset(data_dir=None, verbose=1):
def fetch_language_localizer_demo_dataset(
data_dir=None, verbose=1, legacy_output=True
):
"""Download language localizer demo dataset.

Parameters
----------
%(data_dir)s
%(verbose)s
legacy_output: bool, default=True

.. versionadded:: 0.11.0
.. deprecated:: 0.11.0

Starting from version 0.13.0
the ``legacy_ouput`` argument will be removed
and the fetcher will always return
a ``sklearn.datasets.base.Bunch``.


Returns
-------
data : sklearn.datasets.base.Bunch
Remi-Gau marked this conversation as resolved.
Show resolved Hide resolved
Dictionary-like object, the interest attributes are :

- 'data_dir': :obj:`str` Path to downloaded dataset.
- 'func': :obj:`list` of :obj:`str`,
Absolute paths of downloaded files on disk
- 'description' : :obj:`str`, dataset description

Legacy output
-------------
data_dir : :obj:`str`
Path to downloaded dataset.

downloaded_files : :obj:`list` of :obj:`str`
Absolute paths of downloaded files on disk

description : :obj:`str`

"""
url = "https://osf.io/3dj2a/download"
# When it starts working again change back to:
Expand All @@ -2208,7 +2228,23 @@ def fetch_language_localizer_demo_dataset(data_dir=None, verbose=1):
for path, _, files in os.walk(data_dir)
for f in files
]
return data_dir, sorted(file_list)
if legacy_output:
warnings.warn(
category=DeprecationWarning,
stacklevel=2,
message=(
"From version 0.13.0 this fetcher"
"will always return a Bunch.\n"
"Use `legacy_output=False` "
"to start switch to this new behavior."
),
)
return data_dir, sorted(file_list)

description = get_dataset_descr("language_localizer_demo")
return Bunch(
data_dir=data_dir, func=sorted(file_list), description=description
)


@fill_doc
Expand Down
32 changes: 32 additions & 0 deletions nilearn/datasets/tests/test_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd
import pytest
from numpy.testing import assert_array_equal
from sklearn.utils import Bunch

from nilearn._utils import data_gen
from nilearn._utils.testing import serialize_niimg
Expand All @@ -22,6 +23,24 @@
from nilearn.image import get_data


@pytest.mark.parametrize(
"fn",
[
atlas.fetch_atlas_allen_2011,
atlas.fetch_atlas_basc_multiscale_2015,
atlas.fetch_atlas_schaefer_2018,
atlas.fetch_atlas_smith_2009,
atlas.fetch_atlas_yeo_2011,
atlas.fetch_coords_dosenbach_2010,
atlas.fetch_coords_power_2011,
atlas.fetch_coords_seitzman_2018,
],
)
def test_atlas_fetcher_return_bunch(fn):
data = fn()
assert isinstance(data, Bunch)


def test_downloader(tmp_path, request_mocker):
# Sandboxing test
# ===============
Expand Down Expand Up @@ -218,6 +237,7 @@ def test_fetch_atlas_fsl(
data_dir=tmp_path,
symmetric_split=split,
)
assert isinstance(atlas_instance, Bunch)
_test_atlas_instance_should_match_data(
atlas_instance,
is_symm=is_symm or split,
Expand All @@ -239,6 +259,7 @@ def test_fetch_atlas_craddock_2012(tmp_path, request_mocker):
bunch = atlas.fetch_atlas_craddock_2012(
data_dir=tmp_path, verbose=0, homogeneity="spatial"
)
assert isinstance(bunch, Bunch)
bunch_rand = atlas.fetch_atlas_craddock_2012(
data_dir=tmp_path, verbose=0, homogeneity="random"
)
Expand Down Expand Up @@ -359,6 +380,8 @@ def test_fetch_atlas_destrieux_2009(tmp_path, request_mocker, lateralized):
lateralized=lateralized, data_dir=tmp_path, verbose=0
)

assert isinstance(bunch, Bunch)

assert request_mocker.url_count == 1

name = "_lateralized" if lateralized else ""
Expand Down Expand Up @@ -394,6 +417,7 @@ def test_fetch_atlas_msdl(tmp_path, request_mocker):
)
dataset = atlas.fetch_atlas_msdl(data_dir=tmp_path, verbose=0)

assert isinstance(dataset, Bunch)
assert isinstance(dataset.labels, list)
assert isinstance(dataset.region_coords, list)
assert isinstance(dataset.networks, list)
Expand Down Expand Up @@ -447,6 +471,7 @@ def test_fetch_atlas_difumo(tmp_path, request_mocker):
data_dir=tmp_path, dimension=dim, resolution_mm=res, verbose=0
)

assert isinstance(dataset, Bunch)
assert len(dataset.keys()) == 3
assert len(dataset.labels) == dim
assert isinstance(dataset.maps, str)
Expand Down Expand Up @@ -507,6 +532,7 @@ def test_fetch_atlas_aal(
version=version, data_dir=tmp_path, verbose=0
)

assert isinstance(dataset, Bunch)
assert isinstance(dataset.maps, str)
assert isinstance(dataset.labels, list)
assert isinstance(dataset.indices, list)
Expand Down Expand Up @@ -656,6 +682,8 @@ def test_fetch_atlas_surf_destrieux(tmp_path):

bunch = atlas.fetch_atlas_surf_destrieux(data_dir=tmp_path, verbose=0)

assert isinstance(bunch, Bunch)

# Our mock annots have 4 labels
assert len(bunch.labels) == 4
assert bunch.map_left.shape == (4,)
Expand Down Expand Up @@ -683,6 +711,8 @@ def test_fetch_atlas_talairach(tmp_path, request_mocker):
level_values = np.ones((81, 3)) * [0, 1, 2]
talairach = atlas.fetch_atlas_talairach("hemisphere", data_dir=tmp_path)

assert isinstance(talairach, Bunch)

assert talairach.description != ""

assert_array_equal(
Expand Down Expand Up @@ -713,6 +743,8 @@ def test_fetch_atlas_pauli_2017(tmp_path, request_mocker):

data = atlas.fetch_atlas_pauli_2017("det", data_dir)

assert isinstance(data, Bunch)

assert data.description != ""

assert len(data.labels) == 16
Expand Down