diff --git a/conda-env.yml b/conda-env.yml index ac1c7544d3..80158fbd3d 100644 --- a/conda-env.yml +++ b/conda-env.yml @@ -13,7 +13,7 @@ dependencies: - ruamel.yaml>=0.17,<0.19 - h5py>=3.10 - tqdm>=4.66.1,<4.67.0 - - templateflow>=23.0.0 + - templateflow>=23.0.0,<25.0.0 - lapy>=1.0.0,<2.0.0 - lazy_loader==0.4 - importlib_metadata diff --git a/docs/builtin.rst b/docs/builtin.rst index e67fd7a328..ef45e3d545 100644 --- a/docs/builtin.rst +++ b/docs/builtin.rst @@ -743,5 +743,39 @@ Available Planned ~~~~~~~ + +Maps +---- + +.. + Provide a list of the maps that are implemented or planned. + + Version added: The junifer version in which the maps was added. + +Available +~~~~~~~~~ + +.. list-table:: + :widths: auto + :header-rows: 1 + + * - Name + - Options + - Keys + - Template Spaces + - Version Added + - Publication + * - Smith + - ``components``, ``dimension`` + - | ``Smith_rsn_10``, ``Smith_rsn_20``, ``Smith_rsn_70``, + | ``Smith_bm_10``, ``Smith_bm_20``, ``Smith_bm_70`` + - ``MNI152NLin6Asym`` + - 0.0.7 + - | S.M. Smith, P.T. Fox, K.L. Miller et al. + | Correspondence of the brain's functional architecture during activation and rest + | Proc. Natl. Acad. Sci. U.S.A. 106 (31) 13040-13045 (2009) + | https://doi.org/10.1073/pnas.0905267106 + + .. helpful site for creating tables: https://rest-sphinx-memo.readthedocs.io/en/latest/ReST.html#tables diff --git a/docs/changes/newsfragments/458.change b/docs/changes/newsfragments/458.change new file mode 100644 index 0000000000..7de395f943 --- /dev/null +++ b/docs/changes/newsfragments/458.change @@ -0,0 +1 @@ +:func:`.get_data`, :func:`.load_data`, :func:`.list_data`, :func:`.register_data` and :func:`.deregister_data` now support ``"maps"`` as a valid argument for ``kind`` by `Synchon Mandal`_ diff --git a/docs/changes/newsfragments/458.feature b/docs/changes/newsfragments/458.feature new file mode 100644 index 0000000000..50d1a7e4ca --- /dev/null +++ b/docs/changes/newsfragments/458.feature @@ -0,0 +1 @@ +Introduce :class:`.MapsRegistry` to centralise probabilistic atlas (maps) data management by `Synchon Mandal`_ diff --git a/docs/changes/newsfragments/458.misc b/docs/changes/newsfragments/458.misc new file mode 100644 index 0000000000..d82cb6619f --- /dev/null +++ b/docs/changes/newsfragments/458.misc @@ -0,0 +1 @@ +Constrain ``templateflow`` version to ``>=23.0.0,<25.0.0`` by `Synchon Mandal`_ diff --git a/docs/extending/index.rst b/docs/extending/index.rst index 2f3348dfb3..2e73168325 100644 --- a/docs/extending/index.rst +++ b/docs/extending/index.rst @@ -30,6 +30,7 @@ DataGrabbers, Preprocessors, Markers, etc., following the *junifer* way. parcellations coordinates masks + maps plugins data_registries data_types diff --git a/docs/extending/maps.rst b/docs/extending/maps.rst new file mode 100644 index 0000000000..74516adfa5 --- /dev/null +++ b/docs/extending/maps.rst @@ -0,0 +1,140 @@ +.. include:: ../links.inc + +.. _adding_maps: + +Adding Maps +=========== + +Maps in ``junifer`` are basically probabilistic atlases. They are differentiated +from parcellations for ease of computation and to handle edge cases in a sane manner. +Before you start adding your own maps, check whether ``junifer`` has +the map(s) :ref:`in-built already `. Perhaps, what is available +there will suffice to achieve your goals. However, of course ``junifer`` will not +have every map(s) available that you may want to use, and if so, it will +be nice to be able to add it yourself using a format that ``junifer`` understands. + +Similarly, you may even be interested in creating your own custom maps +and then adding them to ``junifer``, so you can use ``junifer`` to obtain +different Markers to assess and validate your own maps. So, how can you do +this? + +Since both of these use-cases are quite common, and not being able to use your +favourite map(s) is of course quite a buzzkill, ``junifer`` actually +provides the easy-to-use :func:`.register_data` function to do just that. +Let's try to understand the API reference and then use this function to register our +own map(s). + +From the API reference, we can see that it has 3 positional arguments: + +* ``kind`` +* ``name`` +* ``space`` + +as well as one optional keyword argument: ``overwrite``. +As the ``kind`` needs to be ``"maps"``, we can check +``MapsRegistry.register`` for keyword arguments to be passed: + +* ``maps_path`` +* ``maps_labels`` + +The ``name`` of the map(s) is up to you and will be the name that +``junifer`` will use to refer to this particular maps. You can think of +this as being similar to a key in a Python dictionary, i.e. a key that is used to +obtain and operate on the actual maps data. This ``name`` must always +be a string. For example, we could call our map(s) +``"my_custom_maps"`` (Note, that in a real-world use case this is +likely not a good name, and you should try to choose a meaningful name that +conveys as much relevant information about your maps as necessary). + +The ``maps_path`` must be a ``str`` or ``Path`` object indicating a +path to a valid 4D NIfTI image, which contains floating-point labels indicating the +individual regions-of-interest (ROIs) of your map(s). + +We also want to make sure that we can associate each label with +a human readable name (i.e. the name for each ROI). This serves naming the +features that maps-based markers produce in an unambiguous way, such +that a user can easily identify which ROIs were used to produce a specific +feature (multiple ROIs, because some features consist of information from two +or more ROIs, as for example in functional connectivity). Therefore, we provide +junifer with a list of strings, that contains the names for each ROI. In this +list, the label at the i-th position indicates the i-th index in the 4th dimension +of the NIfTI image. + +Lastly, we specify the ``space`` that the map(s) is in, for example, +``"MNI152NLin6Asym"`` or ``"native"`` (scanner-native space). + +Step 1: Prepare code to register a maps +--------------------------------------- + +Now we know everything that we need to know to make sure ``junifer`` can use our +own map(s) to compute any maps-based Marker. A simple example could +look like this: + +.. code-block:: python + + from pathlib import Path + + import numpy as np + from junifer.data import register_data + + + # these are of course just example paths, replace it with your own: + path_to_maps = ( + Path.cwd() / "my_custom_maps.nii.gz" + ) + path_to_labels = ( + Path.cwd() / "my_custom_maps_labels.txt" + ) + + my_labels = list(np.loadtxt(path_to_labels, dtype=str)) + + register_data( + kind="maps", + name="my_custom_maps", + maps_path=path_to_maps, + maps_labels=my_labels, + space="MNI152NLin6Asym", + ) + +We can run this code and it seems to work, however, how can we actually +include the custom map(s) in a ``junifer`` pipeline using a +:ref:`code-less YAML configuration `? + +Step 2: Add maps registration to the YAML file +---------------------------------------------- + +In order to use the maps in a ``junifer`` pipeline configured by a YAML +file, we can save the above code in a Python file, say +``registering_my_maps.py``. We can then simply add this file using the +``with`` keyword provided by ``junifer``: + +.. code-block:: yaml + + with: + - registering_my_maps.py + +Afterwards continue configuring the rest of the pipeline in this YAML file, and +you will be able to use this maps using the name you gave the +maps when registering it. For example, we can add a +``MapsAggregation`` Marker to demonstrate how this can be done: + +.. code-block:: yaml + + markers: + - name: CustomMaps_mean + kind: MapsAggregation + maps: my_custom_maps + method: mean + +Now, you can simply use this YAML file to run your pipeline. + +.. important:: + + It's important to keep in mind that if the paths given in + ``registering_my_maps.py`` are relative paths, they will be interpreted + by ``junifer`` as relative to the jobs directory (i.e. where ``junifer`` will + create submit files, logs directory and so on). For simplicity, you may just + want to use absolute paths to avoid confusion, yet using relative paths is + likely a better way to make your pipeline directory / repository more portable + and therefore more reproducible for others. Really, once you understand how + these paths are interpreted by ``junifer``, it is quite easy. diff --git a/junifer/data/__init__.pyi b/junifer/data/__init__.pyi index d4e8ac1cce..4702dd4dce 100644 --- a/junifer/data/__init__.pyi +++ b/junifer/data/__init__.pyi @@ -3,6 +3,7 @@ __all__ = [ "CoordinatesRegistry", "DataDispatcher", "ParcellationRegistry", + "MapsRegistry", "MaskRegistry", "get_data", "list_data", @@ -17,6 +18,7 @@ __all__ = [ from .pipeline_data_registry_base import BasePipelineDataRegistry from .coordinates import CoordinatesRegistry from .parcellations import ParcellationRegistry +from .maps import MapsRegistry from .masks import MaskRegistry from ._dispatch import ( diff --git a/junifer/data/_dispatch.py b/junifer/data/_dispatch.py index 81c1d5d4ef..dbf75de7c5 100644 --- a/junifer/data/_dispatch.py +++ b/junifer/data/_dispatch.py @@ -17,6 +17,7 @@ from ..utils import raise_error from .coordinates import CoordinatesRegistry +from .maps import MapsRegistry from .masks import MaskRegistry from .parcellations import ParcellationRegistry from .pipeline_data_registry_base import BasePipelineDataRegistry @@ -52,6 +53,7 @@ def __new__(cls): { "coordinates": CoordinatesRegistry, "parcellation": ParcellationRegistry, + "maps": MapsRegistry, "mask": MaskRegistry, } ) @@ -118,14 +120,14 @@ def get_data( extra_input: Optional[dict[str, Any]] = None, ) -> Union[ tuple[ArrayLike, list[str]], # coordinates - tuple["Nifti1Image", list[str]], # parcellation + tuple["Nifti1Image", list[str]], # parcellation / maps "Nifti1Image", # mask ]: """Get tailored ``kind`` for ``target_data``. Parameters ---------- - kind : {"coordinates", "parcellation", "mask"} + kind : {"coordinates", "parcellation", "mask", "maps"} Kind of data to fetch and apply. names : str or dict or list of str / dict The registered name(s) of the data. @@ -166,7 +168,7 @@ def list_data(kind: str) -> list[str]: Parameters ---------- - kind : {"coordinates", "parcellation", "mask"} + kind : {"coordinates", "parcellation", "mask", "maps"} Kind of data registry to list. Returns @@ -195,7 +197,9 @@ def load_data( **kwargs, ) -> Union[ tuple[ArrayLike, list[str], str], # coordinates - tuple[Optional["Nifti1Image"], list[str], Path, str], # parcellation + tuple[ + Optional["Nifti1Image"], list[str], Path, str + ], # parcellation / maps tuple[ Optional[Union["Nifti1Image", Callable]], Optional[Path], str ], # mask @@ -204,7 +208,7 @@ def load_data( Parameters ---------- - kind : {"coordinates", "parcellation", "mask"} + kind : {"coordinates", "parcellation", "mask", "maps"} Kind of data to load. name : str The registered name of the data. @@ -244,7 +248,7 @@ def register_data( Parameters ---------- - kind : {"coordinates", "parcellation", "mask"} + kind : {"coordinates", "parcellation", "mask", "maps"} Kind of data to register. name : str The name to register. @@ -277,7 +281,7 @@ def deregister_data(kind: str, name: str) -> None: Parameters ---------- - kind : {"coordinates", "parcellation", "mask"} + kind : {"coordinates", "parcellation", "mask", "maps"} Kind of data to register. name : str The name to de-register. diff --git a/junifer/data/maps/__init__.py b/junifer/data/maps/__init__.py new file mode 100644 index 0000000000..fcac784545 --- /dev/null +++ b/junifer/data/maps/__init__.py @@ -0,0 +1,9 @@ +"""Maps.""" + +# Authors: Synchon Mandal +# License: AGPL + +import lazy_loader as lazy + + +__getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__) diff --git a/junifer/data/maps/__init__.pyi b/junifer/data/maps/__init__.pyi new file mode 100644 index 0000000000..5b9c62f890 --- /dev/null +++ b/junifer/data/maps/__init__.pyi @@ -0,0 +1,5 @@ +__all__ = [ + "MapsRegistry", +] + +from ._maps import MapsRegistry diff --git a/junifer/data/maps/_ants_maps_warper.py b/junifer/data/maps/_ants_maps_warper.py new file mode 100644 index 0000000000..1991db1b19 --- /dev/null +++ b/junifer/data/maps/_ants_maps_warper.py @@ -0,0 +1,156 @@ +"""Provide class for maps space warping via ANTs.""" + +# Authors: Synchon Mandal +# License: AGPL + +import uuid +from typing import TYPE_CHECKING, Any, Optional + +import nibabel as nib + +from ...pipeline import WorkDirManager +from ...utils import logger, raise_error, run_ext_cmd +from ..template_spaces import get_template, get_xfm + + +if TYPE_CHECKING: + from nibabel.nifti1 import Nifti1Image + + +__all__ = ["ANTsMapsWarper"] + + +class ANTsMapsWarper: + """Class for maps space warping via ANTs. + + This class uses ANTs ``antsApplyTransforms`` for transformation. + + """ + + def warp( + self, + maps_name: str, + maps_img: "Nifti1Image", + src: str, + dst: str, + target_data: dict[str, Any], + warp_data: Optional[dict[str, Any]], + ) -> "Nifti1Image": + """Warp ``maps_img`` to correct space. + + Parameters + ---------- + maps_name : str + The name of the maps. + maps_img : nibabel.nifti1.Nifti1Image + The maps image to transform. + src : str + The data type or template space to warp from. + It should be empty string if ``dst="T1w"``. + dst : str + The data type or template space to warp to. + `"T1w"` is the only allowed data type and it uses the resampled T1w + found in ``target_data.reference``. The ``"reference"`` + key is added if the :class:`.SpaceWarper` is used or if the + data is provided in native space. + target_data : dict + The corresponding item of the data object to which the maps + will be applied. + warp_data : dict or None + The warp data item of the data object. The value is unused if + ``dst!="T1w"``. + + Returns + ------- + nibabel.nifti1.Nifti1Image + The transformed maps image. + + Raises + ------ + ValueError + If ``warp_data`` is None when ``dst="T1w"``. + + """ + # Create element-scoped tempdir so that warped maps is + # available later as nibabel stores file path reference for + # loading on computation + prefix = ( + f"ants_maps_warper_{maps_name}" + f"{'' if not src else f'_from_{src}'}_to_{dst}_" + f"{uuid.uuid1()}" + ) + element_tempdir = WorkDirManager().get_element_tempdir( + prefix=prefix, + ) + + # Native space warping + if dst == "native": # pragma: no cover + # Warp data check + if warp_data is None: + raise_error("No `warp_data` provided") + if "reference" not in target_data: + raise_error("No `reference` provided") + if "path" not in target_data["reference"]: + raise_error("No `path` provided in `reference`") + + logger.debug("Using ANTs for maps transformation") + + # Save existing maps image to a tempfile + prewarp_maps_path = element_tempdir / "prewarp_maps.nii.gz" + nib.save(maps_img, prewarp_maps_path) + + # Create a tempfile for warped output + warped_maps_path = element_tempdir / "maps_warped.nii.gz" + # Set antsApplyTransforms command + apply_transforms_cmd = [ + "antsApplyTransforms", + "-d 3", + "-e 3", + "-n LanczosWindowedSinc", + f"-i {prewarp_maps_path.resolve()}", + # use resampled reference + f"-r {target_data['reference']['path'].resolve()}", + f"-t {warp_data['path'].resolve()}", + f"-o {warped_maps_path.resolve()}", + ] + # Call antsApplyTransforms + run_ext_cmd(name="antsApplyTransforms", cmd=apply_transforms_cmd) + + # Template space warping + else: + logger.debug(f"Using ANTs to warp maps from {src} to {dst}") + + # Get xfm file + xfm_file_path = get_xfm(src=src, dst=dst) + # Get template space image + template_space_img = get_template( + space=dst, + target_img=maps_img, + extra_input=None, + ) + # Save template to a tempfile + template_space_img_path = element_tempdir / f"{dst}_T1w.nii.gz" + nib.save(template_space_img, template_space_img_path) + + # Save existing maps image to a tempfile + prewarp_maps_path = element_tempdir / "prewarp_maps.nii.gz" + nib.save(maps_img, prewarp_maps_path) + + # Create a tempfile for warped output + warped_maps_path = element_tempdir / "maps_warped.nii.gz" + # Set antsApplyTransforms command + apply_transforms_cmd = [ + "antsApplyTransforms", + "-d 3", + "-e 3", + "-n LanczosWindowedSinc", + f"-i {prewarp_maps_path.resolve()}", + f"-r {template_space_img_path.resolve()}", + f"-t {xfm_file_path.resolve()}", + f"-o {warped_maps_path.resolve()}", + ] + # Call antsApplyTransforms + run_ext_cmd(name="antsApplyTransforms", cmd=apply_transforms_cmd) + + # Load nifti + return nib.load(warped_maps_path) diff --git a/junifer/data/maps/_fsl_maps_warper.py b/junifer/data/maps/_fsl_maps_warper.py new file mode 100644 index 0000000000..fe00690f84 --- /dev/null +++ b/junifer/data/maps/_fsl_maps_warper.py @@ -0,0 +1,85 @@ +"""Provide class for maps space warping via FSL FLIRT.""" + +# Authors: Synchon Mandal +# License: AGPL + +import uuid +from typing import TYPE_CHECKING, Any + +import nibabel as nib + +from ...pipeline import WorkDirManager +from ...utils import logger, run_ext_cmd + + +if TYPE_CHECKING: + from nibabel.nifti1 import Nifti1Image + + +__all__ = ["FSLMapsWarper"] + + +class FSLMapsWarper: + """Class for maps space warping via FSL FLIRT. + + This class uses FSL FLIRT's ``applywarp`` for transformation. + + """ + + def warp( + self, + maps_name: str, + maps_img: "Nifti1Image", + target_data: dict[str, Any], + warp_data: dict[str, Any], + ) -> "Nifti1Image": # pragma: no cover + """Warp ``maps_img`` to correct space. + + Parameters + ---------- + maps_name : str + The name of the maps. + maps_img : nibabel.nifti1.Nifti1Image + The maps image to transform. + target_data : dict + The corresponding item of the data object to which the maps + will be applied. + warp_data : dict + The warp data item of the data object. + + Returns + ------- + nibabel.nifti1.Nifti1Image + The transformed maps image. + + """ + logger.debug("Using FSL for maps transformation") + + # Create element-scoped tempdir so that warped maps is + # available later as nibabel stores file path reference for + # loading on computation + element_tempdir = WorkDirManager().get_element_tempdir( + prefix=f"fsl_maps_warper_{maps_name}_{uuid.uuid1()}" + ) + + # Save existing maps image to a tempfile + prewarp_maps_path = element_tempdir / "prewarp_maps.nii.gz" + nib.save(maps_img, prewarp_maps_path) + + # Create a tempfile for warped output + warped_maps_path = element_tempdir / "maps_warped.nii.gz" + # Set applywarp command + applywarp_cmd = [ + "applywarp", + "--interp=spline", + f"-i {prewarp_maps_path.resolve()}", + # use resampled reference + f"-r {target_data['reference']['path'].resolve()}", + f"-w {warp_data['path'].resolve()}", + f"-o {warped_maps_path.resolve()}", + ] + # Call applywarp + run_ext_cmd(name="applywarp", cmd=applywarp_cmd) + + # Load nifti + return nib.load(warped_maps_path) diff --git a/junifer/data/maps/_maps.py b/junifer/data/maps/_maps.py new file mode 100644 index 0000000000..020d3da5fb --- /dev/null +++ b/junifer/data/maps/_maps.py @@ -0,0 +1,446 @@ +"""Provide a class for centralized maps data registry.""" + +# Authors: Synchon Mandal +# License: AGPL + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional, Union + +import nibabel as nib +import nilearn.image as nimg +import numpy as np +from junifer_data import get + +from ...utils import logger, raise_error +from ..pipeline_data_registry_base import BasePipelineDataRegistry +from ..utils import ( + JUNIFER_DATA_PARAMS, + closest_resolution, + get_dataset_path, + get_native_warper, +) +from ._ants_maps_warper import ANTsMapsWarper +from ._fsl_maps_warper import FSLMapsWarper + + +if TYPE_CHECKING: + from nibabel.nifti1 import Nifti1Image + + +__all__ = ["MapsRegistry"] + + +class MapsRegistry(BasePipelineDataRegistry): + """Class for maps data registry. + + This class is a singleton and is used for managing available maps + data in a centralized manner. + + """ + + def __init__(self) -> None: + """Initialize the class.""" + super().__init__() + # Each entry in registry is a dictionary that must contain at least + # the following keys: + # * 'family': the maps' family name (e.g., 'Smith') + # * 'space': the maps' space (e.g., 'MNI') + # and can also have optional key(s): + # * 'valid_resolutions': a list of valid resolutions for the + # maps (e.g., [1, 2]) + # The built-in maps are files that are shipped with the + # junifer-data dataset. + # Make built-in and external dictionaries for validation later + self._builtin = {} + self._external = {} + + # Add Smith + for comp in ["rsn", "bm"]: + for dim in [10, 20, 70]: + self._builtin.update( + { + f"Smith_{comp}_{dim}": { + "family": "Smith2009", + "components": comp, + "dimension": dim, + "space": "MNI152NLin6Asym", + } + } + ) + + # Update registry with built-in ones + self._registry.update(self._builtin) + + def register( + self, + name: str, + maps_path: Union[str, Path], + maps_labels: list[str], + space: str, + overwrite: bool = False, + ) -> None: + """Register a custom user map(s). + + Parameters + ---------- + name : str + The name of the map(s). + maps_path : str or pathlib.Path + The path to the map(s) file. + maps_labels : list of str + The list of labels for the map(s). + space : str + The template space of the map(s), e.g., "MNI152NLin6Asym". + overwrite : bool, optional + If True, overwrite an existing maps with the same name. + Does not apply to built-in maps (default False). + + Raises + ------ + ValueError + If the map(s) ``name`` is a built-in map(s) or + if the map(s) ``name`` is already registered and + ``overwrite=False``. + + """ + # Check for attempt of overwriting built-in maps + if name in self._builtin: + raise_error( + f"Map(s): {name} already registered as built-in map(s)." + ) + # Check for attempt of overwriting external maps + if name in self._external: + if overwrite: + logger.info(f"Overwriting map(s): {name}") + else: + raise_error( + f"Map(s): {name} already registered. Set " + "`overwrite=True` to update its value." + ) + # Convert str to Path + if not isinstance(maps_path, Path): + maps_path = Path(maps_path) + # Registration + logger.info(f"Registering map(s): {name}") + # Add user maps info + self._external[name] = { + "path": maps_path, + "labels": maps_labels, + "family": "CustomUserMaps", + "space": space, + } + # Update registry + self._registry[name] = { + "path": maps_path, + "labels": maps_labels, + "family": "CustomUserMaps", + "space": space, + } + + def deregister(self, name: str) -> None: + """De-register a custom user map(s). + + Parameters + ---------- + name : str + The name of the map(s). + + """ + logger.info(f"De-registering map(s): {name}") + # Remove maps info + _ = self._external.pop(name) + # Update registry + _ = self._registry.pop(name) + + def load( + self, + name: str, + target_space: str, + resolution: Optional[float] = None, + path_only: bool = False, + ) -> tuple[Optional["Nifti1Image"], list[str], Path, str]: + """Load map(s) and labels. + + Parameters + ---------- + name : str + The name of the map(s). + target_space : str + The desired space of the map(s). + resolution : float, optional + The desired resolution of the map(s) to load. If it is not + available, the closest resolution will be loaded. Preferably, use a + resolution higher than the desired one. By default, will load the + highest one (default None). + path_only : bool, optional + If True, the map(s) image will not be loaded (default False). + + Returns + ------- + Nifti1Image or None + Loaded map(s) image. + list of str + Map(s) labels. + pathlib.Path + File path to the map(s) image. + str + The space of the map(s). + + Raises + ------ + ValueError + If ``name`` is invalid or + if the map(s) family is invalid or + if the map(s) values and labels + don't have equal dimension or if the value range is invalid. + + """ + # Check for valid maps name + if name not in self._registry: + raise_error( + f"Map(s): {name} not found. Valid options are: {self.list}" + ) + + # Copy maps definition to avoid edits in original object + maps_def = self._registry[name].copy() + t_family = maps_def.pop("family") + space = maps_def.pop("space") + + # Check and get highest resolution + if space != target_space: + logger.info( + f"Map(s) will be warped from {space} to {target_space} " + "using highest resolution" + ) + resolution = None + + # Check if the maps family is custom or built-in + if t_family == "CustomUserMaps": + maps_fname = maps_def["path"] + maps_labels = maps_def["labels"] + elif t_family in [ + "Smith2009", + ]: + # Load maps and labels + if t_family == "Smith2009": + maps_fname, maps_labels = _retrieve_smith( + resolution=resolution, + **maps_def, + ) + else: # pragma: no cover + raise_error(f"Unknown map(s) family: {t_family}") + + # Load maps image and values + logger.info(f"Loading map(s): {maps_fname.absolute()!s}") + maps_img = None + if not path_only: + # Load image via nibabel + maps_img = nib.load(maps_fname) + # Get regions + maps_regions = maps_img.get_fdata().shape[-1] + # Check for dimension + if maps_regions != len(maps_labels): + raise_error( + f"Map(s) {name} has {maps_regions} " + f"regions but {len(maps_labels)} labels." + ) + + return maps_img, maps_labels, maps_fname, space + + def get( + self, + maps: str, + target_data: dict[str, Any], + extra_input: Optional[dict[str, Any]] = None, + ) -> tuple["Nifti1Image", list[str]]: + """Get map(s), tailored for the target image. + + Parameters + ---------- + maps : str + The name of the map(s). + target_data : dict + The corresponding item of the data object to which the map(s) + will be applied. + extra_input : dict, optional + The other fields in the data object. Useful for accessing other + data kinds that needs to be used in the computation of + map(s) (default None). + + Returns + ------- + Nifti1Image + The map(s) image. + list of str + Map(s) labels. + + Raises + ------ + ValueError + If ``extra_input`` is None when ``target_data``'s space is native. + + """ + # Check pre-requirements for space manipulation + target_space = target_data["space"] + logger.debug(f"Getting {maps} in {target_space} space.") + # Extra data type requirement check if target space is native + if target_space == "native": # pragma: no cover + # Check for extra inputs + if extra_input is None: + raise_error( + "No extra input provided, requires `Warp` and `T1w` " + "data types in particular for transformation to " + f"{target_data['space']} space for further computation." + ) + # Get native space warper spec + warper_spec = get_native_warper( + target_data=target_data, + other_data=extra_input, + ) + # Set target standard space to warp file space source + target_std_space = warper_spec["src"] + logger.debug( + f"Target space is native. Will warp from {target_std_space}" + ) + else: + # Set target standard space to target space + target_std_space = target_space + + # Get the min of the voxels sizes and use it as the resolution + target_img = target_data["data"] + resolution = np.min(target_img.header.get_zooms()[:3]) + + # Load maps + logger.debug(f"Loading map(s) {maps}") + img, labels, _, space = self.load( + name=maps, + resolution=resolution, + target_space=target_space, + ) + + # Convert maps spaces if required; + # cannot be "native" due to earlier check + if space != target_std_space: + logger.debug( + f"Warping {maps} to {target_std_space} space using ANTs." + ) + raw_img = ANTsMapsWarper().warp( + maps_name=maps, + maps_img=img, + src=space, + dst=target_std_space, + target_data=target_data, + warp_data=None, + ) + # Remove extra dimension added by ANTs + img = nimg.math_img("np.squeeze(img)", img=raw_img) + + if target_space != "native": + # No warping is going to happen, just resampling, because + # we are in the correct space + logger.debug(f"Resampling {maps} to target image.") + # Resample maps to target image + img = nimg.resample_to_img( + source_img=img, + target_img=target_img, + interpolation="continuous", + copy=True, + ) + else: # pragma: no cover + # Warp maps if target space is native as either + # the image is in the right non-native space or it's + # warped from one non-native space to another non-native space + logger.debug( + "Warping map(s) to native space using " + f"{warper_spec['warper']}." + ) + # extra_input check done earlier and warper_spec exists + if warper_spec["warper"] == "fsl": + img = FSLMapsWarper().warp( + maps_name="native", + maps_img=img, + target_data=target_data, + warp_data=warper_spec, + ) + elif warper_spec["warper"] == "ants": + img = ANTsMapsWarper().warp( + maps_name="native", + maps_img=img, + src="", + dst="native", + target_data=target_data, + warp_data=warper_spec, + ) + + return img, labels + + +def _retrieve_smith( + resolution: Optional[float] = None, + components: Optional[str] = None, + dimension: Optional[int] = None, +) -> tuple[Path, list[str]]: + """Retrieve Smith maps. + + Parameters + ---------- + resolution : 2.0, optional + The desired resolution of the maps to load. If it is not + available, the closest resolution will be loaded. Preferably, use a + resolution higher than the desired one. By default, will load the + highest one (default None). Available resolution for these + maps are 2mm. + components : {"rsn", "bm"}, optional + The components to load. "rsn" loads the resting-fMRI components and + "bm" loads the BrainMap components (default None). + dimension : {10, 20, 70}, optional + The number of dimensions to load (default None). + + Returns + ------- + pathlib.Path + File path to the maps image. + list of str + Maps labels. + + Raises + ------ + ValueError + If invalid value is provided for ``components`` or ``dimension``. + + """ + logger.info("Maps parameters:") + logger.info(f"\tresolution: {resolution}") + logger.info(f"\tcomponents: {components}") + logger.info(f"\tdimension: {dimension}") + + # Check resolution + _valid_resolutions = [2.0] + resolution = closest_resolution(resolution, _valid_resolutions) + + # Check components value + _valid_components = ["rsn", "bm"] + if components not in _valid_components: + raise_error( + f"The parameter `components` ({components}) needs to be one of " + f"the following: {_valid_components}" + ) + + # Check dimension value + _valid_dimension = [10, 20, 70] + if dimension not in _valid_dimension: + raise_error( + f"The parameter `dimension` ({dimension}) needs to be one of the " + f"following: {_valid_dimension}" + ) + + # Fetch file path + maps_img_path = get( + file_path=Path( + f"parcellations/Smith2009/{components}{dimension}.nii.gz" + ), + dataset_path=get_dataset_path(), + **JUNIFER_DATA_PARAMS, + ) + + return maps_img_path, [f"Map_{i}" for i in range(dimension)] diff --git a/junifer/data/maps/tests/test_maps.py b/junifer/data/maps/tests/test_maps.py new file mode 100644 index 0000000000..0e9fa506d0 --- /dev/null +++ b/junifer/data/maps/tests/test_maps.py @@ -0,0 +1,255 @@ +"""Provide tests for maps.""" + +# Authors: Synchon Mandal +# License: AGPL + +import pytest +from numpy.testing import assert_array_equal + +from junifer.data import ( + deregister_data, + get_data, + list_data, + load_data, + register_data, +) +from junifer.data.maps._maps import _retrieve_smith +from junifer.datagrabber import PatternDataladDataGrabber +from junifer.datareader import DefaultDataReader +from junifer.pipeline.utils import _check_ants +from junifer.testing.datagrabbers import PartlyCloudyTestingDataGrabber + + +def test_register_built_in_check() -> None: + """Test maps registration check for built-in maps.""" + with pytest.raises(ValueError, match=r"built-in"): + register_data( + kind="maps", + name="Smith_rsn_10", + maps_path="testmaps.nii.gz", + maps_labels=["1", "2"], + space="MNI", + ) + + +def test_list_incorrect() -> None: + """Test incorrect information check for list mapss.""" + assert "testmaps" not in list_data(kind="maps") + + +def test_register_overwrite() -> None: + """Test maps registration check for overwriting.""" + register_data( + kind="maps", + name="testmaps", + maps_path="testmaps.nii.gz", + maps_labels=["1", "2"], + space="MNI152NLin6Sym", + ) + with pytest.raises(ValueError, match=r"already registered"): + register_data( + kind="maps", + name="testmaps", + maps_path="testmaps.nii.gz", + maps_labels=["1", "2"], + space="MNI152NLin6Sym", + overwrite=False, + ) + + register_data( + kind="maps", + name="testmaps", + maps_path="testmaps.nii.gz", + maps_labels=["1", "2"], + space="MNI152NLin6Sym", + overwrite=True, + ) + + assert ( + load_data( + kind="maps", + name="testmaps", + target_space="MNI152NLin6Sym", + path_only=True, + )[2].name + == "testmaps.nii.gz" + ) + + +def test_register_valid_input() -> None: + """Test maps registration check for valid input.""" + maps, labels, maps_path, _ = load_data( + kind="maps", + name="Smith_rsn_10", + target_space="MNI152NLin6Asym", + ) + assert maps is not None + + # Test wrong number of labels + register_data( + kind="maps", + name="WrongLabels", + maps_path=maps_path, + maps_labels=labels[:5], + space="MNI152NLin6Asym", + ) + with pytest.raises(ValueError, match=r"has 10 regions but 5"): + load_data( + kind="maps", + name="WrongLabels", + target_space="MNI152NLin6Asym", + ) + + +def test_list() -> None: + """Test listing of available coordinates.""" + assert {"Smith_rsn_10", "Smith_bm_70"}.issubset( + set(list_data(kind="maps")) + ) + + +def test_load_nonexisting() -> None: + """Test loading maps that not exist.""" + with pytest.raises(ValueError, match=r"not found"): + load_data(kind="maps", name="nomaps", target_space="MNI152NLin6Sym") + + +def test_get() -> None: + """Test tailored maps fetch.""" + with PatternDataladDataGrabber( + uri="https://github.com/OpenNeuroDatasets/ds005226.git", + types=["BOLD"], + patterns={ + "BOLD": { + "pattern": ( + "derivatives/pre-processed_data/space-MNI/{subject}/" + "{subject-padded}_task-{task}_run-{run}_space-MNI152NLin6Asym" + "_res-2_desc-preproc_bold.nii.gz" + ), + "space": "MNI152NLin6Asym", + }, + }, + replacements=["subject", "subject-padded", "task", "run"], + ) as dg: + element = dg[("sub-01", "sub-001", "rest", "1")] + element_data = DefaultDataReader().fit_transform(element) + bold = element_data["BOLD"] + bold_img = bold["data"] + # Get tailored coordinates + tailored_maps, tailored_labels = get_data( + kind="maps", names="Smith_rsn_10", target_data=bold + ) + + # Check shape with original element data + assert tailored_maps.shape[:3] == bold_img.shape[:3] + + # Get raw maps + raw_maps, raw_labels, _, _ = load_data( + kind="maps", name="Smith_rsn_10", target_space="MNI152NLin6Asym" + ) + # Tailored and raw shape should be same + assert tailored_maps.shape[:3] == raw_maps.shape[:3] + assert tailored_labels == raw_labels + + +@pytest.mark.skipif( + _check_ants() is False, reason="requires ANTs to be in PATH" +) +def test_get_different_space() -> None: + """Test tailored maps fetch in different space.""" + with PartlyCloudyTestingDataGrabber() as dg: + element = dg["sub-01"] + element_data = DefaultDataReader().fit_transform(element) + bold = element_data["BOLD"] + bold_img = bold["data"] + # Get tailored coordinates + tailored_maps, tailored_labels = get_data( + kind="maps", names="Smith_rsn_10", target_data=bold + ) + + # Check shape with original element data + assert tailored_maps.shape[:3] == bold_img.shape[:3] + + # Get raw maps + raw_maps, raw_labels, _, _ = load_data( + kind="maps", + name="Smith_rsn_10", + target_space="MNI152NLin2009cAsym", + ) + # Tailored and raw should not be same + assert tailored_maps.shape[:3] != raw_maps.shape[:3] + assert tailored_labels == raw_labels + + +def test_deregister() -> None: + """Test maps deregistration.""" + deregister_data(kind="maps", name="testmaps") + assert "testmaps" not in list_data(kind="maps") + + +@pytest.mark.parametrize( + "resolution, components, dimension", + [ + (2.0, "rsn", 10), + (2.0, "rsn", 20), + (2.0, "rsn", 70), + (2.0, "bm", 10), + (2.0, "bm", 20), + (2.0, "bm", 70), + ], +) +def test_smith( + resolution: float, + components: str, + dimension: int, +) -> None: + """Test Smith maps. + + Parameters + ---------- + resolution : float + The parametrized resolution values. + components : str + The parametrized components values. + dimension : int + The parametrized dimension values. + + """ + maps = list_data(kind="maps") + maps_name = f"Smith_{components}_{dimension}" + assert maps_name in maps + + maps_file = f"{components}{dimension}.nii.gz" + # Load maps + img, label, img_path, space = load_data( + kind="maps", + name=maps_name, + target_space="MNI152NLin6Asym", + resolution=resolution, + ) + assert img is not None + assert img_path.name == maps_file + assert space == "MNI152NLin6Asym" + assert len(label) == dimension + assert_array_equal( + img.header["pixdim"][1:4], + 3 * [2.0], + ) + + +def test_retrieve_smith_incorrect_components() -> None: + """Test retrieve Smith with incorrect components.""" + with pytest.raises(ValueError, match="The parameter `components`"): + _retrieve_smith( + components="abc", + dimension=10, + ) + + +def test_retrieve_smith_incorrect_dimension() -> None: + """Test retrieve Smith with incorrect dimension.""" + with pytest.raises(ValueError, match="The parameter `dimension`"): + _retrieve_smith( + components="rsn", + dimension=100, + ) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index 8a75260104..1bd36b2f3a 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -231,8 +231,9 @@ def __init__(self) -> None: # * 'family': the mask's family name # (e.g., 'Vickery-Patil', 'Callable') # * 'space': the mask's space (e.g., 'MNI', 'inherit') - # The built-in masks are files that are shipped with the package in the - # data/masks directory. The user can also register their own masks. + # The built-in masks are files that are shipped either with the + # junifer-data dataset or computed on-demand. The user can also + # register their own masks. # Callable masks should be functions that take at least one parameter: # * `target_img`: the image to which the mask will be applied. # and should be included in the registry as a value to a key: `func`. diff --git a/junifer/data/parcellations/_parcellations.py b/junifer/data/parcellations/_parcellations.py index b9947c796a..b6fb09b7d4 100644 --- a/junifer/data/parcellations/_parcellations.py +++ b/junifer/data/parcellations/_parcellations.py @@ -55,7 +55,7 @@ def __init__(self) -> None: # and can also have optional key(s): # * 'valid_resolutions': a list of valid resolutions for the # parcellation (e.g., [1, 2]) - # The built-in coordinates are files that are shipped with the + # The built-in parcellations are files that are shipped with the # junifer-data dataset. # Make built-in and external dictionaries for validation later self._builtin = {} diff --git a/pyproject.toml b/pyproject.toml index 2c5d8056a1..e00b1f3c6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ "ruamel.yaml>=0.17,<0.19", "h5py>=3.10", "tqdm>=4.66.1,<4.67.0", - "templateflow>=23.0.0", + "templateflow>=23.0.0,<25.0.0", "lapy>=1.0.0,<2.0.0", "lazy_loader==0.4", "importlib_metadata; python_version<'3.9'",