Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add save() and load() methods to MetaResult objects #771

Merged
merged 4 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,4 @@ For more information about fetching data from the internet, see :ref:`fetching t
:template: class.rst

base.NiMAREBase
base.Estimator
estimator.Estimator
126 changes: 1 addition & 125 deletions nimare/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import inspect
import logging
import pickle
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from collections import defaultdict

from nimare.results import MetaResult

LGR = logging.getLogger(__name__)


Expand Down Expand Up @@ -214,125 +212,3 @@ def load(cls, filename, compressed=True):
raise IOError(f"Pickled object must be {cls}, not {type(obj)}")

return obj


class Estimator(NiMAREBase):
"""Estimators take in Datasets and return MetaResults.

All Estimators must have a ``_fit`` method implemented, which applies algorithm-specific
methods to a Dataset and returns a dictionary of arrays to be converted into a MetaResult.

Users will interact with the ``_fit`` method by calling the user-facing ``fit`` method.
``fit`` takes in a ``Dataset``, calls ``_collect_inputs``, then ``_preprocess_input``,
then ``_fit``, and finally converts the dictionary returned by ``_fit`` into a ``MetaResult``.
"""

# Inputs that must be available in input Dataset. Keys are names of
# attributes to set; values are strings indicating location in Dataset.
_required_inputs = {}

def _collect_inputs(self, dataset, drop_invalid=True):
"""Search for, and validate, required inputs as necessary.

This method populates the ``inputs_`` attribute.

.. versionchanged:: 0.0.12

Renamed from ``_validate_input``.

Parameters
----------
dataset : :obj:`~nimare.dataset.Dataset`
drop_invalid : :obj:`bool`, optional
Whether to automatically drop any studies in the Dataset without valid data or not.
Default is True.

Attributes
----------
inputs_ : :obj:`dict`
A dictionary of required inputs for the Estimator, extracted from the Dataset.
The actual inputs collected in this attribute are determined by the
``_required_inputs`` variable that should be specified in each child class.
"""
if not hasattr(dataset, "slice"):
raise ValueError(
f"Argument 'dataset' must be a valid Dataset object, not a {type(dataset)}."
)

if self._required_inputs:
data = dataset.get(self._required_inputs, drop_invalid=drop_invalid)
# Do not overwrite existing inputs_ attribute.
# This is necessary for PairwiseCBMAEstimator, which validates two sets of coordinates
# in the same object.
# It makes the *strong* assumption that required inputs will not changes within an
# Estimator across fit calls, so all fields of inputs_ will be overwritten instead of
# retaining outdated fields from previous fit calls.
if not hasattr(self, "inputs_"):
self.inputs_ = {}

for k, v in data.items():
if v is None:
raise ValueError(
f"Estimator {self.__class__.__name__} requires input dataset to contain "
f"{k}, but no matching data were found."
)
self.inputs_[k] = v

@abstractmethod
def _preprocess_input(self, dataset):
"""Perform any additional preprocessing steps on data in self.inputs_.

Parameters
----------
dataset : :obj:`~nimare.dataset.Dataset`
The Dataset
"""
pass

@abstractmethod
def _fit(self, dataset):
"""Apply estimation to dataset and output results.

Must return a dictionary of results, where keys are names of images
and values are ndarrays.
"""
pass

def fit(self, dataset, drop_invalid=True):
"""Fit Estimator to Dataset.

Parameters
----------
dataset : :obj:`~nimare.dataset.Dataset`
Dataset object to analyze.
drop_invalid : :obj:`bool`, optional
Whether to automatically ignore any studies without the required data or not.
Default is False.

Returns
-------
:obj:`~nimare.results.MetaResult`
Results of Estimator fitting.

Attributes
----------
inputs_ : :obj:`dict`
Inputs used in _fit.

Notes
-----
The `fit` method is a light wrapper that runs input validation and
preprocessing before fitting the actual model. Estimators' individual
"fitting" methods are implemented as `_fit`, although users should
call `fit`.
"""
self._collect_inputs(dataset, drop_invalid=drop_invalid)
self._preprocess_input(dataset)
maps, tables = self._fit(dataset)

if hasattr(self, "masker") and self.masker is not None:
masker = self.masker
else:
masker = dataset.masker

return MetaResult(self, mask=masker, maps=maps, tables=tables)
127 changes: 127 additions & 0 deletions nimare/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Base class for estimators."""
from abc import abstractmethod

from nimare.base import NiMAREBase
from nimare.results import MetaResult


class Estimator(NiMAREBase):
"""Estimators take in Datasets and return MetaResults.

All Estimators must have a ``_fit`` method implemented, which applies algorithm-specific
methods to a Dataset and returns a dictionary of arrays to be converted into a MetaResult.

Users will interact with the ``_fit`` method by calling the user-facing ``fit`` method.
``fit`` takes in a ``Dataset``, calls ``_collect_inputs``, then ``_preprocess_input``,
then ``_fit``, and finally converts the dictionary returned by ``_fit`` into a ``MetaResult``.
"""

# Inputs that must be available in input Dataset. Keys are names of
# attributes to set; values are strings indicating location in Dataset.
_required_inputs = {}

def _collect_inputs(self, dataset, drop_invalid=True):
"""Search for, and validate, required inputs as necessary.

This method populates the ``inputs_`` attribute.

.. versionchanged:: 0.0.12

Renamed from ``_validate_input``.

Parameters
----------
dataset : :obj:`~nimare.dataset.Dataset`
drop_invalid : :obj:`bool`, optional
Whether to automatically drop any studies in the Dataset without valid data or not.
Default is True.

Attributes
----------
inputs_ : :obj:`dict`
A dictionary of required inputs for the Estimator, extracted from the Dataset.
The actual inputs collected in this attribute are determined by the
``_required_inputs`` variable that should be specified in each child class.
"""
if not hasattr(dataset, "slice"):
raise ValueError(
f"Argument 'dataset' must be a valid Dataset object, not a {type(dataset)}."
)

if self._required_inputs:
data = dataset.get(self._required_inputs, drop_invalid=drop_invalid)
# Do not overwrite existing inputs_ attribute.
# This is necessary for PairwiseCBMAEstimator, which validates two sets of coordinates
# in the same object.
# It makes the *strong* assumption that required inputs will not changes within an
# Estimator across fit calls, so all fields of inputs_ will be overwritten instead of
# retaining outdated fields from previous fit calls.
if not hasattr(self, "inputs_"):
self.inputs_ = {}

for k, v in data.items():
if v is None:
raise ValueError(
f"Estimator {self.__class__.__name__} requires input dataset to contain "
f"{k}, but no matching data were found."
)
self.inputs_[k] = v

@abstractmethod
def _preprocess_input(self, dataset):
"""Perform any additional preprocessing steps on data in self.inputs_.

Parameters
----------
dataset : :obj:`~nimare.dataset.Dataset`
The Dataset
"""
pass

@abstractmethod
def _fit(self, dataset):
"""Apply estimation to dataset and output results.

Must return a dictionary of results, where keys are names of images
and values are ndarrays.
"""
pass

def fit(self, dataset, drop_invalid=True):
"""Fit Estimator to Dataset.

Parameters
----------
dataset : :obj:`~nimare.dataset.Dataset`
Dataset object to analyze.
drop_invalid : :obj:`bool`, optional
Whether to automatically ignore any studies without the required data or not.
Default is False.

Returns
-------
:obj:`~nimare.results.MetaResult`
Results of Estimator fitting.

Attributes
----------
inputs_ : :obj:`dict`
Inputs used in _fit.

Notes
-----
The `fit` method is a light wrapper that runs input validation and
preprocessing before fitting the actual model. Estimators' individual
"fitting" methods are implemented as `_fit`, although users should
call `fit`.
"""
self._collect_inputs(dataset, drop_invalid=drop_invalid)
self._preprocess_input(dataset)
maps, tables = self._fit(dataset)

if hasattr(self, "masker") and self.masker is not None:
masker = self.masker
else:
masker = dataset.masker

return MetaResult(self, mask=masker, maps=maps, tables=tables)
2 changes: 1 addition & 1 deletion nimare/meta/cbma/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from scipy import ndimage
from tqdm.auto import tqdm

from nimare.base import Estimator
from nimare.estimator import Estimator
from nimare.meta.kernel import KernelTransformer
from nimare.meta.utils import _calculate_cluster_measures, _get_last_bin
from nimare.results import MetaResult
Expand Down
2 changes: 1 addition & 1 deletion nimare/meta/ibma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nilearn.input_data import NiftiMasker
from nilearn.mass_univariate import permuted_ols

from nimare.base import Estimator
from nimare.estimator import Estimator
from nimare.transforms import p_to_z, t_to_z
from nimare.utils import _boolean_unmask, _check_ncores, get_masker

Expand Down
3 changes: 2 additions & 1 deletion nimare/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import pandas as pd
from nibabel.funcs import squeeze_image

from nimare.base import NiMAREBase
from nimare.utils import get_masker

LGR = logging.getLogger(__name__)


class MetaResult(object):
class MetaResult(NiMAREBase):
"""Base class for meta-analytic results.

Parameters
Expand Down
43 changes: 28 additions & 15 deletions nimare/tests/test_meta_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import nimare
from nimare.correct import FDRCorrector, FWECorrector
from nimare.meta import ale
from nimare.results import MetaResult
from nimare.tests.utils import get_test_data_path
from nimare.utils import vox2mm


def test_ALE_approximate_null_unit(testdata_cbma, tmp_path_factory):
"""Unit test for ALE with approximate null_method."""
tmpdir = tmp_path_factory.mktemp("test_ALE_approximate_null_unit")
out_file = os.path.join(tmpdir, "file.pkl.gz")
est_out_file = os.path.join(tmpdir, "est_file.pkl.gz")
res_out_file = os.path.join(tmpdir, "res_file.pkl.gz")

meta = ale.ALE(null_method="approximate")
res = meta.fit(testdata_cbma)
Expand All @@ -31,20 +33,31 @@ def test_ALE_approximate_null_unit(testdata_cbma, tmp_path_factory):
assert res2 != res
assert isinstance(res, nimare.results.MetaResult)

# Test saving/loading
meta.save(out_file, compress=True)
assert os.path.isfile(out_file)
meta2 = ale.ALE.load(out_file, compressed=True)
assert isinstance(meta2, ale.ALE)
with pytest.raises(pickle.UnpicklingError):
ale.ALE.load(out_file, compressed=False)

meta.save(out_file, compress=False)
assert os.path.isfile(out_file)
meta2 = ale.ALE.load(out_file, compressed=False)
assert isinstance(meta2, ale.ALE)
with pytest.raises(OSError):
ale.ALE.load(out_file, compressed=True)
# Test saving/loading estimator
for compress in [True, False]:
meta.save(est_out_file, compress=compress)
assert os.path.isfile(est_out_file)
meta2 = ale.ALE.load(est_out_file, compressed=compress)
assert isinstance(meta2, ale.ALE)
if compress:
with pytest.raises(pickle.UnpicklingError):
ale.ALE.load(est_out_file, compressed=(not compress))
else:
with pytest.raises(OSError):
ale.ALE.load(est_out_file, compressed=(not compress))

# Test saving/loading MetaResult object
for compress in [True, False]:
res.save(res_out_file, compress=compress)
assert os.path.isfile(res_out_file)
res2 = MetaResult.load(res_out_file, compressed=compress)
assert isinstance(res2, MetaResult)
if compress:
with pytest.raises(pickle.UnpicklingError):
MetaResult.load(res_out_file, compressed=(not compress))
else:
with pytest.raises(OSError):
MetaResult.load(res_out_file, compressed=(not compress))

# Test MCC methods
# Monte Carlo FWE
Expand Down