Skip to content

Commit

Permalink
Merge pull request scikit-learn#3778 from MechCoder/expose_positive
Browse files Browse the repository at this point in the history
[MRG] Expose positive option in elasticnet and lasso path
  • Loading branch information
agramfort committed Oct 16, 2014
2 parents b698d9f + 13c2b00 commit 0db66ba
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ API changes summary
``precompute="auto"`` is now deprecated and will be removed in 0.18
By `Manoj Kumar`_.

- Expose ``positive`` option in :func:`linear_model.enet_path` and
:func:`linear_model.enet_path` which constrains coefficients to be
positive. By `Manoj Kumar`_.

.. _changes_0_15_2:

0.15.2
Expand Down
14 changes: 10 additions & 4 deletions sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
precompute='auto', Xy=None, fit_intercept=None,
normalize=None, copy_X=True, coef_init=None,
verbose=False, return_models=False, return_n_iter=False,
**params):
positive=False, **params):
"""Compute Lasso path with coordinate descent
The Lasso optimization function varies for mono and multi-outputs.
Expand Down Expand Up @@ -182,6 +182,9 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
params : kwargs
keyword arguments passed to the coordinate descent solver.
positive : bool, default False
If set to True, forces coefficients to be positive.
Returns
-------
models : a list of models along the regularization path
Expand Down Expand Up @@ -266,14 +269,15 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
alphas=alphas, precompute=precompute, Xy=Xy,
fit_intercept=fit_intercept, normalize=normalize,
copy_X=copy_X, coef_init=coef_init, verbose=verbose,
return_models=return_models, **params)
return_models=return_models, positive=positive,
**params)


def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
precompute='auto', Xy=None, fit_intercept=True,
normalize=False, copy_X=True, coef_init=None,
verbose=False, return_models=False, return_n_iter=False,
**params):
positive=False, **params):
"""Compute elastic net path with coordinate descent
The elastic net optimization function varies for mono and multi-outputs.
Expand Down Expand Up @@ -359,6 +363,9 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
return_n_iter : bool
whether to return the number of iterations or not.
positive : bool, default False
If set to True, forces coefficients to be positive.
Returns
-------
models : a list of models along the regularization path
Expand Down Expand Up @@ -459,7 +466,6 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,

n_alphas = len(alphas)
tol = params.get('tol', 1e-4)
positive = params.get('positive', False)
max_iter = params.get('max_iter', 1000)
dual_gaps = np.empty(n_alphas)
n_iters = []
Expand Down
13 changes: 12 additions & 1 deletion sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from sklearn.linear_model.coordinate_descent import Lasso, \
LassoCV, ElasticNet, ElasticNetCV, MultiTaskLasso, MultiTaskElasticNet, \
MultiTaskElasticNetCV, MultiTaskLassoCV, lasso_path
MultiTaskElasticNetCV, MultiTaskLassoCV, lasso_path, enet_path
from sklearn.linear_model import LassoLarsCV, lars_path


Expand Down Expand Up @@ -574,6 +574,17 @@ def test_deprection_precompute_enet():
assert_warns(DeprecationWarning, clf.fit, X, y)


def test_enet_path_positive():
"""
Test that the coefs returned by positive=True in enet_path are positive
"""

X, y, _, _ = build_dataset(n_samples=50, n_features=50)
for path in [enet_path, lasso_path]:
pos_path_coef = path(X, y, positive=True)[1]
assert_true(np.all(pos_path_coef >= 0))


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 0db66ba

Please sign in to comment.