Skip to content

Commit

Permalink
Change syntax np.dot(A,x) to A.dot(x)
Browse files Browse the repository at this point in the history
This way the user can provide a sparse matrix without a problem
  • Loading branch information
rodrigo-pena committed Oct 6, 2017
1 parent 7048b63 commit d1af42a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
6 changes: 3 additions & 3 deletions pyunlocbox/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,20 @@ def __init__(self, y=0, A=None, At=None, tight=True, nu=1, tol=1e-3,
self.A = A
else:
# Transform matrix form to operator form.
self.A = lambda x: np.dot(A, x)
self.A = lambda x: A.dot(x)

if At is None:
if A is None:
self.At = lambda x: x
elif callable(A):
self.At = A
else:
self.At = lambda x: np.dot(np.transpose(A), x)
self.At = lambda x: A.T.dot(x)
else:
if callable(At):
self.At = At
else:
self.At = lambda x: np.dot(At, x)
self.At = lambda x: At.dot(x)

self.tight = tight
self.nu = nu
Expand Down
6 changes: 3 additions & 3 deletions pyunlocbox/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,20 +735,20 @@ def __init__(self, L=None, Lt=None, d0=None, *args, **kwargs):
self.L = L
else:
# Transform matrix form to operator form.
self.L = lambda x: np.dot(L, x)
self.L = lambda x: L.dot(x)

if Lt is None:
if L is None:
self.Lt = lambda x: x
elif callable(L):
self.Lt = L
else:
self.Lt = lambda x: np.dot(np.transpose(L), x)
self.Lt = lambda x: L.T.dot(x)
else:
if callable(Lt):
self.Lt = Lt
else:
self.Lt = lambda x: np.dot(Lt, x)
self.Lt = lambda x: Lt.dot(x)

self.d0 = d0

Expand Down
3 changes: 0 additions & 3 deletions pyunlocbox/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,8 @@ def assert_equivalent(param1, param2):
nptest.assert_array_equal(f1.grad(x), f2.grad(x))

# Default parameters. Callable or matrices.
assert_equivalent({}, {'A': None, 'y': 0, 'At': 1})
assert_equivalent({'y': 3.2}, {'y': lambda: 3.2})
assert_equivalent({'A': None}, {'A': np.identity(3)})
assert_equivalent({'A': None}, {'A': 1})
assert_equivalent({'A': 6.4}, {'A': lambda x: 6.4 * x})
A = np.array([[-4, 2, 5], [1, 3, -7], [2, -1, 0]])
assert_equivalent({'A': A}, {'A': A, 'At': A.T})
assert_equivalent({'A': lambda x: A.dot(x)}, {'A': A, 'At': A})
Expand Down

0 comments on commit d1af42a

Please sign in to comment.