New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sparse Fascicle Model #460
Changes from all commits
78e0ab7
caf6e56
80a2058
a0f8116
98d1c55
2dd5e27
344bbe1
3f040f9
833406a
c0943d4
0c1be53
bb940e9
36353ba
0d044e1
4a861df
04b655d
0508e02
4263173
2f81db6
d181d3b
5691af9
e0f683e
29d01f3
fc4efec
6e5bb74
807e46b
fae6afb
ccb27f9
01910e1
8ef004c
6cea1a9
791de59
d262cb8
2908326
809d23c
14b3dce
bd3a145
bb4c06d
851757c
59fdc53
e2ae7d0
8ef225c
c80e8ae
eb324e3
b1ea909
4f371fb
248474c
4ef15c2
2fa5ffc
03fbdfa
d92274e
3472d9b
36d8b37
c3f99f6
8cdde64
9741a19
0979ca4
7945229
2c1152f
c979ed4
cc55cf3
3319a60
c149a48
fcc6e8e
8f15f36
7b66242
9647bc8
0508a7e
f9209cd
d1b03bb
558ebf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,11 +3,13 @@ | |
Only L-BFGS-B and Powell is supported in this class for versions of | ||
Scipy < 0.12. All optimizers are available for scipy >= 0.12. | ||
""" | ||
|
||
import abc | ||
from distutils.version import LooseVersion | ||
import numpy as np | ||
import scipy | ||
import scipy.sparse as sps | ||
import scipy.optimize as opt | ||
from dipy.utils.six import with_metaclass | ||
|
||
SCIPY_LESS_0_12 = LooseVersion(scipy.__version__) < '0.12' | ||
|
||
|
@@ -218,7 +220,6 @@ def __init__(self, fun, x0, args=(), method='L-BFGS-B', jac=None, | |
|
||
def history_of_x(kx): | ||
self._evol_kx.append(kx) | ||
|
||
res = minimize(fun, x0, args, method, jac, hess, hessp, bounds, | ||
constraints, tol, callback=history_of_x, | ||
options=options) | ||
|
@@ -293,24 +294,6 @@ def spdot(A, B): | |
return np.dot(A, B) | ||
|
||
|
||
def rsq(ss_residuals, ss_residuals_to_mean): | ||
""" | ||
Calculate: $R^2 = \frac{1-SSE}{\sigma^2}$ | ||
|
||
Parameters | ||
---------- | ||
ss_residuals : array | ||
Model fit errors relative to the data | ||
ss_residuals_to_mean : array | ||
Residuals of the data relative to the mean of the data (variance) | ||
|
||
Returns | ||
------- | ||
rsq : the variance explained. | ||
""" | ||
return 100 * (1 - ss_residuals/ss_residuals_to_mean) | ||
|
||
|
||
def sparse_nnls(y, X, | ||
momentum=1, | ||
step_size=0.01, | ||
|
@@ -357,17 +340,14 @@ def sparse_nnls(y, X, | |
h_best : The best estimate of the parameters. | ||
|
||
""" | ||
num_data = y.shape[0] | ||
num_regressors = X.shape[1] | ||
# Initialize the parameters at the origin: | ||
h = np.zeros(num_regressors) | ||
# If nothing good happens, we'll return that: | ||
h_best = h | ||
gradient = np.zeros(num_regressors) | ||
iteration = 1 | ||
count = 1 | ||
ss_residuals_min = np.inf # This will keep track of the best solution | ||
ss_residuals_to_mean = np.sum((y - np.mean(y)) ** 2) # The variance of y | ||
sse_best = np.inf # This will keep track of the best performance so far | ||
count_bad = 0 # Number of times estimation error has gone up. | ||
error_checks = 0 # How many error checks have we done so far | ||
|
@@ -396,21 +376,73 @@ def sparse_nnls(y, X, | |
if sse < ss_residuals_min: | ||
# Update your expectations about the minimum error: | ||
ss_residuals_min = sse | ||
n_iterations = iteration # This holds the number of iterations | ||
# for the best solution so far. | ||
h_best = h # This holds the best params we have so far | ||
|
||
# Are we generally (over iterations) converging on | ||
# sufficient improvement in r-squared? | ||
if sse < converge_on_sse * sse_best: | ||
sse_best = sse | ||
count_bad = 0 | ||
else: | ||
count_bad +=1 | ||
count_bad += 1 | ||
else: | ||
count_bad += 1 | ||
|
||
if count_bad >= max_error_checks: | ||
return h_best | ||
error_checks += 1 | ||
iteration += 1 | ||
|
||
|
||
class SKLearnLinearSolver(with_metaclass(abc.ABCMeta, object)): | ||
""" | ||
Provide a sklearn-like uniform interface to algorithms that solve problems | ||
of the form: $y = Ax$ for $x$ | ||
|
||
Sub-classes of SKLearnLinearSolver should provide a 'fit' method that have | ||
the following signature: `SKLearnLinearSolver.fit(X, y)`, which would set | ||
an attribute `SKLearnLinearSolver.coef_`, with the shape (X.shape[1],), | ||
such that an estimate of y can be calculated as: | ||
`y_hat = np.dot(X, SKLearnLinearSolver.coef_.T)` | ||
""" | ||
def __init__(self, *args, **kwargs): | ||
self._args = args | ||
self._kwargs = kwargs | ||
|
||
@abc.abstractmethod | ||
def fit(self, X, y): | ||
"""Implement for all derived classes """ | ||
|
||
def predict(self, X): | ||
""" | ||
Predict using the result of the model | ||
|
||
Parameters | ||
---------- | ||
X : array-like (n_samples, n_features) | ||
Samples. | ||
|
||
Returns | ||
------- | ||
C : array, shape = (n_samples,) | ||
Predicted values. | ||
""" | ||
X = np.asarray(X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the custom implementation, instead of using the scikit-learn solver's predict method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It requires scikit-learn, and we're trying to deal with situations where that's not installed. Though, why wouldn't it? |
||
return np.dot(X, self.coef_.T) | ||
|
||
|
||
class NonNegativeLeastSquares(SKLearnLinearSolver): | ||
""" | ||
A sklearn-like interface to scipy.optimize.nnls | ||
|
||
""" | ||
def fit(self, X, y): | ||
""" | ||
Fit the NonNegativeLeastSquares linear model to data | ||
|
||
Parameters | ||
---------- | ||
|
||
""" | ||
coef, rnorm = opt.nnls(X, y) | ||
self.coef_ = coef | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not also make this an abstract method?