Skip to content

Commit

Permalink
Fix solve routine for NnlsL2 solver
Browse files Browse the repository at this point in the history
For some reason this was not working properly previously (i.e.
`Nnls(weights=True)` and `NnlsL2(weights=True, reg=0)` were giving
very different results).
  • Loading branch information
hunse committed May 30, 2016
1 parent e06fd04 commit 34ef8b3
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions nengo/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,17 +456,28 @@ def __init__(self, weights=False, reg=0.1):
super(NnlsL2, self).__init__(weights=weights)
self.reg = reg

def _solve(self, A, Y, rng, E, sigma):
def _solve(self, A, Y, rng, E, sigma=0.):
import scipy.optimize

tstart = time.time()
Y, m, n, _, matrix_in = format_system(A, Y)
Y = self.mul_encoders(Y, E, copy=True)
d = Y.shape[1]

# form Gram matrix so we can add regularization
GA = np.dot(A.T, A)
GY = np.dot(A.T, Y)
np.fill_diagonal(GA, GA.diagonal() + A.shape[0] * sigma**2)
X, info = super(NnlsL2, self).__call__(GA, GY, rng=rng, E=E)
GY = np.dot(A.T, Y.clip(0, None))
# ^ TODO: why is it better if we clip Y to be positive here?

X = np.zeros((n, d))
residuals = np.zeros(d)
for i in range(d):
X[:, i], residuals[i] = scipy.optimize.nnls(GA, GY[:, i])

t = time.time() - tstart
# recompute the RMSE in terms of the original matrices
info = {'rmses': rmses(A, X, Y), 'gram_info': info, 'time': t}
return X, info
info = {'rmses': rmses(A, X, Y), 'residuals': residuals, 'time': t}
return X if matrix_in or X.shape[1] > 1 else X.ravel(), info

def __call__(self, A, Y, rng=None, E=None):
return self._solve(A, Y, rng, E, sigma=self.reg * A.max())
Expand Down

0 comments on commit 34ef8b3

Please sign in to comment.