-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CHG: adding changes to add backtracking fista algorithm
- Loading branch information
1 parent
065c63f
commit 9fa24a2
Showing
11 changed files
with
220 additions
and
114 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .fista import FISTA | ||
from .fista_backtracking import BacktrackingFISTA | ||
from .fista_general import GeneralFISTA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import copy | ||
from dataclasses import dataclass | ||
|
||
from ....objectivefunction import Chi2, Fi | ||
from ....optimization.optimizer import Optimizer | ||
|
||
|
||
@dataclass(init=True, repr=True) | ||
class FISTA(Optimizer): | ||
|
||
fx: Chi2 = None | ||
gx: Fi = None | ||
lipschitz_constant: float = None | ||
noise: float = None | ||
|
||
def run(self): | ||
pass |
29 changes: 29 additions & 0 deletions
29
src/csromer/optimization/methods/fista/fista_backtracking.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import copy | ||
from dataclasses import dataclass | ||
|
||
from ....optimization.methods.fista.fista_backtracking_algorithm import fista_backtracking_algorithm | ||
from .fista import FISTA | ||
|
||
|
||
@dataclass(init=True, repr=True) | ||
class BacktrackingFISTA(FISTA): | ||
eta: float = None | ||
|
||
def run(self): | ||
ret, x = fista_backtracking_algorithm( | ||
self.guess_param.data, | ||
self.F_obj.evaluate, | ||
self.fx.evaluate, | ||
self.fx.calculate_gradient_fista, | ||
self.gx, | ||
self.lipschitz_constant, | ||
self.eta, | ||
self.maxiter, | ||
self.guess_param.n, | ||
self.noise, | ||
self.verbose, | ||
) | ||
|
||
param = copy.deepcopy(self.guess_param) | ||
param.data = x | ||
return ret, param |
79 changes: 79 additions & 0 deletions
79
src/csromer/optimization/methods/fista/fista_backtracking_algorithm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import numpy as np | ||
|
||
|
||
def calculate_Q(x, y, fx, fx_grad, gx, lipschitz_constant): | ||
x_minus_y = x - y | ||
res = fx(y) + np.dot(x_minus_y, | ||
-fx_grad(y)) + 0.5 * lipschitz_constant * np.sum(x_minus_y**2) + gx(x) | ||
return res | ||
|
||
|
||
def fista_backtracking_algorithm( | ||
x=None, | ||
F=None, | ||
fx=None, | ||
fx_grad=None, | ||
g_prox=None, | ||
lipschitz_constant=None, | ||
eta=None, | ||
max_iter=None, | ||
n=None, | ||
noise=None, | ||
verbose=True, | ||
): | ||
if x is None and n is not None: | ||
x = np.zeros(n, dtype=np.float32) | ||
t = 1 | ||
z = x.copy() | ||
min_cost = 0.0 | ||
|
||
if max_iter is None and noise is not None: | ||
if noise is not np.nan: | ||
if noise != 0.0: | ||
max_iter = int(np.floor(g_prox.getLambda() / noise)) | ||
else: | ||
noise = 1e-5 | ||
max_iter = int(np.floor(g_prox.getLambda() / noise)) | ||
else: | ||
raise ValueError("Noise must be a number") | ||
if verbose: | ||
print("Iterations set to " + str(max_iter)) | ||
|
||
if noise is None: | ||
noise = 1e-5 | ||
|
||
if noise >= g_prox.getLambda(): | ||
if verbose: | ||
print("Error, noise cannot be greater than lambda") | ||
return min_cost, x | ||
|
||
for it in range(0, max_iter): | ||
xold = x.copy() | ||
z = z - fx_grad(z) | ||
x = g_prox.calc_prox(z) | ||
|
||
t0 = t | ||
t = 0.5 * (1.0 + np.sqrt(1.0 + 4.0 * t**2)) | ||
z = x + ((t0 - 1.0) / t) * (x - xold) | ||
# e = np.sqrt(np.sum((x-xold)**2)) / np.sqrt(np.sum(xold**2)) | ||
# print(e) | ||
e = np.sum(np.abs(x - xold)) / len(x) | ||
|
||
# if e <= tol: | ||
# if verbose: | ||
# print("Exit due to tolerance: ", e, " < ", tol) | ||
# print("Iterations: ", it + 1) | ||
# break | ||
|
||
if verbose and it % 10 == 0: | ||
cost = F(x) | ||
print("Iteration: ", it, " objective function value: {0:0.5f}".format(cost)) | ||
new_lambda = g_prox.getLambda() - noise | ||
if new_lambda > 0.0: | ||
g_prox.setLambda(reg=new_lambda) | ||
else: | ||
if verbose: | ||
print("Exit due to negative regularization parameter") | ||
break | ||
min_cost = F(x) | ||
return min_cost, x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import copy | ||
from dataclasses import dataclass | ||
|
||
from ....optimization.methods.fista.fista_general_algorithm import fista_general_algorithm | ||
from .fista import FISTA | ||
|
||
|
||
@dataclass(init=True, repr=True) | ||
class GeneralFISTA(FISTA): | ||
|
||
def run(self): | ||
ret, x = fista_general_algorithm( | ||
self.guess_param.data, | ||
self.F_obj.evaluate, | ||
self.fx.calculate_gradient_fista, | ||
self.gx, | ||
self.lipschitz_constant, | ||
self.maxiter, | ||
self.guess_param.n, | ||
self.verbose, | ||
) | ||
|
||
param = copy.deepcopy(self.guess_param) | ||
param.data = x | ||
return ret, param |
43 changes: 43 additions & 0 deletions
43
src/csromer/optimization/methods/fista/fista_general_algorithm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import numpy as np | ||
|
||
|
||
def fista_general_algorithm( | ||
x=None, | ||
F=None, | ||
fx=None, | ||
g_prox=None, | ||
lipschitz_constant=None, | ||
max_iter=None, | ||
n=None, | ||
verbose=True, | ||
): | ||
if x is None and n is not None: | ||
x = np.zeros(n, dtype=np.float32) | ||
|
||
if lipschitz_constant is None: | ||
lipschitz_constant = 1. | ||
|
||
t = 1 | ||
z = x.copy() | ||
g_prox.setLambda(reg=g_prox.getLambda() / lipschitz_constant) | ||
|
||
if max_iter is None: | ||
max_iter = 110 | ||
if verbose: | ||
print("Iterations set to " + str(max_iter)) | ||
|
||
for it in range(0, max_iter): | ||
x_old = x.copy() | ||
z = z - fx(z) / lipschitz_constant | ||
x = g_prox.calc_prox(z) | ||
|
||
t0 = t | ||
t = 0.5 * (1.0 + np.sqrt(1.0 + 4.0 * t**2)) | ||
z = x + ((t0 - 1.0) / t) * (x - x_old) | ||
|
||
if verbose and it % 10 == 0: | ||
cost = F(x) | ||
print("Iteration: ", it, " objective function value: {0:0.5f}".format(cost)) | ||
|
||
min_cost = F(x) | ||
return min_cost, x |
20 changes: 20 additions & 0 deletions
20
src/csromer/wrappers/reconstructors/csromer_3d_reconstructor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Union | ||
|
||
import numpy as np | ||
from astropy.stats import sigma_clipped_stats | ||
|
||
from ...dictionaries import Wavelet | ||
from ...objectivefunction import L1, TSV, TV, Chi2, OFunction | ||
from ...optimization import FISTA | ||
from ...reconstruction import Parameter | ||
from ...transformers.dfts import NDFT1D, NUFFT1D | ||
from ...transformers.flaggers.flagger import Flagger | ||
from .faraday_reconstructor import FaradayReconstructorWrapper | ||
|
||
|
||
@dataclass(init=True, repr=True) | ||
class CSROMER3DReconstructorWrapper(FaradayReconstructorWrapper): | ||
sigma_threshold_p: float = None | ||
sigma_threshold_intensity: float = None | ||
spectral_index: Union[Union[str, float], np.ndarray] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters