Skip to content

Commit

Permalink
[FIX] BIC/AIC for Lasso (scikit-learn#9022)
Browse files Browse the repository at this point in the history
* correcting information criterion calculation in least_angle.py

The information criterion calculation is not compatible with the original paper 
Zou, Hui, Trevor Hastie, and Robert Tibshirani. "On the “degrees of freedom” of the lasso." The Annals of Statistics 35.5 (2007): 2173-2192.
APA

* FIX : fix AIC/BIC computation in LassoLarsIC

* update what's new

* fix test

* fix test

* address comments

* DOC comments and docstring on criterion computation
  • Loading branch information
agramfort authored and Jeremiah Johnson committed Dec 18, 2017
1 parent 157fb12 commit 93d694a
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 16 deletions.
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ Bug fixes
- Add ``shuffle`` parameter to :func:`model_selection.train_test_split`.
:issue:`#8845` by :user:`themrmax <themrmax>`

- Fix AIC/BIC criterion computation in :class:`linear_model.LassoLarsIC`
by `Alexandre Gramfort`_ and :user:`Mehmet Basbug <mehmetbasbug>`.

API changes summary
-------------------

Expand Down
12 changes: 8 additions & 4 deletions sklearn/linear_model/least_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,8 +1400,10 @@ class LassoLarsIC(LassoLars):
criterion_ : array, shape (n_alphas,)
The value of the information criteria ('aic', 'bic') across all
alphas. The alpha which has the smallest information criteria
is chosen.
alphas. The alpha which has the smallest information criteria is chosen.
This value is larger by a factor of ``n_samples`` compared to Eqns. 2.15
and 2.16 in (Zou et al, 2007).
Examples
--------
Expand Down Expand Up @@ -1487,6 +1489,7 @@ def fit(self, X, y, copy_X=True):

R = y[:, np.newaxis] - np.dot(X, coef_path_) # residuals
mean_squared_error = np.mean(R ** 2, axis=0)
sigma2 = np.var(y)

df = np.zeros(coef_path_.shape[1], dtype=np.int) # Degrees of freedom
for k, coef in enumerate(coef_path_.T):
Expand All @@ -1499,8 +1502,9 @@ def fit(self, X, y, copy_X=True):
df[k] = np.sum(mask)

self.alphas_ = alphas_
with np.errstate(divide='ignore'):
self.criterion_ = n_samples * np.log(mean_squared_error) + K * df
eps64 = np.finfo('float64').eps
self.criterion_ = (n_samples * mean_squared_error / (sigma2 + eps64) +
K * df) # Eqns. 2.15--16 in (Zou et al, 2007)
n_best = np.argmin(self.criterion_)

self.alpha_ = alphas_[n_best]
Expand Down
13 changes: 2 additions & 11 deletions sklearn/linear_model/tests/test_least_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.testing import assert_no_warnings, assert_warns
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import TempMemmap
from sklearn.exceptions import ConvergenceWarning
from sklearn import linear_model, datasets
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_lasso_lars_ic():
rng = np.random.RandomState(42)
X = diabetes.data
y = diabetes.target
X = np.c_[X, rng.randn(X.shape[0], 4)] # add 4 bad features
X = np.c_[X, rng.randn(X.shape[0], 5)] # add 5 bad features
lars_bic.fit(X, y)
lars_aic.fit(X, y)
nonzero_bic = np.where(lars_bic.coef_)[0]
Expand All @@ -444,15 +444,6 @@ def test_lasso_lars_ic():
assert_raises(ValueError, lars_broken.fit, X, y)


def test_no_warning_for_zero_mse():
# LassoLarsIC should not warn for log of zero MSE.
y = np.arange(10, dtype=float)
X = y.reshape(-1, 1)
lars = linear_model.LassoLarsIC(normalize=False)
assert_no_warnings(lars.fit, X, y)
assert_true(np.any(np.isinf(lars.criterion_)))


def test_lars_path_readonly_data():
# When using automated memory mapping on large input, the
# fold data is in read-only mode
Expand Down
3 changes: 2 additions & 1 deletion sklearn/linear_model/tests/test_randomized_l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_raises_regex
from sklearn.utils.testing import assert_allclose

from sklearn.linear_model.randomized_l1 import (lasso_stability_path,
RandomizedLasso,
Expand Down Expand Up @@ -94,7 +95,7 @@ def test_randomized_lasso():
clf = RandomizedLasso(verbose=False, alpha='aic', random_state=42,
scaling=scaling)
feature_scores = clf.fit(X, y).scores_
assert_array_equal(feature_scores, X.shape[1] * [1.])
assert_allclose(feature_scores, [1., 1., 1., 0.225, 1.], rtol=0.2)

clf = RandomizedLasso(verbose=False, scaling=-0.1)
assert_raises(ValueError, clf.fit, X, y)
Expand Down

0 comments on commit 93d694a

Please sign in to comment.