Skip to content

Commit

Permalink
fix intercept
Browse files Browse the repository at this point in the history
  • Loading branch information
mathurinm committed Apr 22, 2018
1 parent f14c3f2 commit 4c2c7ca
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
8 changes: 5 additions & 3 deletions celer/dropin_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ class Lasso(Lasso_sklearn):
"""

def __init__(self, alpha=1., max_iter=100, gap_freq=10,
max_epochs=50000, p0=10, verbose=0, tol=1e-6, prune=0):
max_epochs=50000, p0=10, verbose=0, tol=1e-6, prune=0,
fit_intercept=True):
super(Lasso, self).__init__(
alpha=alpha, tol=tol, max_iter=max_iter)
alpha=alpha, tol=tol, max_iter=max_iter,
fit_intercept=fit_intercept)
self.verbose = verbose
self.gap_freq = gap_freq
self.max_epochs = max_epochs
Expand Down Expand Up @@ -184,7 +186,7 @@ class LassoCV(LassoCV_sklearn):
"""

def __init__(self, eps=1e-3, n_alphas=100, alphas=None,
fit_intercept=False, max_iter=100,
fit_intercept=True, max_iter=100,
tol=1e-6, cv=None, verbose=0, gap_freq=10,
max_epochs=50000, p0=10, prune=0,
normalize=False, precompute='auto'):
Expand Down
9 changes: 6 additions & 3 deletions celer/tests/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,12 @@ def test_LassoCV_compatibility():
clf2 = sklearn_LassoCV(**params)
clf2.fit(X, y)

np.testing.assert_allclose(clf.mse_path_, clf2.mse_path_, rtol=1e-05)
np.testing.assert_allclose(clf.alpha_, clf2.alpha_, rtol=1e-05)
np.testing.assert_allclose(clf.coef_, clf2.coef_, rtol=1e-05)
np.testing.assert_allclose(clf.mse_path_, clf2.mse_path_,
rtol=1e-04)
np.testing.assert_allclose(clf.alpha_, clf2.alpha_,
rtol=1e-05)
np.testing.assert_allclose(clf.coef_, clf2.coef_,
rtol=1e-05)

check_estimator(LassoCV)

Expand Down

0 comments on commit 4c2c7ca

Please sign in to comment.