# Demo: Bayesian Image Deconvolution
$$\arg \min_{\mathbf{x}} \; \frac{1}{2} \|\mathbf{y} - \mathbf{F} \mathbf{x}
  \|_2^2 + \lambda \| \nabla \mathbf{x} \|_{2,1} +
  \iota_{\mathrm{+}}(\mathbf{x}) \;,$$

In [None]:
import time as t
import warnings
from functools import partial

import matplotlib.pyplot as plt
import napari
import numpy as np
import scipy as sp
import scipy.signal
import skimage.io
from jupyter_compare_view import compare
from pycsou.abc import DiffFunc, DiffMap, LinOp, Map, ProxFunc
from pycsou.operator import SquaredL2Norm
from pycsou.operator.interop import from_sciop, from_source, from_torch
from pycsou.operator.interop.torch import *
from pycsou.runtime import Precision, Width, enforce_precision
from pycsou.util import get_array_module, to_NUMPY

warnings.filterwarnings("ignore")

plt.style.use("seaborn-darkgrid")
plt.rcParams["figure.figsize"] = [9, 6]
plt.rcParams["figure.dpi"] = 150
plt.rcParams["axes.grid"] = False
plt.rcParams["image.cmap"] = "viridis"

rng = np.random.default_rng(seed=0)


def monochromatic(im, chan=0):
    out = im.copy()
    xp = get_array_module(im)
    mask = xp.ones(im.shape, dtype=bool)
    mask[chan] = False
    out[mask] = 0
    return out


def imshow(im, rgb=False):
    im = to_NUMPY(im)
    if im.ndim > 2 and rgb:
        plt.subplot(2, 2, 1)
        plt.imshow(np.moveaxis(im, 0, -1))
        plt.subplot(2, 2, 2)
        plt.imshow(np.moveaxis(monochromatic(im, 0), 0, -1))
        plt.subplot(2, 2, 3)
        plt.imshow(np.moveaxis(monochromatic(im, 1), 0, -1))
        plt.subplot(2, 2, 4)
        plt.imshow(np.moveaxis(monochromatic(im, 2), 0, -1))
    elif im.ndim > 2 and not rgb:
        plt.imshow(np.moveaxis(im, 0, -1))
    else:
        plt.imshow(im, cmap="gray")
    plt.axis("off")


def imshow_compare(*images, **kwargs):
    images = [to_NUMPY(im) for im in images]
    images = [np.clip(im, 0, 1) for im in images]
    images = [np.moveaxis(im, 0, -1) if im.ndim > 2 else im for im in images]
    return compare(
        *images, height=700, add_controls=True, display_format="jpg", **kwargs
    )


warnings.filterwarnings("ignore")

In [None]:
from pycsou.operator import DiagonalOp, FFT , Pad, SubSample, Sum
from pycsou.util import view_as_complex, view_as_real
from numpy.fft import fftshift, fft, ifft
from pycsou.operator import block_diag, hstack, vstack
import pycsou.util as pycu
import pycsou.abc as pyca


class Roll(pyca.UnitOp):
    def __init__(self, arg_shape, axes=None, shift=None):
        self.arg_shape = arg_shape
        if axes is None:
            self.axes = tuple(range(len(arg_shape)))
        if shift is None:
            self.shift = [dim // 2 for dim in arg_shape]
        
        self.shift_adjoint = [-sh for sh in self.shift]
        dim = np.prod(arg_shape).item()
        super().__init__(shape=(dim, dim))

    
    def apply(self, arr):
        sh = arr.shape[:-1]
        arr = arr.reshape(*sh, *self.arg_shape)
        xp = pycu.get_array_module(arr)
        return xp.roll(arr, self.shift, self.axes).reshape(*sh, -1)
    
    def adjoint(self, arr):
        sh = arr.shape[:-1]
        arr = arr.reshape(*sh, *self.arg_shape)
        xp = pycu.get_array_module(arr)
        return xp.roll(arr, self.shift_adjoint, self.axes).reshape(*sh, -1)

    
def ComplexMult(arr):
    
    mask_r = SubSample((arr.size,), slice(0,None, 2))
    mask_i = SubSample((arr.size,), slice(1,None, 2))
    
    arr_r = DiagonalOp(mask_r(arr))
    arr_i = DiagonalOp(mask_i(arr))
    
    # First compute real part
    real = mask_r.T * ( arr_r * mask_r - arr_i * mask_i)
    # Second compute imaginary part
    imag = mask_i.T * ( arr_i * mask_r + arr_r * mask_i)
        
    return real + imag


def PsfFourier(psf):
    # PSF and input arrays are assumed to have both the same shape
    shape = np.array(psf.shape)
    # Both are padded to have 2 N - 1 shape
    size = shape * 2  - 1
    #fsize = 2 ** np.ceil(np.log2(size)).astype(int)
    pad = Pad(arg_shape=psf.shape,pad_width=[( sh, sh )  for sh in (size - shape)//2])
    ft = FFT(arg_shape=pad._pad_shape, real=True)
    fft_f = ft * pad
    psf_fft = fft_f(psf.ravel())
    mult = ComplexMult(psf_fft)
    
    fft_shape = np.array(shape) * 2
    roll = Roll(pad._pad_shape)
    
    startind = (pad._pad_shape - shape) // 2
    endind = startind + shape
    slices = [slice(startind[k], endind[k]) for k in range(len(endind))]
    center = SubSample(pad._pad_shape, *slices)


    op = (1 / ft.dim) * center * roll * ft.T * mult * fft_f 
    
    setattr(op, "center", center)
    setattr(op, "roll", roll)
    setattr(op, "ft", ft)
    setattr(op, "mult", mult)
    setattr(op, "fft_f", fft_f)
    setattr(op, "pad", pad)
    
    return op

In [None]:
gpu=False

In [None]:
# Load data

from utils import downsample_volume, epfl_deconv_data

y, psf = [], []
for channel in range(3):
    y_, psf_ = epfl_deconv_data(channel)
    y_ = downsample_volume(y_, 2)
    psf_ = downsample_volume(psf_, 2)
    
    if gpu:
        import cupy as xp
    else:
        xp = np
    y_ = xp.asarray(y_)
    psf_ = xp.asarray(psf_)

    # Same preprocessing as in Scico
    y_ -= y_.min()
    y_ /= y_.max()
    psf_ /= psf_.sum()

    y.append(y_)
    psf.append(psf_)

y = np.stack(y)
psf = np.stack(psf)

In [None]:
print(f"{y.shape=}, {y.dtype=}")
print(f"{psf.shape=}, {psf.dtype=}")

In [None]:
if gpu:
    viewer = napari.view_image(y.get().T, rgb=True)
else:
    viewer = napari.view_image(y.T, rgb=True)

# Least squares (PINV)

In [None]:
from pycsou.opt.stop import MaxIter, RelError
from pycsou.operator import Jacobian, PositiveL1Norm, L21Norm

In [None]:
forwards = [PsfFourier(psf_) for psf_ in psf]

In [None]:

default_stop_crit = (
    RelError(eps=1e-3, var="x", f=None, norm=2, satisfy_all=True) & MaxIter(20)
) | MaxIter(400)


x_pinv_recons = xp.stack([
    forward.pinv(
        y_channel.ravel(), 
        damp=1., 
        kwargs_init=dict(show_progress=True, verbosity=10),
        kwargs_fit=dict(stop_crit=default_stop_crit)
    ).reshape(y_channel.shape) for y_channel, forward in zip(y, forwards)
])
x_pinv_recons /= x_pinv_recons.max((1,2,3), keepdims=True)

In [None]:
if gpu:
    viewer = napari.view_image(x_pinv_recons.get().reshape(y.shape).T, rgb=True)
else:
    viewer = napari.view_image(x_pinv_recons.reshape(y.shape).T, rgb=True)

# MAP

In [None]:
import pycsou.runtime as pycrt
from pycsou.opt.solver import PD3O

# Stopping criterion
default_stop_crit = (
    RelError(eps=1e-3, var="x", f=None, norm=2, satisfy_all=True)
    & RelError(eps=1e-3, var="z", f=None, norm=2, satisfy_all=True)
    & MaxIter(20)
) | MaxIter(500)

from pycsou.operator import Gradient, Sum

grad = Gradient(arg_shape=y.shape[1:], 
                diff_method="fd", 
                accuracy=4,
                sampling=[0.64, 0.64, 1.6],
                gpu=True) 

grad.lipschitz(tight=False)

λ1 = 1e-1
λ2 = 1e-4

posl1 = PositiveL1Norm(dim=y[0].size)

x_recons_tv = []

for y_channel, forward in zip(y, forwards):    
    sl2 = SquaredL2Norm(dim=y_channel.size).asloss(y_channel.ravel())
    sl2.diff_lipschitz()

    l21_norm = L21Norm(arg_shape=(3, *y_channel.shape), l2_axis=(0,))
    
    loss = sl2 * forward
    loss.diff_lipschitz()
    
    solver = PD3O(
        f=loss, g=λ1 * posl1, h= λ2 * l21_norm, K=grad, show_progress=True, verbosity=100
    )
    # Fit
    with pycrt.Precision(pycrt.Width.SINGLE):
        solver.fit(x0=0 * y_channel.ravel(), tuning_strategy=2, stop_crit=default_stop_crit)
        x_recons_tv.append(solver.solution().reshape(y_channel.shape))

x_recons_tv = xp.stack(x_recons_tv)

In [None]:
if gpu:
    viewer = napari.view_image(x_recons_tv.get().T, rgb=True)
else:
    viewer = napari.view_image(x_recons_tv.T, rgb=True)

In [None]:
if gpu:
    plt.hist(x_recons_tv.reshape(3, -1).T.get(), bins=np.linspace(0,1,20))
else:
    plt.hist(x_recons_tv.reshape(3, -1).T, bins=np.linspace(0,1,20))

In [None]:
imshow(psf[2].sum(-1))

In [None]:
plt.plot(psf[2].sum((0,1)).get())

In [None]:
plt.plot(psf[2,..., 26].sum((1)).get())