Skip to content

Commit

Permalink
Improve ADMM tests (#46)
Browse files Browse the repository at this point in the history
* Fixes scaling bug in CG from #38 (this commit already merged in PR #45)

* Tests solvers for loss.SquaredL2Loss with non-default scale factor

* Improve solution accuracy

* scale=pi, justify comment

Co-authored-by: Thilo Balke <thilo.balke@gmail.com>
  • Loading branch information
bwohlberg and tbalke committed Oct 18, 2021
1 parent c844535 commit be808db
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions scico/test/test_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,26 @@ def setup_method(self, method):
MA = 9
MB = 10
N = 8
# Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2
# Set up arrays for problem argmin (𝛼/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2
Amx = np.random.randn(MA, N)
Bmx = np.random.randn(MB, N)
y = np.random.randn(MA)
𝛼 = np.pi # sort of random number chosen to test non-default scale factor
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
self.y = y
self.𝛼 = 𝛼
self.λ = λ
# Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y
self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x
self.grdb = Amx.T @ y
# Solution of problem is given by linear system (𝛼 A^T A + λ B^T B) x = 𝛼 A^T y
self.grdA = lambda x: (𝛼 * Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x
self.grdb = 𝛼 * Amx.T @ y

def test_admm_generic(self):
maxiter = 100
ρ = 1e-1
A = linop.MatrixOperator(self.Amx)
f = loss.SquaredL2Loss(y=self.y, A=A)
f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)
g_list = [(self.λ / 2) * functional.SquaredL2Norm()]
C_list = [linop.MatrixOperator(self.Bmx)]
rho_list = [ρ]
Expand All @@ -56,7 +58,7 @@ def test_admm_quadratic_scico(self):
maxiter = 50
ρ = 1e0
A = linop.MatrixOperator(self.Amx)
f = loss.SquaredL2Loss(y=self.y, A=A)
f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)
g_list = [(self.λ / 2) * functional.SquaredL2Norm()]
C_list = [linop.MatrixOperator(self.Bmx)]
rho_list = [ρ]
Expand All @@ -77,7 +79,7 @@ def test_admm_quadratic_jax(self):
maxiter = 50
ρ = 1e0
A = linop.MatrixOperator(self.Amx)
f = loss.SquaredL2Loss(y=self.y, A=A)
f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)
g_list = [(self.λ / 2) * functional.SquaredL2Norm()]
C_list = [linop.MatrixOperator(self.Bmx)]
rho_list = [ρ]
Expand All @@ -100,24 +102,26 @@ def setup_method(self, method):
MA = 9
MB = 10
N = 8
# Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2
# Set up arrays for problem argmin (𝛼/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2
Amx, key = random.randn((MA, N), dtype=np.complex64, key=None)
Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key)
y = np.random.randn(MA)
𝛼 = 1.0 / 3.0
λ = 1e0
self.Amx = Amx
self.Bmx = Bmx
self.y = y
self.𝛼 = 𝛼
self.λ = λ
# Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y
self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x
self.grdb = Amx.conj().T @ y
# Solution of problem is given by linear system (𝛼 A^T A + λ B^T B) x = A^T y
self.grdA = lambda x: (𝛼 * Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x
self.grdb = 𝛼 * Amx.conj().T @ y

def test_admm_generic(self):
maxiter = 100
ρ = 1e-1
ρ = 2e-1
A = linop.MatrixOperator(self.Amx)
f = loss.SquaredL2Loss(y=self.y, A=A)
f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)
g_list = [(self.λ / 2) * functional.SquaredL2Norm()]
C_list = [linop.MatrixOperator(self.Bmx)]
rho_list = [ρ]
Expand All @@ -140,7 +144,7 @@ def test_admm_quadratic(self):
maxiter = 50
ρ = 1e0
A = linop.MatrixOperator(self.Amx)
f = loss.SquaredL2Loss(y=self.y, A=A)
f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)
g_list = [(self.λ / 2) * functional.SquaredL2Norm()]
C_list = [linop.MatrixOperator(self.Bmx)]
rho_list = [ρ]
Expand Down

0 comments on commit be808db

Please sign in to comment.