diff --git a/scico/admm.py b/scico/admm.py index 23461ffcd..2f6ac25f3 100644 --- a/scico/admm.py +++ b/scico/admm.py @@ -185,7 +185,7 @@ def internal_init(self, admm): ) if admm.f is not None: # hessian = A.T @ W @ A; W may be identity - lhs_op = lhs_op + 2.0 * admm.f.scale * admm.f.hessian + lhs_op = lhs_op + admm.f.hessian lhs_op.jit() self.lhs_op = lhs_op diff --git a/scico/test/test_admm.py b/scico/test/test_admm.py index 7dd74e970..504e61a62 100644 --- a/scico/test/test_admm.py +++ b/scico/test/test_admm.py @@ -16,24 +16,26 @@ def setup_method(self, method): MA = 9 MB = 10 N = 8 - # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (位/2) ||B x||_2^2 + # Set up arrays for problem argmin (饾浖/2) ||A x - y||_2^2 + (位/2) ||B x||_2^2 Amx = np.random.randn(MA, N) Bmx = np.random.randn(MB, N) y = np.random.randn(MA) + 饾浖 = np.pi # sort of random number chosen to test non-default scale factor 位 = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = y + self.饾浖 = 饾浖 self.位 = 位 - # Solution of problem is given by linear system (A^T A + 位 B^T B) x = A^T y - self.grdA = lambda x: (Amx.T @ Amx + 位 * Bmx.T @ Bmx) @ x - self.grdb = Amx.T @ y + # Solution of problem is given by linear system (饾浖 A^T A + 位 B^T B) x = 饾浖 A^T y + self.grdA = lambda x: (饾浖 * Amx.T @ Amx + 位 * Bmx.T @ Bmx) @ x + self.grdb = 饾浖 * Amx.T @ y def test_admm_generic(self): maxiter = 100 蟻 = 1e-1 A = linop.MatrixOperator(self.Amx) - f = loss.SquaredL2Loss(y=self.y, A=A) + f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.饾浖 / 2.0) g_list = [(self.位 / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [蟻] @@ -56,7 +58,7 @@ def test_admm_quadratic_scico(self): maxiter = 50 蟻 = 1e0 A = linop.MatrixOperator(self.Amx) - f = loss.SquaredL2Loss(y=self.y, A=A) + f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.饾浖 / 2.0) g_list = [(self.位 / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [蟻] @@ -77,7 +79,7 @@ def test_admm_quadratic_jax(self): maxiter = 50 蟻 = 1e0 A = linop.MatrixOperator(self.Amx) - f = loss.SquaredL2Loss(y=self.y, A=A) + f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.饾浖 / 2.0) g_list = [(self.位 / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [蟻] @@ -100,24 +102,26 @@ def setup_method(self, method): MA = 9 MB = 10 N = 8 - # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (位/2) ||B x||_2^2 + # Set up arrays for problem argmin (饾浖/2) ||A x - y||_2^2 + (位/2) ||B x||_2^2 Amx, key = random.randn((MA, N), dtype=np.complex64, key=None) Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key) y = np.random.randn(MA) + 饾浖 = 1.0 / 3.0 位 = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = y + self.饾浖 = 饾浖 self.位 = 位 - # Solution of problem is given by linear system (A^T A + 位 B^T B) x = A^T y - self.grdA = lambda x: (Amx.conj().T @ Amx + 位 * Bmx.conj().T @ Bmx) @ x - self.grdb = Amx.conj().T @ y + # Solution of problem is given by linear system (饾浖 A^T A + 位 B^T B) x = A^T y + self.grdA = lambda x: (饾浖 * Amx.conj().T @ Amx + 位 * Bmx.conj().T @ Bmx) @ x + self.grdb = 饾浖 * Amx.conj().T @ y def test_admm_generic(self): maxiter = 100 - 蟻 = 1e-1 + 蟻 = 2e-1 A = linop.MatrixOperator(self.Amx) - f = loss.SquaredL2Loss(y=self.y, A=A) + f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.饾浖 / 2.0) g_list = [(self.位 / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [蟻] @@ -140,7 +144,7 @@ def test_admm_quadratic(self): maxiter = 50 蟻 = 1e0 A = linop.MatrixOperator(self.Amx) - f = loss.SquaredL2Loss(y=self.y, A=A) + f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.饾浖 / 2.0) g_list = [(self.位 / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [蟻]