# The Split-Bregman Algorithm: Sparsity Enforcing Inversions

@Author: Francesco Picetti - francesco.picetti@polimi.it

In this notebook we will show the Split-Bregman algorithm,
as described by [Goldstein and Osher, 2009](https://doi.org/10.1137/080725891).

In its generalized unconstrained formulation (the same we handle in this library), this algorithm
takes what we call `ProblemLinearReg` (see `pyProblem.py`):

\begin{equation}
    \arg \min_\mathbf{m} \Vert  \mathbf{A} \mathbf{m} - \mathbf{d}\Vert_2^2 + \sum_i \varepsilon_{L2,i}\Vert \mathbf{R}_{L2,i} \mathbf{m} - \mathbf{p}_i\Vert_2^2 + \sum_i \varepsilon_{L1,i}\Vert \mathbf{R}_{L1,i} \mathbf{m}\Vert_1,
\end{equation}

being $\mathbf{m}$ our model, $\mathbf{d}$ the observed data, $\mathbf{A}$ a linear modeling operator,
$\mathbf{R}_{L2,i}$ the $i$-th linear regularizer to which we associate a prior $\mathbf{p}_i$ along with its weight $\varepsilon_{L2,i}$,
$\mathbf{R}_{L1,i}$ the $i$-th sparsity promoting regularizer and its weight $\varepsilon_{L1,i}$.

#### Import modules

In [2]:
# Importing necessary modules
import os
import genericIO
import numpy as np
import pyOperator as pyOp
import pyVector as pyVec
import pyNpOperator as NpOp
from pyProblem import ProblemLinearReg, ProblemL2Linear, ProblemL1Lasso
from pyLinearSolver import LSQRsolver, LCGsolver
from pySparseSolver import ISTAsolver, SplitBregmanSolver
from pyStopper import BasicStopper
from sys_util import logger

# Plotting library
from matplotlib import pyplot as plt
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
%matplotlib inline
params = {
    'image.cmap': 'gray',
    'axes.grid': False,
    'savefig.dpi': 300,  # to adjust notebook inline plot size
    'axes.labelsize': 14, # fontsize for x and y labels (was 10)
    'axes.titlesize': 14,
    'font.size': 14,
    'legend.fontsize': 14,
    'xtick.labelsize': 14,
    'ytick.labelsize': 14,
    'text.usetex':True
}
matplotlib.rcParams.update(params)

ModuleNotFoundError: No module named 'genericIO'

##  Example 1:  1D Ricker Deconvolution

In this example we want to invert a 1D seismic trace for the Earth's reflectivity.
We suppose the wavelet to be known, yielding  to a deterministic deconvolution problem.

In [None]:
# load wavelet
ricker = genericIO.defaultIO.getVector("../testdata/ricker1.H")
waveletHyper = ricker.getHyper()
nt = waveletHyper.getAxis(1).n
ot = waveletHyper.getAxis(1).o
dt = waveletHyper.getAxis(1).d

# load reflectivity
model = genericIO.defaultIO.getVector('../testdata/ref_randomspikes.H')

# instantating operator
A = NpOp.ConvNDscipy(model, ricker)

# Generating true recorded trace
data = A * model

In [None]:
fig, ax = plt.subplots(figsize=(5,3))
time_range = np.linspace(ot,ot+(nt-1)*dt,nt)
plt.plot(np.linspace(ot,ot+(nt-1)*dt,nt), ricker.getNdArray())
plt.title('Wavelet'), plt.xlabel("Time [s]"), plt.ylabel("Amplitude")
ax.autoscale(enable=True, axis='x', tight=True)
plt.show()

fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(model.getNdArray(), 'k',  label='model')
plt.plot(data.getNdArray(), 'b',  label='data')
ax.autoscale(enable=True, axis='x', tight=True)
plt.ylim(-1., 1.)
plt.legend()
plt.grid(True)
plt.show()

#### Solve with Conjugate Gradient

Conjugate gradient is a fast and powerful adjoint-state algorithm for solving problems in the form:

\begin{equation}
    \hat{\mathbf{m}} = \arg \min_\mathbf{m} \Vert \mathbf{A}\mathbf{m} - \mathbf{d}\Vert_2^2
\end{equation}

In [None]:
# instantiate a linear L2 problem
problemCG = ProblemL2Linear(model.clone().zero(), data, A)

# define the solver with 10'000 iterations
CG = LCGsolver(BasicStopper(10000))

# solve the problem
CG.run(problemCG, verbose=True)

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(model.getNdArray(), 'k', label="true model")
plt.plot(problemCG.model.getNdArray(), 'r--', label="CG")
plt.title('Conjugate Gradient, %d iterations' % CG.stopper.niter)
plt.ylim(-1., 1.)
ax.autoscale(enable=True, axis='x', tight=True)
plt.grid(True)
plt.legend()
plt.show()

#### Impose sparsity in the solution: FISTA

Iterative Shrinkage-Thresholding Algorithms solve the so-called LASSO problem:
\begin{equation}
    \hat{\mathbf{m}} = \arg \min_\mathbf{m} \Vert \mathbf{A}\mathbf{m} - \mathbf{d}\Vert_2^2 + \lambda \Vert \mathbf{m}\Vert_1
\end{equation}

If we provide the operator's maximum eigenvalue $\eta$ we can build a fast version
of this algorithm by imposing the step $\alpha < 1/\eta^2$ as proposed in [Beck and Teboulle, 2009](https://doi.org/10.1137/080716542).

In [None]:
# compute maximum eigenvalue
maxeig = A.powerMethod()
print('FISTA α=%.6e' % (1/(maxeig**2)))

# define the LASSO problem
problemFISTA = ProblemL1Lasso(model.clone().zero(), data, A, lambda_value=1, op_norm=maxeig**2)

# instantiate the fast solver with 1000 iterations
FISTA = ISTAsolver(BasicStopper(1000), fast=True)

# solve the problem
FISTA.run(problemFISTA, verbose=True)

In [None]:
fig, ax = plt.subplots(figsize=(12,3))
plt.plot(model.getNdArray(), 'k', label="true model")
plt.plot(problemFISTA.model.getNdArray(), 'r--', label="FISTA")
plt.title('Fast ISTA, λ=%.e' % problemFISTA.lambda_value)
plt.ylim(-1, 1)
ax.autoscale(enable=True, axis='x', tight=True)
plt.grid(True)
plt.legend()
plt.show()

Note that FISTA is able to recover the model's kinematics but not the true amplitude.
Indeed, this algorithm is very sensitive to $\lambda$ parameter.
Let's try another solver.

#### Split-Bregman for LASSO problem

In [None]:
# define the Linear Regularized Problem
problemSB = ProblemLinearReg(model.clone().zero(), data, A,
                             regsL1=pyOp.IdentityOp(model), epsL1=.1)

# instantiate the Split-Bregman solver
SB = SplitBregmanSolver(BasicStopper(1000), niter_inner=5, niter_solver=15,
                        linear_solver='LSQR', breg_weight=1., warm_start=True)

# solve the problem
SB.run(problemSB, verbose=True, inner_verbose=False)

In [None]:
fig, ax = plt.subplots(figsize=(12,3))
plt.plot(model.getNdArray(), 'k', label="true model")
plt.plot(problemFISTA.model.getNdArray(), 'r--', label="SB")
plt.title('Split-Bregman, ε=%.e' % problemSB.epsL1)
plt.ylim(-1, 1)
ax.autoscale(enable=True, axis='x', tight=True)
plt.grid(True)
plt.legend()
plt.show()

The results is much better than FISTA output: the dynamics is almost-perfectly recovered.


##  Example 2:  1D Velocity Deconvolution

Now we create a synthetic velocity profile and suppose to have recorded a smooth version of it.
The deconvolution problem aims at recovering the sharp model, and this is done by using a Total Variation regularizer.

In [None]:
# create the true model
nx = 201
x = pyVec.vectorIC(np.zeros((nx,), dtype=np.float32)).zero()
x.getNdArray()[20:30] = 10.
x.getNdArray()[50:75] = -5.
x.getNdArray()[100:150] = 2.5
x.getNdArray()[175:180] = 7.5

# instantiate the blurring operator
G = NpOp.GaussianFilter(x, 2.0)

# simulate data
y = G * x

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(x.getNdArray(), 'k',  label='model')
plt.plot(y.getNdArray(), 'b',  label='data')
ax.autoscale(enable=True, axis='x', tight=True)
plt.ylim(-6, 12)
plt.legend()
plt.grid(True)
plt.tight_layout(.5)
plt.show()

As we can see, the model presents some flat regions. Let's have a look to its first derivative:

In [None]:
TV = NpOp.Gradient(x)
Dx = TV * x

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(x.getNdArray(), 'k',  label='model')
plt.plot(Dx.getNdArray(), 'b',  label='gradient')
ax.autoscale(enable=True, axis='x', tight=True)
plt.ylim(-6, 12)
plt.legend()
plt.grid(True)
plt.tight_layout(.5)
plt.show()

The model derivative is sparse! We can set up a SplitBregman solver:

In [None]:
problem = ProblemLinearReg(x.clone().zero(), y, G, regsL1=TV, epsL1=.1)

SB = SplitBregmanSolver(BasicStopper(1000), niter_inner=5, niter_solver=10,
                        linear_solver='LSQR', breg_weight=1., warm_start=True)
SB.setDefaults(save_obj=True, save_model=True)

SB.run(problem, verbose=True, inner_verbose=False)

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(x.getNdArray(), 'k', label="true model")
plt.plot(problem.model.getNdArray(), 'r--', label="SB")
plt.title('Split-Bregman, ε=%.e' % problem.epsL1)
plt.ylim(-6, 12)
ax.autoscale(enable=True, axis='x', tight=True)
plt.grid(True)
plt.legend()
plt.tight_layout(.5)
plt.show()

##  Example 3: 2D Phantom Deconvolution

In this example we reconstruct a phantom CT image starting from a blurred acquisition.
Again, we regularize the inversion by imposing the first derivative to be sparse.

Model:

In [None]:
x = pyVec.vectorIC(np.load('../testdata/shepp_logan_phantom.npy', allow_pickle=True).astype(np.float32))
x = x.scale(1 / 255.)

plt.figure(figsize=(5, 4))
plt.imshow(x.getNdArray(), cmap='bone', vmin=x.min(), vmax=x.max()), plt.colorbar()
plt.title('Model')
plt.tight_layout(.5)
plt.show()

Data:

In [None]:
Op = NpOp.GaussianFilter(x, [3, 3])
y = Op * x

plt.figure(figsize=(5, 4))
plt.imshow(y.getNdArray(), cmap='bone', vmin=x.min(), vmax=x.max()), plt.colorbar()
plt.title('Data')
plt.tight_layout(.5)
plt.show()

Problem and solver:

In [None]:
G = NpOp.Gradient(x)

x_grad = G*x
plt.imshow(x_grad.vecs[1].getNdArray(), cmap='gray'), plt.colorbar(), plt.title('Gx'), plt.show()
plt.imshow(x_grad.vecs[0].getNdArray(), cmap='gray'), plt.colorbar(), plt.title('Gz'), plt.show()

# Anisotropic TV
x_grad_aniso = x.clone().zero()
G.merge_directions(False, x_grad, x_grad_aniso, iso=False)
plt.figure(figsize=(5, 4))
plt.imshow(x_grad_aniso.getNdArray(), cmap='gray'), plt.colorbar()
plt.title('Sum of G')
plt.show()

In [None]:
problemSB = ProblemLinearReg(x.clone().zero(), y, Op, regsL1=G, epsL1=1e-3)
SB = SplitBregmanSolver(BasicStopper(niter=200), niter_inner=5, niter_solver=15,
                        linear_solver='LSQR', breg_weight=1., warm_start=True)
SB.setDefaults(save_obj=True, save_model=True)
SB.run(problemSB, verbose=True, inner_verbose=False)

In [None]:
plt.figure(figsize=(5, 4))
plt.imshow(problemSB.model.getNdArray(), cmap='bone', vmin=x.min(), vmax=x.max()), plt.colorbar()
plt.title(r'SB TV, $\varepsilon=%.e$, %d iter' % (problemSB.epsL1[0], SB.stopper.niter))
plt.tight_layout(.5)
plt.show()

In [None]:
plt.figure(figsize=(5, 4))
plt.semilogy(SB.obj / SB.obj[0], 'b', lw=3, label='Obj')
plt.semilogy(np.asarray(SB.obj_terms)[:, 0] / SB.obj[0], 'r--', label='0.5 $\Vert\cdot\Vert_2^2$')
plt.semilogy(np.asarray(SB.obj_terms)[:, 1] / SB.obj[0], 'y--', label='ε $\Vert\cdot\Vert_1$')
plt.legend()
plt.grid(True)
plt.tight_layout(.5)
plt.title('Convergence curve')
plt.show()