Skip to content

Commit

Permalink
CHG: adding changes to add backtracking fista algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelcarcamov committed Apr 27, 2023
1 parent 065c63f commit 9fa24a2
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 114 deletions.
Empty file added src/__init__.py
Empty file.
1 change: 0 additions & 1 deletion src/csromer/objectivefunction/priors/l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def evaluate(self, x, epsilon=np.finfo(np.float32).tiny):
return val

def calculate_gradient(self, x, epsilon=np.finfo(np.float32).tiny):
dx = np.zeros(len(x), x.dtype)

dx = x / approx_abs(x, epsilon)

Expand Down
108 changes: 0 additions & 108 deletions src/csromer/optimization/methods/fista.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/csromer/optimization/methods/fista/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .fista import FISTA
from .fista_backtracking import BacktrackingFISTA
from .fista_general import GeneralFISTA
17 changes: 17 additions & 0 deletions src/csromer/optimization/methods/fista/fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import copy
from dataclasses import dataclass

from ....objectivefunction import Chi2, Fi
from ....optimization.optimizer import Optimizer


@dataclass(init=True, repr=True)
class FISTA(Optimizer):

fx: Chi2 = None
gx: Fi = None
lipschitz_constant: float = None
noise: float = None

def run(self):
pass
29 changes: 29 additions & 0 deletions src/csromer/optimization/methods/fista/fista_backtracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import copy
from dataclasses import dataclass

from ....optimization.methods.fista.fista_backtracking_algorithm import fista_backtracking_algorithm
from .fista import FISTA


@dataclass(init=True, repr=True)
class BacktrackingFISTA(FISTA):
eta: float = None

def run(self):
ret, x = fista_backtracking_algorithm(
self.guess_param.data,
self.F_obj.evaluate,
self.fx.evaluate,
self.fx.calculate_gradient_fista,
self.gx,
self.lipschitz_constant,
self.eta,
self.maxiter,
self.guess_param.n,
self.noise,
self.verbose,
)

param = copy.deepcopy(self.guess_param)
param.data = x
return ret, param
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np


def calculate_Q(x, y, fx, fx_grad, gx, lipschitz_constant):
x_minus_y = x - y
res = fx(y) + np.dot(x_minus_y,
-fx_grad(y)) + 0.5 * lipschitz_constant * np.sum(x_minus_y**2) + gx(x)
return res


def fista_backtracking_algorithm(
x=None,
F=None,
fx=None,
fx_grad=None,
g_prox=None,
lipschitz_constant=None,
eta=None,
max_iter=None,
n=None,
noise=None,
verbose=True,
):
if x is None and n is not None:
x = np.zeros(n, dtype=np.float32)
t = 1
z = x.copy()
min_cost = 0.0

if max_iter is None and noise is not None:
if noise is not np.nan:
if noise != 0.0:
max_iter = int(np.floor(g_prox.getLambda() / noise))
else:
noise = 1e-5
max_iter = int(np.floor(g_prox.getLambda() / noise))
else:
raise ValueError("Noise must be a number")
if verbose:
print("Iterations set to " + str(max_iter))

if noise is None:
noise = 1e-5

if noise >= g_prox.getLambda():
if verbose:
print("Error, noise cannot be greater than lambda")
return min_cost, x

for it in range(0, max_iter):
xold = x.copy()
z = z - fx_grad(z)
x = g_prox.calc_prox(z)

t0 = t
t = 0.5 * (1.0 + np.sqrt(1.0 + 4.0 * t**2))
z = x + ((t0 - 1.0) / t) * (x - xold)
# e = np.sqrt(np.sum((x-xold)**2)) / np.sqrt(np.sum(xold**2))
# print(e)
e = np.sum(np.abs(x - xold)) / len(x)

# if e <= tol:
# if verbose:
# print("Exit due to tolerance: ", e, " < ", tol)
# print("Iterations: ", it + 1)
# break

if verbose and it % 10 == 0:
cost = F(x)
print("Iteration: ", it, " objective function value: {0:0.5f}".format(cost))
new_lambda = g_prox.getLambda() - noise
if new_lambda > 0.0:
g_prox.setLambda(reg=new_lambda)
else:
if verbose:
print("Exit due to negative regularization parameter")
break
min_cost = F(x)
return min_cost, x
25 changes: 25 additions & 0 deletions src/csromer/optimization/methods/fista/fista_general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import copy
from dataclasses import dataclass

from ....optimization.methods.fista.fista_general_algorithm import fista_general_algorithm
from .fista import FISTA


@dataclass(init=True, repr=True)
class GeneralFISTA(FISTA):

def run(self):
ret, x = fista_general_algorithm(
self.guess_param.data,
self.F_obj.evaluate,
self.fx.calculate_gradient_fista,
self.gx,
self.lipschitz_constant,
self.maxiter,
self.guess_param.n,
self.verbose,
)

param = copy.deepcopy(self.guess_param)
param.data = x
return ret, param
43 changes: 43 additions & 0 deletions src/csromer/optimization/methods/fista/fista_general_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np


def fista_general_algorithm(
x=None,
F=None,
fx=None,
g_prox=None,
lipschitz_constant=None,
max_iter=None,
n=None,
verbose=True,
):
if x is None and n is not None:
x = np.zeros(n, dtype=np.float32)

if lipschitz_constant is None:
lipschitz_constant = 1.

t = 1
z = x.copy()
g_prox.setLambda(reg=g_prox.getLambda() / lipschitz_constant)

if max_iter is None:
max_iter = 110
if verbose:
print("Iterations set to " + str(max_iter))

for it in range(0, max_iter):
x_old = x.copy()
z = z - fx(z) / lipschitz_constant
x = g_prox.calc_prox(z)

t0 = t
t = 0.5 * (1.0 + np.sqrt(1.0 + 4.0 * t**2))
z = x + ((t0 - 1.0) / t) * (x - x_old)

if verbose and it % 10 == 0:
cost = F(x)
print("Iteration: ", it, " objective function value: {0:0.5f}".format(cost))

min_cost = F(x)
return min_cost, x
20 changes: 20 additions & 0 deletions src/csromer/wrappers/reconstructors/csromer_3d_reconstructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass, field
from typing import Union

import numpy as np
from astropy.stats import sigma_clipped_stats

from ...dictionaries import Wavelet
from ...objectivefunction import L1, TSV, TV, Chi2, OFunction
from ...optimization import FISTA
from ...reconstruction import Parameter
from ...transformers.dfts import NDFT1D, NUFFT1D
from ...transformers.flaggers.flagger import Flagger
from .faraday_reconstructor import FaradayReconstructorWrapper


@dataclass(init=True, repr=True)
class CSROMER3DReconstructorWrapper(FaradayReconstructorWrapper):
sigma_threshold_p: float = None
sigma_threshold_intensity: float = None
spectral_index: Union[Union[str, float], np.ndarray] = None
9 changes: 4 additions & 5 deletions src/csromer/wrappers/reconstructors/csromer_reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ...dictionaries import Wavelet
from ...objectivefunction import L1, TSV, TV, Chi2, OFunction
from ...optimization import FISTA
from ...optimization.methods.fista import BacktrackingFISTA, GeneralFISTA
from ...reconstruction import Parameter
from ...transformers.dfts import NDFT1D, NUFFT1D
from ...transformers.flaggers.flagger import Flagger
Expand Down Expand Up @@ -187,11 +187,13 @@ def reconstruct(self):
else:
opt_noise = self.dataset.theo_noise

opt = FISTA(
opt = BacktrackingFISTA(
guess_param=self.parameter,
F_obj=F_obj,
fx=chi2,
gx=g_obj,
lipschitz_constant=1,
eta=2,
noise=opt_noise,
verbose=True,
)
Expand Down Expand Up @@ -228,9 +230,6 @@ def reconstruct(self):
)

def calculate_second_moment(self):
# phi_nonzero_positions = np.abs(self.fd_model) > 0.
# phi_nonzero = self.parameter.phi[phi_nonzero_positions]
# fd_model_nonzero = self.fd_model[phi_nonzero_positions]

fd_model_abs = np.abs(self.fd_model)
k_parameter = np.sum(fd_model_abs)
Expand Down

0 comments on commit 9fa24a2

Please sign in to comment.