Skip to content

Commit

Permalink
[ENH] Add LassoCV to nilearn.decoding.DecoderRegressor (#3781)
Browse files Browse the repository at this point in the history
* add LassoCV to DecoderRegressor

* update decoder tests

* add 'lasso'/'lasso_regressor' to regressor options

* update changelog entry

* use n_alphas=30 as default instead of 100

* add lasso regression decoding to example

* fix formatting in example

* add lasso(_regressor) to list of estimators in docs

* revert example

* fix formatting
  • Loading branch information
michellewang committed Aug 17, 2023
1 parent 48aecfa commit b7a7adf
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 5 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Enhancements

- Make return key names in the description file of destrieux surface consistent with :func:`~datasets.fetch_atlas_surf_destrieux` (:gh:`3774` by `Tarun Samanta`_).

- Add ``LassoCV`` as a new estimator option for Decoder objects (:gh: `3781` by `Michelle Wang`_)

Changes
-------

Expand Down
2 changes: 2 additions & 0 deletions doc/decoding/estimator_choice.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ In :class:`nilearn.decoding.DecoderRegressor` you can use some of these objects

* `ridge_regressor` (same as `ridge`) : `Ridge regression <https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.RidgeCV.html>`_.

* `lasso_regressor` (same as `lasso`) : `Lasso regression <https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html>`_.

* `dummy_regressor` : A `dummy regressor <https://scikit-learn.org/stable/modules/generated/sklearn.dummy.DummyRegressor.html>`_ is a regressor that makes predictions using simple rules. It is useful as a simple baseline to compare with other regressors.

.. note::
Expand Down
11 changes: 11 additions & 0 deletions nilearn/_utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,17 @@ def custom_function(vertices):
svr = SVR(kernel="linear",
max_iter=1e4)
- `lasso`: \
:class:`{Lasso regression} <sklearn.linear_model.LassoCV>`.
.. code-block:: python
lasso = LassoCV()
- `lasso_regressor`: \
:class:`{Lasso regression} <sklearn.linear_model.LassoCV>`.
.. note::
Same option as `lasso`.
- `dummy_regressor`: \
:class:`{Dummy regressor} <sklearn.dummy.DummyRegressor>`.
.. code-block:: python
Expand Down
25 changes: 21 additions & 4 deletions nilearn/decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sklearn import clone
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.linear_model import (
LassoCV,
LinearRegression,
LogisticRegressionCV,
RidgeClassifierCV,
Expand Down Expand Up @@ -59,6 +60,8 @@
ridge_classifier=RidgeClassifierCV(),
ridge_regressor=RidgeCV(),
ridge=RidgeCV(),
lasso=LassoCV(),
lasso_regressor=LassoCV(),
svr=SVR(kernel="linear", max_iter=10000),
dummy_classifier=DummyClassifier(strategy="stratified", random_state=0),
dummy_regressor=DummyRegressor(strategy="mean"),
Expand Down Expand Up @@ -114,6 +117,8 @@ def _check_param_grid(estimator, X, y, param_grid=None):
elif isinstance(estimator, LogisticRegressionCV):
param_grid = _replace_param_grid_key(param_grid, "C", "Cs")
param_grid = _wrap_param_grid(param_grid, "Cs")
elif isinstance(estimator, LassoCV):
param_grid = _wrap_param_grid(param_grid, "alphas")

return param_grid

Expand Down Expand Up @@ -155,7 +160,14 @@ def _default_param_grid(estimator, X, y):
raise NotImplementedError(message)
elif not isinstance(
estimator,
(LogisticRegressionCV, LinearSVC, RidgeCV, RidgeClassifierCV, SVR),
(
LogisticRegressionCV,
LinearSVC,
RidgeCV,
RidgeClassifierCV,
SVR,
LassoCV,
),
):
raise ValueError(
"Invalid estimator. The supported estimators are:"
Expand Down Expand Up @@ -185,6 +197,11 @@ def _default_param_grid(estimator, X, y):
# so for L2 penalty, param_grid["Cs"] is either 1e-3, ..., 1e4, and
# for L1 penalty the values are obtained in a more data-driven way
param_grid["Cs"] = [np.geomspace(2e-3, 2e4, 8) * min_c]
elif isinstance(estimator, LassoCV):
# the default is to generate 30 alphas based on the data
# (alpha values can also be set with the 'alphas' parameter, in which
# case 'n_alphas' is ignored)
param_grid["n_alphas"] = [30]
elif isinstance(estimator, (LinearSVC, SVR)):
# similar logic as above:
# - for L2 penalty this is [1, 10, 100]
Expand Down Expand Up @@ -398,7 +415,7 @@ def _parallel_fit(
elif isinstance(estimator, DummyRegressor):
dummy_output = estimator.constant_

if isinstance(estimator, (RidgeCV, RidgeClassifierCV)):
if isinstance(estimator, (RidgeCV, RidgeClassifierCV, LassoCV)):
params["best_alpha"] = estimator.alpha_
elif isinstance(estimator, LogisticRegressionCV):
params["best_C"] = estimator.C_.item()
Expand Down Expand Up @@ -658,8 +675,8 @@ def fit(self, X, y, groups=None):
built-in cross-validation, this will include an additional key for
the single best value estimated by the built-in cross-validation
('best_C' for LogisticRegressionCV and 'best_alpha' for
RidgeCV/RidgeClassifierCV), in addition to the input list of
values.
RidgeCV/RidgeClassifierCV/LassoCV), in addition to the input list
of values.
'scorer_' : function
Scorer function used on the held out data to choose the best
Expand Down
9 changes: 8 additions & 1 deletion nilearn/decoding/tests/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import (
LassoCV,
LogisticRegressionCV,
RidgeClassifierCV,
RidgeCV,
Expand Down Expand Up @@ -140,7 +141,12 @@ def multiclass_data():


@pytest.mark.parametrize(
"regressor, param", [(RidgeCV(), ["alphas"]), (SVR(kernel="linear"), "C")]
"regressor, param",
[
(RidgeCV(), ["alphas"]),
(SVR(kernel="linear"), ["C"]),
(LassoCV(), ["n_alphas"]),
],
)
def test_check_param_grid_regression(regressor, param):
"""Test several estimators.
Expand Down Expand Up @@ -414,6 +420,7 @@ def test_parallel_fit(rand_X_Y):
(RidgeCV(), "alphas", "best_alpha", False),
(RidgeClassifierCV(), "alphas", "best_alpha", True),
(LogisticRegressionCV(), "Cs", "best_C", True),
(LassoCV(), "alphas", "best_alpha", False),
],
)
def test_parallel_fit_builtin_cv(
Expand Down

0 comments on commit b7a7adf

Please sign in to comment.