"""This module provides an abstract class Alg for iterative algorithms,
and implements commonly used methods, such as gradient methods,
Newton's method, and the augmented Lagrangian method.
"""

In [1]:
import numpy as np
import sigpy as sp
from sigpy import backend, util

"""Abstraction for iterative algorithms.

    The standard way of using an :class:`Alg` object, say alg, is as follows:

    >>> while not alg.done():
    >>>     alg.update()

    The user is free to run other things in the while loop.
    An :class:`Alg` object is meant to run once.
    Once done, the object should not be run again.

    When creating a new :class:`Alg` class, the user should supply
    an _update() function
    to perform the iterative update, and optionally a _done() function
    to determine when to terminate the iteration. The default _done() function
    simply checks whether the number of iterations has reached the maximum.

    The interface for each :class:`Alg` class should not depend on
    Linop or Prox explicitly.
    For example, if the user wants to design an
    :class:`Alg` class to accept a Linop, say A,
    as an argument, then it should also accept any function that can be called
    to compute x -> A(x). Similarly, to accept a Prox, say proxg,
    as an argument,
    the Alg class should accept any function that can be called to compute
    alpha, x -> proxg(x).

    Args:
        max_iter (int): Maximum number of iterations.

    Attributes:
        max_iter (int): Maximum number of iterations.
        iter (int): Current iteration.
        

In [None]:




class Alg(object):
    

    def __init__(self, max_iter):
        self.max_iter = max_iter
        self.iter = 0


    def _update(self):
        raise NotImplementedError

    def _done(self):
        return self.iter >= self.max_iter

    def update(self):
        """Perform one update step.

        Call the user-defined _update() function and increment iter.
        """
        self._update()
        self.iter += 1

    def done(self):
        """Return whether the algorithm is done.

        Call the user-defined _done() function.
        """
        return self._done()



r"""Conjugate gradient method.

    Solves for:

    .. math:: A x = b

    where A is a Hermitian linear operator.

    Args:
        A (Linop or function): Linop or function to compute A.
        b (array): Observation.
        x (array): Variable.
        P (function or None): Preconditioner.
        max_iter (int): Maximum number of iterations.
        tol (float): Tolerance for stopping condition.

    """

In [None]:
class ConjugateGradient(Alg):
    
    
    def __init__(self, A, b, x, P=None, max_iter=100, tol=0):
        self.A = A
        self.b = b
        self.P = P
        self.x = x
        self.tol = tol
        self.device = backend.get_device(x)
        with self.device:
            xp = self.device.xp
            self.r = b - self.A(self.x)

            if self.P is None:
                z = self.r
            else:
                z = self.P(self.r)

            if max_iter > 1:
                self.p = z.copy()
            else:
                self.p = z

            self.not_positive_definite = False
            self.rzold = xp.real(xp.vdot(self.r, z))
            self.resid = self.rzold.item() ** 0.5

            super().__init__(max_iter)


def _update(self):
        with self.device:
            xp = self.device.xp
            Ap = self.A(self.p)
            pAp = xp.real(xp.vdot(self.p, Ap)).item()
            if pAp <= 0:
                self.not_positive_definite = True
                return

            self.alpha = self.rzold / pAp
            util.axpy(self.x, self.alpha, self.p)
            if self.iter < self.max_iter - 1:
                util.axpy(self.r, -self.alpha, Ap)
                if self.P is not None:
                    z = self.P(self.r)
                else:
                    z = self.r

                rznew = xp.real(xp.vdot(self.r, z))
                beta = rznew / self.rzold
                util.xpay(self.p, beta, z)
                self.rzold = rznew

            self.resid = self.rzold.item() ** 0.5

    
def _done(self):
        return (
            self.iter >= self.max_iter
            or self.not_positive_definite
            or self.resid <= self.tol
        )

