diff --git a/docs/changes/newsfragments/451.doc b/docs/changes/newsfragments/451.doc new file mode 100644 index 0000000000..35098253ef --- /dev/null +++ b/docs/changes/newsfragments/451.doc @@ -0,0 +1 @@ +Add documentation on adding data types by `Synchon Mandal`_ diff --git a/docs/changes/newsfragments/451.feature b/docs/changes/newsfragments/451.feature new file mode 100644 index 0000000000..bf1109c440 --- /dev/null +++ b/docs/changes/newsfragments/451.feature @@ -0,0 +1 @@ +Enable data types to be added by introducing :func:`.register_data_type` by `Synchon Mandal`_ diff --git a/docs/extending/data_types.rst b/docs/extending/data_types.rst new file mode 100644 index 0000000000..4c229aea50 --- /dev/null +++ b/docs/extending/data_types.rst @@ -0,0 +1,57 @@ +.. include:: ../links.inc + +.. _adding_data_types: + +Adding Data Types +================= + +``junifer`` supports most of the :ref:`data types ` required for fMRI +but also provides a way to add custom data types in case you work with other +modalities like EEG. + +How to add a data type +---------------------- + +#. Check :ref:`extending junifer ` on how to create a + *junifer extension* if you have not done so. +#. Define the data type schema in the *extension script* like so: + + .. code-block:: python + + from junifer.datagrabber import DataTypeSchema + + + dtype_schema: DataTypeSchema = { + "mandatory": ["pattern"], + "optional": { + "mask": { + "mandatory": ["pattern"], + "optional": [], + }, + }, + } + + * The :obj:`.DataTypeSchema` has two mandatory keys: + + * ``mandatory`` : list of str + * ``optional`` : dict of str and :obj:`.OptionalTypeSchema` + + * ``mandatory`` defines the keys that must be present when defining a *pattern* + in a DataGrabber. + * ``optional`` defines the mapping from *sub-types* that are optional, to their patterns. + The patterns in turn require a ``mandatory`` key and an ``optional`` key both + just being the keys that must be there if the optional key is found. It's + possible that the *sub-type* (``mask`` in the example) can be absent from the dataset. + +#. Register the data type before defining / using a DataGrabber like so: + + .. code-block:: python + + from junifer.datagrabber import register_data_type + ... + + + # registers the data type as "dtype" + register_data_type(name="dtype", schema=dtype_schema) + + ... diff --git a/docs/extending/index.rst b/docs/extending/index.rst index 0fba1a95d8..a6ed8ca123 100644 --- a/docs/extending/index.rst +++ b/docs/extending/index.rst @@ -32,3 +32,4 @@ DataGrabbers, Preprocessors, Markers, etc., following the *junifer* way. masks plugins data_registries + data_types diff --git a/junifer/datagrabber/__init__.pyi b/junifer/datagrabber/__init__.pyi index 06c17f9d5f..51fcd6f492 100644 --- a/junifer/datagrabber/__init__.pyi +++ b/junifer/datagrabber/__init__.pyi @@ -10,7 +10,11 @@ __all__ = [ "DataladHCP1200", "MultipleDataGrabber", "DMCC13Benchmark", + "DataTypeManager", + "DataTypeSchema", + "OptionalTypeSchema", "PatternValidationMixin", + "register_data_type", ] # These 4 need to be in this order, otherwise it is a circular import @@ -24,4 +28,10 @@ from .hcp1200 import HCP1200, DataladHCP1200 from .multiple import MultipleDataGrabber from .dmcc13_benchmark import DMCC13Benchmark -from .pattern_validation_mixin import PatternValidationMixin +from .pattern_validation_mixin import ( + DataTypeManager, + DataTypeSchema, + OptionalTypeSchema, + PatternValidationMixin, + register_data_type, +) diff --git a/junifer/datagrabber/pattern_validation_mixin.py b/junifer/datagrabber/pattern_validation_mixin.py index 52479b1836..77ba2d6e72 100644 --- a/junifer/datagrabber/pattern_validation_mixin.py +++ b/junifer/datagrabber/pattern_validation_mixin.py @@ -3,71 +3,199 @@ # Authors: Synchon Mandal # License: AGPL +from collections.abc import Iterator, MutableMapping +from typing import TypedDict + from ..typing import DataGrabberPatterns from ..utils import logger, raise_error, warn_with_log -__all__ = ["PatternValidationMixin"] - - -# Define schema for pattern-based datagrabber's patterns -PATTERNS_SCHEMA = { - "T1w": { - "mandatory": ["pattern", "space"], - "optional": { - "mask": {"mandatory": ["pattern", "space"], "optional": []}, - }, - }, - "T2w": { - "mandatory": ["pattern", "space"], - "optional": { - "mask": {"mandatory": ["pattern", "space"], "optional": []}, - }, - }, - "BOLD": { - "mandatory": ["pattern", "space"], - "optional": { - "mask": {"mandatory": ["pattern", "space"], "optional": []}, - "confounds": { - "mandatory": ["pattern", "format"], - "optional": ["mappings"], - }, - "reference": {"mandatory": ["pattern"], "optional": []}, - "prewarp_space": {"mandatory": [], "optional": []}, - }, - }, - "Warp": { - "mandatory": ["pattern", "src", "dst", "warper"], - "optional": {}, - }, - "VBM_GM": { - "mandatory": ["pattern", "space"], - "optional": {}, - }, - "VBM_WM": { - "mandatory": ["pattern", "space"], - "optional": {}, - }, - "VBM_CSF": { - "mandatory": ["pattern", "space"], - "optional": {}, - }, - "DWI": { - "mandatory": ["pattern"], - "optional": {}, - }, - "FreeSurfer": { - "mandatory": ["pattern"], - "optional": { - "aseg": {"mandatory": ["pattern"], "optional": []}, - "norm": {"mandatory": ["pattern"], "optional": []}, - "lh_white": {"mandatory": ["pattern"], "optional": []}, - "rh_white": {"mandatory": ["pattern"], "optional": []}, - "lh_pial": {"mandatory": ["pattern"], "optional": []}, - "rh_pial": {"mandatory": ["pattern"], "optional": []}, - }, - }, -} +__all__ = [ + "DataTypeManager", + "DataTypeSchema", + "OptionalTypeSchema", + "PatternValidationMixin", + "register_data_type", +] + + +class OptionalTypeSchema(TypedDict): + """Optional type schema.""" + + mandatory: list[str] + optional: list[str] + + +class DataTypeSchema(TypedDict): + """Data type schema.""" + + mandatory: list[str] + optional: dict[str, OptionalTypeSchema] + + +class DataTypeManager(MutableMapping): + """Class for managing data types.""" + + _instance = None + + def __new__(cls): + """Overridden to make the class singleton.""" + # Make class singleton + if cls._instance is None: + cls._instance = super().__new__(cls) + # Set global schema + cls._global: dict[str, DataTypeSchema] = {} + cls._builtin: dict[str, DataTypeSchema] = {} + cls._external: dict[str, DataTypeSchema] = {} + cls._builtin.update( + { + "T1w": { + "mandatory": ["pattern", "space"], + "optional": { + "mask": { + "mandatory": ["pattern", "space"], + "optional": [], + }, + }, + }, + "T2w": { + "mandatory": ["pattern", "space"], + "optional": { + "mask": { + "mandatory": ["pattern", "space"], + "optional": [], + }, + }, + }, + "BOLD": { + "mandatory": ["pattern", "space"], + "optional": { + "mask": { + "mandatory": ["pattern", "space"], + "optional": [], + }, + "confounds": { + "mandatory": ["pattern", "format"], + "optional": ["mappings"], + }, + "reference": { + "mandatory": ["pattern"], + "optional": [], + }, + "prewarp_space": {"mandatory": [], "optional": []}, + }, + }, + "Warp": { + "mandatory": ["pattern", "src", "dst", "warper"], + "optional": {}, + }, + "VBM_GM": { + "mandatory": ["pattern", "space"], + "optional": {}, + }, + "VBM_WM": { + "mandatory": ["pattern", "space"], + "optional": {}, + }, + "VBM_CSF": { + "mandatory": ["pattern", "space"], + "optional": {}, + }, + "DWI": { + "mandatory": ["pattern"], + "optional": {}, + }, + "FreeSurfer": { + "mandatory": ["pattern"], + "optional": { + "aseg": {"mandatory": ["pattern"], "optional": []}, + "norm": {"mandatory": ["pattern"], "optional": []}, + "lh_white": { + "mandatory": ["pattern"], + "optional": [], + }, + "rh_white": { + "mandatory": ["pattern"], + "optional": [], + }, + "lh_pial": { + "mandatory": ["pattern"], + "optional": [], + }, + "rh_pial": { + "mandatory": ["pattern"], + "optional": [], + }, + }, + }, + } + ) + cls._global.update(cls._builtin) + return cls._instance + + def __getitem__(self, key: str) -> DataTypeSchema: + """Retrieve schema for ``key``.""" + return self._global[key] + + def __iter__(self) -> Iterator[str]: + """Iterate over data types.""" + return iter(self._global) + + def __len__(self) -> int: + """Get data type count.""" + return len(self._global) + + def __delitem__(self, key: str) -> None: + """Remove schema for ``key``.""" + # Internal check + if key in self._builtin: + raise_error(f"Cannot delete in-built key: {key}") + # Non-existing key + if key not in self._external: + raise_error(klass=KeyError, msg=key) + # Update external + _ = self._external.pop(key) + # Update global + _ = self._global.pop(key) + + def __setitem__(self, key: str, value: DataTypeSchema) -> None: + """Update ``key`` with ``value``.""" + # Internal check + if key in self._builtin: + raise_error(f"Cannot set value for in-built key: {key}") + # Value type check + if not isinstance(value, dict): + raise_error(f"Invalid value type: {type(value)}") + # Update external + self._external[key] = value + # Update global + self._global[key] = value + + def popitem(): + """Not implemented.""" + pass + + def clear(self): + """Not implemented.""" + pass + + def setdefault(self, key: str, value=None): + """Not implemented.""" + pass + + +def register_data_type(name: str, schema: DataTypeSchema) -> None: + """Register custom data type. + + Parameters + ---------- + name : str + The data type name. + schema : DataTypeSchema + The data type schema. + + """ + DataTypeManager()[name] = schema class PatternValidationMixin: @@ -311,12 +439,13 @@ def validate_patterns( msg="`patterns` must contain all `types`", klass=ValueError ) # Check against schema + dtype_mgr = DataTypeManager() for dtype_key, dtype_val in patterns.items(): # Check if valid data type is provided - if dtype_key not in PATTERNS_SCHEMA: + if dtype_key not in dtype_mgr: raise_error( f"Unknown data type: {dtype_key}, " - f"should be one of: {list(PATTERNS_SCHEMA.keys())}" + f"should be one of: {list(dtype_mgr.keys())}" ) # Conditional for list dtype vals like Warp if isinstance(dtype_val, list): @@ -324,14 +453,14 @@ def validate_patterns( # Check mandatory keys for data type self._validate_mandatory_keys( keys=list(entry), - schema=PATTERNS_SCHEMA[dtype_key]["mandatory"], + schema=dtype_mgr[dtype_key]["mandatory"], data_type=f"{dtype_key}.{idx}", partial_pattern_ok=partial_pattern_ok, ) # Check optional keys for data type - for optional_key, optional_val in PATTERNS_SCHEMA[ - dtype_key - ]["optional"].items(): + for optional_key, optional_val in dtype_mgr[dtype_key][ + "optional" + ].items(): if optional_key not in entry: logger.debug( f"Optional key: `{optional_key}` missing for " @@ -344,12 +473,12 @@ def validate_patterns( ) # Set nested type name for easier access nested_dtype = f"{dtype_key}.{idx}.{optional_key}" - nested_mandatory_keys_schema = PATTERNS_SCHEMA[ + nested_mandatory_keys_schema = dtype_mgr[ dtype_key ]["optional"][optional_key]["mandatory"] - nested_optional_keys_schema = PATTERNS_SCHEMA[ - dtype_key - ]["optional"][optional_key]["optional"] + nested_optional_keys_schema = dtype_mgr[dtype_key][ + "optional" + ][optional_key]["optional"] # Check mandatory keys for nested type self._validate_mandatory_keys( keys=list(optional_val["mandatory"]), @@ -392,10 +521,8 @@ def validate_patterns( self._identify_stray_keys( keys=list(entry.keys()), schema=( - PATTERNS_SCHEMA[dtype_key]["mandatory"] - + list( - PATTERNS_SCHEMA[dtype_key]["optional"].keys() - ) + dtype_mgr[dtype_key]["mandatory"] + + list(dtype_mgr[dtype_key]["optional"].keys()) ), data_type=dtype_key, ) @@ -412,12 +539,12 @@ def validate_patterns( # Check mandatory keys for data type self._validate_mandatory_keys( keys=list(dtype_val), - schema=PATTERNS_SCHEMA[dtype_key]["mandatory"], + schema=dtype_mgr[dtype_key]["mandatory"], data_type=dtype_key, partial_pattern_ok=partial_pattern_ok, ) # Check optional keys for data type - for optional_key, optional_val in PATTERNS_SCHEMA[dtype_key][ + for optional_key, optional_val in dtype_mgr[dtype_key][ "optional" ].items(): if optional_key not in dtype_val: @@ -432,12 +559,12 @@ def validate_patterns( ) # Set nested type name for easier access nested_dtype = f"{dtype_key}.{optional_key}" - nested_mandatory_keys_schema = PATTERNS_SCHEMA[ - dtype_key - ]["optional"][optional_key]["mandatory"] - nested_optional_keys_schema = PATTERNS_SCHEMA[ - dtype_key - ]["optional"][optional_key]["optional"] + nested_mandatory_keys_schema = dtype_mgr[dtype_key][ + "optional" + ][optional_key]["mandatory"] + nested_optional_keys_schema = dtype_mgr[dtype_key][ + "optional" + ][optional_key]["optional"] # Check mandatory keys for nested type self._validate_mandatory_keys( keys=list(optional_val["mandatory"]), @@ -476,8 +603,8 @@ def validate_patterns( self._identify_stray_keys( keys=list(dtype_val.keys()), schema=( - PATTERNS_SCHEMA[dtype_key]["mandatory"] - + list(PATTERNS_SCHEMA[dtype_key]["optional"].keys()) + dtype_mgr[dtype_key]["mandatory"] + + list(dtype_mgr[dtype_key]["optional"].keys()) ), data_type=dtype_key, ) diff --git a/junifer/datagrabber/tests/test_pattern_validation_mixin.py b/junifer/datagrabber/tests/test_pattern_validation_mixin.py index 76437d0693..83da5f5a57 100644 --- a/junifer/datagrabber/tests/test_pattern_validation_mixin.py +++ b/junifer/datagrabber/tests/test_pattern_validation_mixin.py @@ -9,7 +9,110 @@ import pytest -from junifer.datagrabber.pattern_validation_mixin import PatternValidationMixin +from junifer.datagrabber.pattern_validation_mixin import ( + DataTypeManager, + DataTypeSchema, + PatternValidationMixin, + register_data_type, +) + + +def test_dtype_mgr_addition_errors() -> None: + """Test data type manager addition errors.""" + with pytest.raises(ValueError, match="Cannot set"): + dtype_schema: DataTypeSchema = { + "mandatory": ["pattern"], + "optional": {}, + } + DataTypeManager()["T1w"] = dtype_schema + + with pytest.raises(ValueError, match="Invalid"): + DataTypeManager()["DType"] = "" + + +def test_dtype_mgr_removal_errors() -> None: + """Test data type manager removal errors.""" + with pytest.raises(ValueError, match="Cannot delete"): + _ = DataTypeManager().pop("T1w") + + with pytest.raises(KeyError, match="DType"): + del DataTypeManager()["DType"] + + +@pytest.mark.parametrize( + "dtype", + [ + { + "mandatory": ["pattern"], + "optional": {}, + }, + { + "mandatory": ["pattern"], + "optional": { + "subtype": { + "mandatory": [], + "optional": [], + } + }, + }, + { + "mandatory": ["pattern"], + "optional": { + "subtype": { + "mandatory": ["pattern"], + "optional": [], + } + }, + }, + { + "mandatory": ["pattern"], + "optional": { + "subtype": { + "mandatory": ["pattern"], + "optional": ["pattern"], + } + }, + }, + ], +) +def test_dtype_mgr(dtype: DataTypeSchema) -> None: + """Test data type manager addition and removal. + + Parameters + ---------- + dtype : DataTypeSchema + The parametrized schema. + + """ + + DataTypeManager().update({"DType": dtype}) + assert "DType" in DataTypeManager() + + _ = DataTypeManager().pop("DType") + assert "DType" not in DataTypeManager() + + +def test_register_data_type() -> None: + """Test data type registration.""" + + dtype_schema: DataTypeSchema = { + "mandatory": ["pattern"], + "optional": { + "mask": { + "mandatory": ["pattern"], + "optional": [], + }, + }, + } + + register_data_type( + name="dtype", + schema=dtype_schema, + ) + + assert "dtype" in DataTypeManager() + _ = DataTypeManager().pop("dtype") + assert "dumb" not in DataTypeManager() @pytest.mark.parametrize( diff --git a/junifer/typing/_typing.py b/junifer/typing/_typing.py index d95417ac3e..42d4e01bd8 100644 --- a/junifer/typing/_typing.py +++ b/junifer/typing/_typing.py @@ -64,7 +64,7 @@ ExternalDependencies = Sequence[MutableMapping[str, Union[str, Sequence[str]]]] MarkerInOutMappings = MutableMapping[str, MutableMapping[str, str]] DataGrabberPatterns = dict[ - str, Union[dict[str, str], Sequence[dict[str, str]]] + str, Union[dict[str, str], list[dict[str, str]]] ] ConfigVal = Union[bool, int, float, str] Element = Union[str, tuple[str, ...]]