Skip to content

Commit

Permalink
Merge branch 'main' into brendt/issue#240
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Oct 24, 2022
2 parents 3fe52d0 + edc00ed commit 53daa5d
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 121 deletions.
65 changes: 35 additions & 30 deletions scico/optimize/_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Callable, List, Optional, Tuple, Union

import scico.numpy as snp
from scico.diagnostics import IterationStats
from scico.functional import Functional
from scico.linop import LinearOperator
from scico.numpy import BlockArray
Expand All @@ -24,6 +23,7 @@
from scico.util import Timer

from ._admmaux import GenericSubproblemSolver, LinearSubproblemSolver, SubproblemSolver
from ._common import itstat_func_and_object


class ADMM:
Expand Down Expand Up @@ -122,12 +122,12 @@ def __init__(
the :class:`.diagnostics.IterationStats` initializer. The
dict may also include an additional key "itstat_func"
with the corresponding value being a function with two
parameters, an integer and an `ADMM` object, responsible
for constructing a tuple ready for insertion into the
:class:`.diagnostics.IterationStats` object. If ``None``,
default values are used for the dict entries, otherwise
the default dict is updated with the dict specified by
this parameter.
parameters, an integer and an :class:`ADMM` object,
responsible for constructing a tuple ready for insertion
into the :class:`.diagnostics.IterationStats` object. If
``None``, default values are used for the dict entries,
otherwise the default dict is updated with the dict
specified by this parameter.
"""
N = len(g_list)
if len(C_list) != N:
Expand All @@ -148,6 +148,31 @@ def __init__(
self.subproblem_solver: SubproblemSolver = subproblem_solver
self.subproblem_solver.internal_init(self)

if x0 is None:
input_shape = C_list[0].input_shape
dtype = C_list[0].input_dtype
x0 = snp.zeros(input_shape, dtype=dtype)
self.x = ensure_on_device(x0)
self.z_list, self.z_list_old = self.z_init(self.x)
self.u_list = self.u_init(self.x)

self._itstat_init(itstat_options)

def _itstat_init(self, itstat_options: Optional[dict] = None):
"""Initialize iteration statistics mechanism.
Args:
itstat_options: A dict of named parameters to be passed to
the :class:`.diagnostics.IterationStats` initializer. The
dict may also include an additional key "itstat_func"
with the corresponding value being a function with two
parameters, an integer and an :class:`ADMM` object,
responsible for constructing a tuple ready for insertion
into the :class:`.diagnostics.IterationStats` object. If
``None``, default values are used for the dict entries,
otherwise the default dict is updated with the dict
specified by this parameter.
"""
# iteration number and time fields
itstat_fields = {
"Iter": "%d",
Expand Down Expand Up @@ -177,29 +202,9 @@ def __init__(
["subproblem_solver.info['num_iter']", "subproblem_solver.info['rel_res']"]
)

# dynamically create itstat_func; see https://stackoverflow.com/questions/24733831
itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")"
scope: dict[str, Callable] = {}
exec("def itstat_func(obj): " + itstat_return, scope)

# determine itstat options and initialize IterationStats object
default_itstat_options = {
"fields": itstat_fields,
"itstat_func": scope["itstat_func"],
"display": False,
}
if itstat_options:
default_itstat_options.update(itstat_options)
self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore
self.itstat_object = IterationStats(**default_itstat_options) # type: ignore

if x0 is None:
input_shape = C_list[0].input_shape
dtype = C_list[0].input_dtype
x0 = snp.zeros(input_shape, dtype=dtype)
self.x = ensure_on_device(x0)
self.z_list, self.z_list_old = self.z_init(self.x)
self.u_list = self.u_init(self.x)
self.itstat_insert_func, self.itstat_object = itstat_func_and_object(
itstat_fields, itstat_attrib, itstat_options
)

def objective(
self,
Expand Down
62 changes: 62 additions & 0 deletions scico/optimize/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Functions common to multiple optimizer modules."""


from typing import Callable, Dict, List, Optional, Tuple, Union

from scico.diagnostics import IterationStats


def itstat_func_and_object(
itstat_fields: dict, itstat_attrib: List, itstat_options: Optional[dict] = None
) -> Tuple[Callable, IterationStats]:
"""Iteration statistics initialization.
Iteration statistics initialization steps common to all optimizer
classes.
Args:
itstat_fields: A dictionary associating field names with format
strings for displaying the corresponding values.
itstat_attrib: A list of expressions corresponding of optimizer
class attributes to be evaluated when computing iteration
statistics.
itstat_options: A dict of named parameters to be passed to
the :class:`.diagnostics.IterationStats` initializer. The
dict may also include an additional key "itstat_func"
with the corresponding value being a function with two
parameters, an integer and an optimizer object,
responsible for constructing a tuple ready for insertion
into the :class:`.diagnostics.IterationStats` object. If
``None``, default values are used for the dict entries,
otherwise the default dict is updated with the dict
specified by this parameter.
Returns:
A tuple consisting of the statistics insertion function and the
:class:`.diagnostics.IterationStats` object.
"""
# dynamically create itstat_func; see https://stackoverflow.com/questions/24733831
itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")"
scope: Dict[str, Callable] = {}
exec("def itstat_func(obj): " + itstat_return, scope)

# determine itstat options and initialize IterationStats object
default_itstat_options: Dict[str, Union[dict, Callable, bool]] = {
"fields": itstat_fields,
"itstat_func": scope["itstat_func"],
"display": False,
}
if itstat_options:
default_itstat_options.update(itstat_options)

itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore
itstat_object = IterationStats(**default_itstat_options) # type: ignore

return itstat_insert_func, itstat_object
68 changes: 37 additions & 31 deletions scico/optimize/_ladmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Callable, List, Optional, Tuple, Union

import scico.numpy as snp
from scico.diagnostics import IterationStats
from scico.functional import Functional
from scico.linop import LinearOperator
from scico.numpy import BlockArray
Expand All @@ -23,6 +22,8 @@
from scico.typing import JaxArray
from scico.util import Timer

from ._common import itstat_func_and_object


class LinearizedADMM:
r"""Linearized alternating direction method of multipliers algorithm.
Expand Down Expand Up @@ -113,12 +114,12 @@ def __init__(
the :class:`.diagnostics.IterationStats` initializer. The
dict may also include an additional key "itstat_func"
with the corresponding value being a function with two
parameters, an integer and a `LinearizedADMM` object,
responsible for constructing a tuple ready for insertion
into the :class:`.diagnostics.IterationStats` object. If
``None``, default values are used for the dict entries,
otherwise the default dict is updated with the dict
specified by this parameter.
parameters, an integer and a :class:`LinearizedADMM`
object, responsible for constructing a tuple ready for
insertion into the :class:`.diagnostics.IterationStats`
object. If ``None``, default values are used for the dict
entries, otherwise the default dict is updated with the
dict specified by this parameter.
"""
self.f: Functional = f
self.g: Functional = g
Expand All @@ -129,43 +130,48 @@ def __init__(
self.maxiter: int = maxiter
self.timer: Timer = Timer()

if x0 is None:
input_shape = C.input_shape
dtype = C.input_dtype
x0 = snp.zeros(input_shape, dtype=dtype)
self.x = ensure_on_device(x0)
self.z, self.z_old = self.z_init(self.x)
self.u = self.u_init(self.x)

self._itstat_init(itstat_options)

def _itstat_init(self, itstat_options: Optional[dict] = None):
"""Initialize iteration statistics mechanism.
Args:
itstat_options: A dict of named parameters to be passed to
the :class:`.diagnostics.IterationStats` initializer. The
dict may also include an additional key "itstat_func"
with the corresponding value being a function with two
parameters, an integer and a :class:`LinearizedADMM`
object, responsible for constructing a tuple ready for
insertion into the :class:`.diagnostics.IterationStats`
object. If ``None``, default values are used for the dict
entries, otherwise the default dict is updated with the
dict specified by this parameter.
"""
# iteration number and time fields
itstat_fields = {
"Iter": "%d",
"Time": "%8.2e",
}
itstat_attrib = ["itnum", "timer.elapsed()"]
# objective function can be evaluated if 'g' function can be evaluated
if g.has_eval:
if self.g.has_eval:
itstat_fields.update({"Objective": "%9.3e"})
itstat_attrib.append("objective()")
# primal and dual residual fields
itstat_fields.update({"Prml Rsdl": "%9.3e", "Dual Rsdl": "%9.3e"})
itstat_attrib.extend(["norm_primal_residual()", "norm_dual_residual()"])

# dynamically create itstat_func; see https://stackoverflow.com/questions/24733831
itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")"
scope: dict[str, Callable] = {}
exec("def itstat_func(obj): " + itstat_return, scope)

# determine itstat options and initialize IterationStats object
default_itstat_options = {
"fields": itstat_fields,
"itstat_func": scope["itstat_func"],
"display": False,
}
if itstat_options:
default_itstat_options.update(itstat_options)
self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore
self.itstat_object = IterationStats(**default_itstat_options) # type: ignore

if x0 is None:
input_shape = C.input_shape
dtype = C.input_dtype
x0 = snp.zeros(input_shape, dtype=dtype)
self.x = ensure_on_device(x0)
self.z, self.z_old = self.z_init(self.x)
self.u = self.u_init(self.x)
self.itstat_insert_func, self.itstat_object = itstat_func_and_object(
itstat_fields, itstat_attrib, itstat_options
)

def objective(
self,
Expand Down
54 changes: 30 additions & 24 deletions scico/optimize/_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import jax

import scico.numpy as snp
from scico.diagnostics import IterationStats
from scico.functional import Functional
from scico.loss import Loss
from scico.numpy import BlockArray
from scico.numpy.util import ensure_on_device
from scico.typing import JaxArray
from scico.util import Timer

from ._common import itstat_func_and_object
from ._pgmaux import (
AdaptiveBBStepSize,
BBStepSize,
Expand Down Expand Up @@ -68,12 +68,12 @@ def __init__(
the :class:`.diagnostics.IterationStats` initializer. The
dict may also include an additional key "itstat_func"
with the corresponding value being a function with two
parameters, an integer and a PGM object, responsible
for constructing a tuple ready for insertion into the
:class:`.diagnostics.IterationStats` object. If ``None``,
default values are used for the dict entries, otherwise
the default dict is updated with the dict specified by
this parameter.
parameters, an integer and a :class:`PGM` object,
responsible for constructing a tuple ready for insertion
into the :class:`.diagnostics.IterationStats` object. If
``None``, default values are used for the dict entries,
otherwise the default dict is updated with the dict
specified by this parameter.
"""

#: Functional or Loss to minimize; must have grad method defined.
Expand All @@ -100,36 +100,42 @@ def x_step(v: Union[JaxArray, BlockArray], L: float) -> Union[JaxArray, BlockArr

self.x_step = jax.jit(x_step)

self.x: Union[JaxArray, BlockArray] = ensure_on_device(x0) # current estimate of solution

self._itstat_init(itstat_options)

def _itstat_init(self, itstat_options: Optional[dict] = None):
"""Initialize iteration statistics mechanism.
Args:
itstat_options: A dict of named parameters to be passed to
the :class:`.diagnostics.IterationStats` initializer. The
dict may also include an additional key "itstat_func"
with the corresponding value being a function with two
parameters, an integer and a :class:`PGM` object,
responsible for constructing a tuple ready for insertion
into the :class:`.diagnostics.IterationStats` object. If
``None``, default values are used for the dict entries,
otherwise the default dict is updated with the dict
specified by this parameter.
"""
# iteration number and time fields
itstat_fields = {
"Iter": "%d",
"Time": "%8.2e",
}
itstat_attrib = ["itnum", "timer.elapsed()"]
# objective function can be evaluated if 'g' function can be evaluated
if g.has_eval:
if self.g.has_eval:
itstat_fields.update({"Objective": "%9.3e"})
itstat_attrib.append("objective()")
# step size and residual fields
itstat_fields.update({"L": "%9.3e", "Residual": "%9.3e"})
itstat_attrib.extend(["L", "norm_residual()"])

# dynamically create itstat_func; see https://stackoverflow.com/questions/24733831
itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")"
scope: dict[str, Callable] = {}
exec("def itstat_func(obj): " + itstat_return, scope)

default_itstat_options: dict[str, Union[dict, Callable, bool]] = {
"fields": itstat_fields,
"itstat_func": scope["itstat_func"],
"display": False,
}
if itstat_options:
default_itstat_options.update(itstat_options)
self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func") # type: ignore
self.itstat_object = IterationStats(**default_itstat_options) # type: ignore

self.x: Union[JaxArray, BlockArray] = ensure_on_device(x0) # current estimate of solution
self.itstat_insert_func, self.itstat_object = itstat_func_and_object(
itstat_fields, itstat_attrib, itstat_options
)

def objective(self, x: Optional[Union[JaxArray, BlockArray]] = None) -> float:
r"""Evaluate the objective function :math:`f(\mb{x}) + g(\mb{x})`."""
Expand Down
Loading

0 comments on commit 53daa5d

Please sign in to comment.