Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional relaxation to ADMM solver #118

Merged
merged 9 commits into from
Dec 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 31 additions & 41 deletions scico/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class ADMM:
\text{such that}\; C_i \mb{x} = \mb{z}_i \;,

via an ADMM algorithm :cite:`glowinski-1975-approximation` :cite:`gabay-1976-dual`
:cite:`boyd-2010-distributed`. consisting of the iterations
:cite:`boyd-2010-distributed`. consisting of the iterations (see :meth:`step`)

.. math::
\begin{aligned}
Expand All @@ -332,11 +332,6 @@ class ADMM:
\mb{u}_i^{(k+1)} &= \mb{u}_i^{(k)} + C_i \mb{x}^{(k+1)} - \mb{z}^{(k+1)}_i \; .
\end{aligned}

For documentation on minimization with respect to :math:`\mb{x}`, see :meth:`x_step`.

For documentation on minimization with respect to :math:`\mb{z}_i` and
:math:`\mb{u}_i`, see :meth:`z_and_u_step`.


Attributes:
f (:class:`.Functional`): Functional :math:`f` (usually a :class:`.Loss`)
Expand All @@ -348,6 +343,7 @@ class ADMM:
timer (:class:`.Timer`): Iteration timer.
rho_list (list of scalars): List of :math:`\rho_i` penalty parameters.
Must be same length as :code:`C_list` and :code:`g_list`.
alpha (float): Relaxation parameter.
u_list (list of array-like): List of scaled Lagrange multipliers
:math:`\mb{u}_i` at current iteration.
x (array-like): Solution
Expand All @@ -365,6 +361,7 @@ def __init__(
g_list: List[Functional],
C_list: List[LinearOperator],
rho_list: List[float],
alpha: float = 1.0,
x0: Optional[Union[JaxArray, BlockArray]] = None,
maxiter: int = 100,
subproblem_solver: Optional[SubproblemSolver] = None,
Expand All @@ -379,7 +376,8 @@ def __init__(
as :code:`C_list` and :code:`rho_list`
C_list : List of :math:`C_i` operators
rho_list : List of :math:`\rho_i` penalty parameters.
Must be same length as :code:`C_list` and :code:`g_list`
Must be same length as :code:`C_list` and :code:`g_list`.
alpha: Relaxation parameter. No relaxation for default 1.0.
x0 : Starting point for :math:`\mb{x}`. If None, defaults to
an array of zeros.
maxiter : Number of ADMM outer-loop iterations. Default: 100.
Expand Down Expand Up @@ -407,6 +405,7 @@ def __init__(
self.g_list: List[Functional] = g_list
self.C_list: List[LinearOperator] = C_list
self.rho_list: List[float] = rho_list
self.alpha: float = alpha
self.itnum: int = 0
self.maxiter: int = maxiter
self.timer: Timer = Timer()
Expand Down Expand Up @@ -566,58 +565,49 @@ def u_init(self, x0: Union[JaxArray, BlockArray]):
u_list = [snp.zeros(Ci.output_shape, dtype=Ci.output_dtype) for Ci in self.C_list]
return u_list

def x_step(self, x):
r"""Update :math:`\mb{x}` by solving the optimization problem.

.. math::
\mb{x}^{(k+1)} = \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i \frac{\rho_i}{2}
\norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2
def step(self):
r"""Perform a single ADMM iteration.

"""
return self.subproblem_solver.solve(x)
The primary variable :math:`\mb{x}` is updated by solving the the
optimization problem

def z_and_u_step(self, u_list, z_list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the code was written this way so that the z_and_u_step could be jitted (jitted functions should be pure, no side effects). However, z_and_u_step was not actually jitted as far as I can tell. So I don't mind the refactor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my suspicion too. If we decide there is some benefit to jitting, we can refactor again, but for now this simpler structure makes more sense.

r"""Update the auxiliary variables :math:`\mb{z}_i` and scaled Lagrange multipliers
:math:`\mb{u}_i`.
.. math::
\mb{x}^{(k+1)} = \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i
\frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i -
C_i \mb{x}}_2^2 \;.

The auxiliary variables are updated according to

.. math::
\begin{aligned}
\mb{z}_i^{(k+1)} &= \argmin_{\mb{z}_i} \; g_i(\mb{z}_i) + \frac{\rho_i}{2}
\norm{\mb{z}_i - \mb{u}^{(k)}_i - C_i \mb{x}^{(k+1)}}_2^2 \\
&= \mathrm{prox}_{g_i}(C_i \mb{x} + \mb{u}_i, 1 / \rho_i)
\mb{z}_i^{(k+1)} &= \argmin_{\mb{z}_i} \; g_i(\mb{z}_i) +
\frac{\rho_i}{2} \norm{\mb{z}_i - \mb{u}^{(k)}_i - C_i
\mb{x}^{(k+1)}}_2^2 \\
&= \mathrm{prox}_{g_i}(C_i \mb{x} + \mb{u}_i, 1 / \rho_i) \;,
\end{aligned}

while the scaled Lagrange multipliers are updated according to
and the scaled Lagrange multipliers are updated according to

.. math::
\mb{u}_i^{(k+1)} = \mb{u}_i^{(k)} + C_i \mb{x}^{(k+1)} - \mb{z}^{(k+1)}_i

\mb{u}_i^{(k+1)} = \mb{u}_i^{(k)} + C_i \mb{x}^{(k+1)} -
\mb{z}^{(k+1)}_i \;.
"""
z_list_old = z_list.copy()

# Unpack the arrays that will be changing to prevent side-effects
z_list = self.z_list
u_list = self.u_list
self.x = self.subproblem_solver.solve(self.x)

self.z_list_old = self.z_list.copy()

for i, (rhoi, gi, Ci, zi, ui) in enumerate(
zip(self.rho_list, self.g_list, self.C_list, z_list, u_list)
zip(self.rho_list, self.g_list, self.C_list, self.z_list, self.u_list)
):
Cix = Ci(self.x)
if self.alpha == 1.0:
Cix = Ci(self.x)
else:
Cix = self.alpha * Ci(self.x) + (1.0 - self.alpha) * zi
zi = gi.prox(Cix + ui, 1 / rhoi, v0=zi)
ui = ui + Cix - zi
z_list[i] = zi
u_list[i] = ui
return u_list, z_list, z_list_old

def step(self):
"""Perform a single ADMM iteration.

Equivalent to calling :meth:`.x_step` followed by :meth:`.z_and_u_step`.
"""
self.x = self.x_step(self.x)
self.u_list, self.z_list, self.z_list_old = self.z_and_u_step(self.u_list, self.z_list)
self.z_list[i] = zi
self.u_list[i] = ui

def solve(
self,
Expand Down
78 changes: 74 additions & 4 deletions scico/test/test_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,54 @@
)


class TestMisc:
def setup_method(self, method):
np.random.seed(12345)
self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32))
self.λ = 1e0

def test_admm(self):
maxiter = 2
ρ = 1e-1
A = linop.Identity(self.y.shape)
f = loss.SquaredL2Loss(y=self.y, A=A)
g = (self.λ / 2) * functional.BM3D()
C = linop.Identity(self.y.shape)

itstat_dict = {"Iter": "%d", "Time": "%8.2e"}

def itstat_func(obj):
return (obj.itnum, obj.timer.elapsed())

admm_ = ADMM(
f=f,
g_list=[g],
C_list=[C],
rho_list=[ρ],
maxiter=maxiter,
verbose=False,
)
assert len(admm_.itstat_object.fieldname) == 4
assert snp.sum(admm_.x) == 0.0
admm_ = ADMM(
f=f,
g_list=[g],
C_list=[C],
rho_list=[ρ],
maxiter=maxiter,
verbose=False,
itstat=(itstat_dict, itstat_func),
)
assert len(admm_.itstat_object.fieldname) == 2

def callback(obj):
global flag
flag = True

x = admm_.solve(callback=callback)
assert flag


class TestReal:
def setup_method(self, method):
np.random.seed(12345)
Expand All @@ -35,7 +83,7 @@ def setup_method(self, method):

def test_admm_generic(self):
maxiter = 100
ρ = 1e-1
ρ = 2e-1
A = linop.MatrixOperator(self.Amx)
f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)
g_list = [(self.λ / 2) * functional.SquaredL2Norm()]
Expand All @@ -58,7 +106,7 @@ def test_admm_generic(self):

def test_admm_quadratic_scico(self):
maxiter = 50
ρ = 1e0
ρ = 4e-1
A = linop.MatrixOperator(self.Amx)
f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)
g_list = [(self.λ / 2) * functional.SquaredL2Norm()]
Expand Down Expand Up @@ -98,6 +146,28 @@ def test_admm_quadratic_jax(self):
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5

def test_admm_quadratic_relax(self):
maxiter = 50
ρ = 1e0
A = linop.MatrixOperator(self.Amx)
f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)
g_list = [(self.λ / 2) * functional.SquaredL2Norm()]
C_list = [linop.MatrixOperator(self.Bmx)]
rho_list = [ρ]
admm_ = ADMM(
f=f,
g_list=g_list,
C_list=C_list,
rho_list=rho_list,
alpha=1.6,
maxiter=maxiter,
verbose=False,
x0=A.adj(self.y),
subproblem_solver=LinearSubproblemSolver(cg_function="jax"),
)
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5


class TestRealWeighted:
def setup_method(self, method):
Expand All @@ -113,7 +183,7 @@ def setup_method(self, method):
𝛼 = np.pi # sort of random number chosen to test non-default scale factor
λ = np.e
self.Amx = Amx
self.W = W
self.W = jax.device_put(W)
self.Bmx = Bmx
self.y = jax.device_put(y)
self.𝛼 = 𝛼
Expand Down Expand Up @@ -169,7 +239,7 @@ def setup_method(self, method):

def test_admm_generic(self):
maxiter = 100
ρ = 2e-1
ρ = 1e0
A = linop.MatrixOperator(self.Amx)
f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)
g_list = [(self.λ / 2) * functional.SquaredL2Norm()]
Expand Down