Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add LassoCV to nilearn.decoding.DecoderRegressor #3781

Merged
merged 14 commits into from
Aug 17, 2023
Merged
2 changes: 2 additions & 0 deletions doc/changes/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Enhancements

- Make return key names in the description file of destrieux surface consistent with :func:`~datasets.fetch_atlas_surf_destrieux` (:gh:`3774` by `Tarun Samanta`_).

- Add ``LassoCV`` as a new estimator option for Decoder objects (:gh: `3781` by `Michelle Wang`_)

Changes
-------

Expand Down
11 changes: 11 additions & 0 deletions nilearn/_utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,17 @@ def custom_function(vertices):
svr = SVR(kernel="linear",
max_iter=1e4)

- `lasso`: \
:class:`{Lasso regression} <sklearn.linear_model.LassoCV>`.
.. code-block:: python

lasso = LassoCV()

- `lasso_regressor`: \
:class:`{Lasso regression} <sklearn.linear_model.LassoCV>`.
.. note::
Same option as `lasso`.

- `dummy_regressor`: \
:class:`{Dummy regressor} <sklearn.dummy.DummyRegressor>`.
.. code-block:: python
Expand Down
25 changes: 21 additions & 4 deletions nilearn/decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sklearn import clone
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.linear_model import (
LassoCV,
LinearRegression,
LogisticRegressionCV,
RidgeClassifierCV,
Expand Down Expand Up @@ -59,6 +60,8 @@
ridge_classifier=RidgeClassifierCV(),
ridge_regressor=RidgeCV(),
ridge=RidgeCV(),
lasso=LassoCV(),
lasso_regressor=LassoCV(),
svr=SVR(kernel="linear", max_iter=10000),
dummy_classifier=DummyClassifier(strategy="stratified", random_state=0),
dummy_regressor=DummyRegressor(strategy="mean"),
Expand Down Expand Up @@ -114,6 +117,8 @@ def _check_param_grid(estimator, X, y, param_grid=None):
elif isinstance(estimator, LogisticRegressionCV):
param_grid = _replace_param_grid_key(param_grid, "C", "Cs")
param_grid = _wrap_param_grid(param_grid, "Cs")
elif isinstance(estimator, LassoCV):
param_grid = _wrap_param_grid(param_grid, "alphas")

return param_grid

Expand Down Expand Up @@ -155,7 +160,14 @@ def _default_param_grid(estimator, X, y):
raise NotImplementedError(message)
elif not isinstance(
estimator,
(LogisticRegressionCV, LinearSVC, RidgeCV, RidgeClassifierCV, SVR),
(
LogisticRegressionCV,
LinearSVC,
RidgeCV,
RidgeClassifierCV,
SVR,
LassoCV,
),
):
raise ValueError(
"Invalid estimator. The supported estimators are:"
Expand Down Expand Up @@ -185,6 +197,11 @@ def _default_param_grid(estimator, X, y):
# so for L2 penalty, param_grid["Cs"] is either 1e-3, ..., 1e4, and
# for L1 penalty the values are obtained in a more data-driven way
param_grid["Cs"] = [np.geomspace(2e-3, 2e4, 8) * min_c]
elif isinstance(estimator, LassoCV):
# the default is to generate 30 alphas based on the data
# (alpha values can also be set with the 'alphas' parameter, in which
# case 'n_alphas' is ignored)
param_grid["n_alphas"] = [30]
elif isinstance(estimator, (LinearSVC, SVR)):
# similar logic as above:
# - for L2 penalty this is [1, 10, 100]
Expand Down Expand Up @@ -398,7 +415,7 @@ def _parallel_fit(
elif isinstance(estimator, DummyRegressor):
dummy_output = estimator.constant_

if isinstance(estimator, (RidgeCV, RidgeClassifierCV)):
if isinstance(estimator, (RidgeCV, RidgeClassifierCV, LassoCV)):
params["best_alpha"] = estimator.alpha_
elif isinstance(estimator, LogisticRegressionCV):
params["best_C"] = estimator.C_.item()
Expand Down Expand Up @@ -658,8 +675,8 @@ def fit(self, X, y, groups=None):
built-in cross-validation, this will include an additional key for
the single best value estimated by the built-in cross-validation
('best_C' for LogisticRegressionCV and 'best_alpha' for
RidgeCV/RidgeClassifierCV), in addition to the input list of
values.
RidgeCV/RidgeClassifierCV/LassoCV), in addition to the input list
of values.

'scorer_' : function
Scorer function used on the held out data to choose the best
Expand Down
9 changes: 8 additions & 1 deletion nilearn/decoding/tests/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import (
LassoCV,
LogisticRegressionCV,
RidgeClassifierCV,
RidgeCV,
Expand Down Expand Up @@ -140,7 +141,12 @@ def multiclass_data():


@pytest.mark.parametrize(
"regressor, param", [(RidgeCV(), ["alphas"]), (SVR(kernel="linear"), "C")]
"regressor, param",
[
(RidgeCV(), ["alphas"]),
(SVR(kernel="linear"), ["C"]),
(LassoCV(), ["n_alphas"]),
],
)
def test_check_param_grid_regression(regressor, param):
"""Test several estimators.
Expand Down Expand Up @@ -414,6 +420,7 @@ def test_parallel_fit(rand_X_Y):
(RidgeCV(), "alphas", "best_alpha", False),
(RidgeClassifierCV(), "alphas", "best_alpha", True),
(LogisticRegressionCV(), "Cs", "best_C", True),
(LassoCV(), "alphas", "best_alpha", False),
],
)
def test_parallel_fit_builtin_cv(
Expand Down