Skip to content

Commit

Permalink
Merge pull request #2011 from arokem/fix-2010
Browse files Browse the repository at this point in the history
BF: Use the `sklearn.base` interface, instead of deprecated `sklearn.linear_model.base`
  • Loading branch information
skoudoro committed Dec 6, 2019
2 parents 1fb90ab + 308c4e8 commit 108cd11
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions dipy/reconst/sfm.py
Expand Up @@ -35,7 +35,8 @@
from dipy.reconst.cache import Cache
from dipy.core.onetime import auto_attr

lm, has_sklearn, _ = optional_package('sklearn.linear_model')
sklearn, has_sklearn, _ = optional_package('sklearn')
lm, _, _ = optional_package('sklearn.linear_model')

# If sklearn is unavailable, we can fall back on nnls (but we also warn the
# user that we are about to do that):
Expand Down Expand Up @@ -307,12 +308,17 @@ def __init__(self, gtab, sphere=None, response=[0.0015, 0.0005, 0.0005],
The eigenvalues of a canonical tensor to be used as the response
function of single-fascicle signals.
Default:[0.0015, 0.0005, 0.0005]
solver : string, dipy.core.optimize.SKLearnLinearSolver object, or sklearn.linear_model.base.LinearModel object, optional.
solver : string, or initialized linear model object.
This will determine the algorithm used to solve the set of linear
equations underlying this model. If it is a string it needs to be
one of the following: {'ElasticNet', 'NNLS'}. Otherwise, it can be
an object that inherits from `dipy.optimize.SKLearnLinearSolver`.
an object that inherits from `dipy.optimize.SKLearnLinearSolver`
or an object with a similar interface from Scikit Learn:
`sklearn.linear_model.ElasticNet`, `sklearn.linear_model.Lasso` or
`sklearn.linear_model.Ridge` and other objects that inherit from
`sklearn.base.RegressorMixin`.
Default: 'ElasticNet'.
l1_ratio : float, optional
Sets the balance betwee L1 and L2 regularization in ElasticNet
[Zou2005]_. Default: 0.5
Expand Down Expand Up @@ -356,7 +362,7 @@ def __init__(self, gtab, sphere=None, response=[0.0015, 0.0005, 0.0005],
self.solver = opt.NonNegativeLeastSquares()

elif (isinstance(solver, opt.SKLearnLinearSolver) or
has_sklearn and isinstance(solver, lm.base.LinearModel)):
has_sklearn and isinstance(solver, sklearn.base.RegressorMixin)):
self.solver = solver

else:
Expand Down

0 comments on commit 108cd11

Please sign in to comment.