### Backprojection idea

Variables:
- $M, N, C$ the height, width and number of channels;
- $\boldsymbol{y} \in \mathbb{R}^{MN}$ the observation;
- $\boldsymbol{x} \in \mathbb{R}^{MNC}$ a reconstruction;
- $\boldsymbol{A} \in \mathbb{R}^{MN \times MNC}$ the degradation operator (or forward operator).

#### Usual problem to be solved:

$$\boldsymbol{\hat{x}} = \argmin \frac{1}{2} \| \boldsymbol{Ax} - \boldsymbol{y} \|_2^2 + \lambda \mathcal{R}(\boldsymbol{x})$$

Where the data fidelity's norm is the norm of $\mathbb{R}^{MN}$ (the observation space).


#### Backprojected problem:

$$\boldsymbol{\hat{x}} = \argmin \frac{1}{2} \| \boldsymbol{P} (\boldsymbol{Ax} - \boldsymbol{y}) \|_2^2 + \lambda \mathcal{R}(\boldsymbol{x})$$

Where $\boldsymbol{P} \in \mathbb{R}^{MNC \times MN}$ the backprojection matrix and the norm is now the norm of $\mathbb{R}^{MNC}$ (the reconstruction space).


#### The backprojection matrix

Two possibilities are explored here:

The adjoint of the forward operator:
$$\boldsymbol{P} = \boldsymbol{A^T}$$

The pseudo-inverse of the forward operator:
$$\boldsymbol{P} = \boldsymbol{A^\dagger} = \boldsymbol{A^T} (\boldsymbol{AA^T})^{-1}

In [2]:
import numpy as np
import jax.numpy as jnp
from scico import functional, linop, loss, metric
from scico.optimize import PDHG
from scico.optimize.admm import ADMM, LinearSubproblemSolver

import matplotlib.pyplot as plt

from scipy.sparse.linalg import splu

from time import perf_counter

from src.input_initialization import *
from src.forward_operator.operators import cfa_operator
from src.forward_operator.forward_operator import forward_operator
from src.inversions.baseline_method.inversion_baseline import *

### General parameters

In [3]:
INPUT_DIR = 'input/'

CFA = 'sparse_3'
BINNING = CFA == 'quad_bayer'
MAX_ITER = 400
NOISE_LEVEL = 5

### Ground truth and raw acquisition

In [4]:
input_name = '01690'
gt, spectral_stencil = initialize_input(INPUT_DIR + input_name + '.png')

cfa_op = cfa_operator(CFA, gt.shape, spectral_stencil)
forward_op = forward_operator([cfa_op])
mat = forward_op.matrix
tmp = splu(mat @ mat.T)

acq = np.clip(forward_op.direct(gt) + np.random.normal(0, NOISE_LEVEL / 100, gt.shape[:-1]), 0, 1)

### Data fidelity term (no backprojection)

In [5]:
def forward_pass(x):
    return jnp.array(forward_op.direct(x))

def adjoint_pass(y):
    return jnp.array(forward_op.adjoint(y))

A = linop.LinearOperator(input_shape=gt.shape, output_shape=acq.shape, eval_fn=forward_pass, adj_fn=adjoint_pass)
f = loss.SquaredL2Loss(y=jnp.array(acq), A=A)



### Data fidelity term with adjoint backprojection

In [7]:
def adj_forw_pass(x):
    return jnp.array(forward_op.adjoint(forward_op.direct(x)))

A = linop.LinearOperator(input_shape=gt.shape, output_shape=gt.shape, eval_fn=adj_forw_pass, adj_fn=adj_forw_pass)
f_adj = loss.SquaredL2Loss(y=jnp.array(forward_op.adjoint(acq)), A=A)

### Data fidelity term with pseudo-inverse backprojection

In [6]:
def pinv_forw_pass(x):
    return jnp.array(forward_op.adjoint(tmp.solve(np.array(forward_op.direct(x).reshape(-1))).reshape(gt.shape[:-1])))

A = linop.LinearOperator(input_shape=gt.shape, output_shape=gt.shape, eval_fn=pinv_forw_pass, adj_fn=pinv_forw_pass)
f_pinv = loss.SquaredL2Loss(y=jnp.array(forward_op.adjoint(tmp.solve(acq.reshape(-1)).reshape(gt.shape[:-1]))), A=A)

### Baseline inversion

In [7]:
baseline_inverse = Inverse_problem(CFA, BINNING, 0, gt.shape, spectral_stencil)
start = perf_counter()
x_baseline = jnp.array(baseline_inverse(acq))
mid_1 = perf_counter()

### TV PDHG (no backprojection)

In [None]:
lambd = 5e-3

g = functional.L21Norm(l2_axis=(0, 3))
C = lambd * linop.FiniteDifference(input_shape=gt.shape, append=0, axes=(0, 1))
C_squared_norm = np.float64(linop.operator_norm(C))**2

sigma = 1e2
tau = 0.99 / (sigma * C_squared_norm)

solver_PDHG = PDHG(
    f=f,
    g=g,
    C=C,
    tau=tau,
    sigma=sigma,
    x0=x_baseline,
    maxiter=MAX_ITER
)

### TV PDHG with adjoint backprojection

In [None]:
lambd = 5e-3

g = functional.L21Norm(l2_axis=(0, 3))
C = lambd * linop.FiniteDifference(input_shape=gt.shape, append=0, axes=(0, 1))
C_squared_norm = np.float64(linop.operator_norm(C))**2

sigma = 1e2
tau = 0.99 / (sigma * C_squared_norm)

solver_PDHG_adj = PDHG(
    f=f_adj,
    g=g,
    C=C,
    tau=tau,
    sigma=sigma,
    x0=x_baseline,
    maxiter=MAX_ITER
)

### TV PDHG with pseudo-inverse backprojection

In [None]:
lambd = 5e-2

g = functional.L21Norm(l2_axis=(0, 3))
C = lambd * linop.FiniteDifference(input_shape=gt.shape, append=0, axes=(0, 1))
C_squared_norm = np.float64(linop.operator_norm(C))**2

sigma = 3e1
tau = 0.4 / (sigma * C_squared_norm)

solver_PDHG_pinv = PDHG(
    f=f_pinv,
    g=g,
    C=C,
    tau=tau,
    sigma=sigma,
    x0=x_baseline,
    maxiter=MAX_ITER
)

### BM3D PnP ADMM (no backprojection)

In [None]:
lambd = 5e-2

g = lambd * 6e-2 * functional.BM3D(is_rgb=True)
C = linop.Identity(input_shape=gt.shape)

rho = 2 * lambd * 10**-1

solver_BM3D = ADMM(
    f=f,
    g_list=[g],
    C_list=[C],
    rho_list=[rho],
    x0=x_baseline,
    maxiter=MAX_ITER // 20,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={'tol': 1e-3, 'maxiter': 100})
)

### BM3D PnP ADMM with adjoint backprojection

In [None]:
lambd = 5e-2

g = lambd * 6e-2 * functional.BM3D(is_rgb=True)
C = linop.Identity(input_shape=gt.shape)

rho = 2 * lambd * 10**-1

solver_BM3D_adj = ADMM(
    f=f_adj,
    g_list=[g],
    C_list=[C],
    rho_list=[rho],
    x0=x_baseline,
    maxiter=MAX_ITER // 20,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={'tol': 1e-3, 'maxiter': 100})
)

### BM3D PnP ADMM with pseudo-inverse backprojection

In [None]:
lambd = 5e-2

g = lambd * 6e-2 * functional.BM3D(is_rgb=True)
C = linop.Identity(input_shape=gt.shape)

rho = 2 * lambd * 10**-1

solver_BM3D_pinv = ADMM(
    f=f_pinv,
    g_list=[g],
    C_list=[C],
    rho_list=[rho],
    x0=x_baseline,
    maxiter=MAX_ITER // 20,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={'tol': 1e-3, 'maxiter': 100})
)

### Running the solvers

In [None]:
# mid_2 = perf_counter()
# x_PDHG = solver_PDHG.solve()
# mid_3 = perf_counter()
# x_PDHG_adj = solver_PDHG_adj.solve()
# mid_4 = perf_counter()
# x_PDHG_pinv = solver_PDHG_pinv.solve()
# mid_5 = perf_counter()
# x_BM3D = solver_BM3D.solve()
# mid_6 = perf_counter()
# x_BM3D_adj = solver_BM3D_adj.solve()
mid_7 = perf_counter()
x_BM3D_pinv = solver_BM3D_pinv.solve()
end = perf_counter()

### Plotting the results

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True, figsize=(20, 20))
ax[0][0].imshow(gt)
ax[0][0].set_title('Reference', size=20)
ax[0][1].imshow(acq, cmap='gray')
ax[0][1].set_title(f'Raw image, noise level: {NOISE_LEVEL}', size=20)
ax[0][2].imshow(x_baseline)
ax[0][2].set_title(f'Baseline: {metric.psnr(gt, x_baseline):.2f} (dB), {mid_1 - start:.1f}s', size=20)
# ax[1][0].imshow(x_PDHG)
# ax[1][0].set_title(f'TV: {metric.psnr(gt, x_PDHG):.2f} (dB), {mid_3 - mid_2:.1f}s', size=20)
# ax[1][1].imshow(x_PDHG_adj)
# ax[1][1].set_title(f'TV adj backproj: {metric.psnr(gt, x_PDHG_adj):.2f} (dB), {mid_4 - mid_3:.1f}s', size=20)
# ax[1][2].imshow(x_PDHG_pinv)
# ax[1][2].set_title(f'TV pinv backproj: {metric.psnr(gt, x_PDHG_pinv):.2f} (dB), {mid_5 - mid_4:.1f}s', size=20)
# ax[2][0].imshow(x_BM3D)
# ax[2][0].set_title(f'PnP: {metric.psnr(gt, x_BM3D):.2f} (dB), {mid_6 - mid_5:.1f}s', size=20)
# ax[2][1].imshow(x_BM3D_adj)
# ax[2][1].set_title(f'PnP adj backproj: {metric.psnr(gt, x_BM3D_adj):.2f} (dB), {mid_7 - mid_6:.1f}s', size=20)
ax[2][2].imshow(x_BM3D_pinv)
ax[2][2].set_title(f'PnP pinv backproj: {metric.psnr(gt, x_BM3D_pinv):.2f} (dB), {end - mid_7:.1f}s', size=20)

# fig.savefig('backproj_comparisons.png')

### Plotting the erros

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True, figsize=(20, 20))
ax[0][0].imshow(gt)
ax[0][0].set_title('Reference', size=20)
ax[0][1].imshow(acq, cmap='gray')
ax[0][1].set_title(f'Raw image, noise level: {NOISE_LEVEL}', size=20)
ax[0][2].imshow(gt -x_baseline)
ax[0][2].set_title(f'Baseline: {metric.psnr(gt, x_baseline):.2f} (dB), {mid_1 - start:.1f}s', size=20)
ax[1][0].imshow(gt -x_PDHG)
ax[1][0].set_title(f'TV: {metric.psnr(gt, x_PDHG):.2f} (dB), {mid_3 - mid_2:.1f}s', size=20)
ax[1][1].imshow(gt -x_PDHG_adj)
ax[1][1].set_title(f'TV adj backproj: {metric.psnr(gt, x_PDHG_adj):.2f} (dB), {mid_4 - mid_3:.1f}s', size=20)
ax[1][2].imshow(gt -x_PDHG_pinv)
ax[1][2].set_title(f'TV pinv backproj: {metric.psnr(gt, x_PDHG_pinv):.2f} (dB), {mid_5 - mid_4:.1f}s', size=20)
# ax[2][0].imshow(gt -x_BM3D)
# ax[2][0].set_title(f'PnP: {metric.psnr(gt, x_BM3D):.2f} (dB), {mid_6 - mid_5:.1f}s', size=20)
# ax[2][1].imshow(gt -x_BM3D_adj)
# ax[2][1].set_title(f'PnP adj backproj: {metric.psnr(gt, x_BM3D_adj):.2f} (dB), {mid_7 - mid_6:.1f}s', size=20)
# ax[2][2].imshow(gt -x_BM3D_pinv)
# ax[2][2].set_title(f'PnP pinv backproj: {metric.psnr(gt, x_BM3D_pinv):.2f} (dB), {end - mid_7:.1f}s', size=20)

# fig.savefig('backproj_comparisons_errors.png')