In [None]:
import time as t
import warnings
from functools import partial

import matplotlib.pyplot as plt
import napari
import numpy as np
import scipy as sp
import scipy.signal
import skimage.io
from jupyter_compare_view import compare
from pycsou.abc import DiffFunc, DiffMap, LinOp, Map, ProxFunc
from pycsou.operator import SquaredL2Norm
from pycsou.operator.interop import from_sciop, from_source, from_torch
from pycsou.operator.interop.torch import *
from pycsou.runtime import Precision, Width, enforce_precision
from pycsou.util import get_array_module, to_NUMPY

warnings.filterwarnings("ignore")

plt.style.use("seaborn-darkgrid")
plt.rcParams["figure.figsize"] = [9, 6]
plt.rcParams["figure.dpi"] = 150
plt.rcParams["axes.grid"] = False
plt.rcParams["image.cmap"] = "viridis"

rng = np.random.default_rng(seed=0)


def monochromatic(im, chan=0):
    out = im.copy()
    xp = get_array_module(im)
    mask = xp.ones(im.shape, dtype=bool)
    mask[chan] = False
    out[mask] = 0
    return out


def imshow(im, rgb=False):
    im = to_NUMPY(im)
    if im.ndim > 2 and rgb:
        plt.subplot(2, 2, 1)
        plt.imshow(np.moveaxis(im, 0, -1))
        plt.subplot(2, 2, 2)
        plt.imshow(np.moveaxis(monochromatic(im, 0), 0, -1))
        plt.subplot(2, 2, 3)
        plt.imshow(np.moveaxis(monochromatic(im, 1), 0, -1))
        plt.subplot(2, 2, 4)
        plt.imshow(np.moveaxis(monochromatic(im, 2), 0, -1))
    elif im.ndim > 2 and not rgb:
        plt.imshow(np.moveaxis(im, 0, -1))
    else:
        plt.imshow(im, cmap="gray")
    plt.axis("off")


def imshow_compare(*images, **kwargs):
    images = [to_NUMPY(im) for im in images]
    images = [np.clip(im, 0, 1) for im in images]
    images = [np.moveaxis(im, 0, -1) if im.ndim > 2 else im for im in images]
    return compare(
        *images, height=700, add_controls=True, display_format="jpg", **kwargs
    )


warnings.filterwarnings("ignore")

<p align="center">
<img src="https://matthieumeo.github.io/pycsou/html/_images/pycsou.png" alt= “” width=65%>
</p>

# A High Performance Computational Imaging Framework for Python

In [None]:
from pycsou.operator import DiagonalOp, FFT , Pad, SubSample, Sum
from pycsou.util import view_as_complex, view_as_real
from numpy.fft import fftshift, fft, ifft
from pycsou.operator import block_diag, hstack, vstack
import pycsou.util as pycu
import pycsou.abc as pyca


class Roll(pyca.UnitOp):
    def __init__(self, arg_shape, axes=None, shift=None):
        self.arg_shape = arg_shape
        if axes is None:
            self.axes = tuple(range(len(arg_shape)))
        if shift is None:
            self.shift = [dim // 2 for dim in arg_shape]
        
        self.shift_adjoint = [-sh for sh in self.shift]
        dim = np.prod(arg_shape).item()
        super().__init__(shape=(dim, dim))

    
    def apply(self, arr):
        sh = arr.shape[:-1]
        arr = arr.reshape(*sh, *self.arg_shape)
        xp = pycu.get_array_module(arr)
        return xp.roll(arr, self.shift, self.axes).reshape(*sh, -1)
    
    def adjoint(self, arr):
        sh = arr.shape[:-1]
        arr = arr.reshape(*sh, *self.arg_shape)
        xp = pycu.get_array_module(arr)
        return xp.roll(arr, self.shift_adjoint, self.axes).reshape(*sh, -1)

    
def ComplexMult(arr):
    
    mask_r = SubSample((arr.size,), slice(0,None, 2))
    mask_i = SubSample((arr.size,), slice(1,None, 2))
    
    arr_r = DiagonalOp(mask_r(arr))
    arr_i = DiagonalOp(mask_i(arr))
    
    # First compute real part
    real = mask_r.T * ( arr_r * mask_r - arr_i * mask_i)
    # Second compute imaginary part
    imag = mask_i.T * ( arr_i * mask_r + arr_r * mask_i)
        
    return real + imag


def PsfFourier(psf):
    # PSF and input arrays are assumed to have both the same shape
    shape = np.array(psf.shape)
    # Both are padded to have 2 N - 1 shape
    size = shape * 2  - 1
    #fsize = 2 ** np.ceil(np.log2(size)).astype(int)
    pad = Pad(arg_shape=psf.shape,pad_width=[( sh, sh )  for sh in (size - shape)//2])
    ft = FFT(arg_shape=pad._pad_shape, real=True)
    fft_f = ft * pad
    psf_fft = fft_f(psf.ravel())
    mult = ComplexMult(psf_fft)
    
    fft_shape = np.array(shape) * 2
    roll = Roll(pad._pad_shape)
    
    startind = (pad._pad_shape - shape) // 2
    endind = startind + shape
    slices = [slice(startind[k], endind[k]) for k in range(len(endind))]
    center = SubSample(pad._pad_shape, *slices)


    op = (1 / ft.dim) * center * roll * ft.T * mult * fft_f 
    
    setattr(op, "center", center)
    setattr(op, "roll", roll)
    setattr(op, "ft", ft)
    setattr(op, "mult", mult)
    setattr(op, "fft_f", fft_f)
    setattr(op, "pad", pad)
    
    return op

In [None]:
def _centered(arr, newshape):
    # Return the center newshape portion of the array.
    newshape = np.asarray(newshape)
    currshape = np.array(arr.shape)
    startind = (currshape - newshape) // 2
    endind = startind + newshape
    myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
    return arr[tuple(myslice)]

In [None]:
# Try 1D
import skimage

npix = 512
x = skimage.data.binary_blobs(length=npix, blob_size_fraction=0.1, n_dim=2, volume_fraction=0.01).astype(float).sum(0)
sigma = 10
psf = np.fromfunction(lambda x: (1/(2*np.pi*sigma**2)) * np.exp(-1*((x-(npix-1)/2)**2)/(2*sigma**2)), (npix,))

forward = PsfFourier(psf)

y = forward(x.ravel()).reshape(x.shape)

fig, axs = plt.subplots(1,3, figsize=(15,5))
im = axs[0].plot(x)
im = axs[1].plot(psf)
im = axs[2].plot(y)

####
s1 = x.size #np.array(x.shape)
s2 = psf.size #np.array(psf.shape)
size = s1 + s2 - 1
fsize = 2 ** np.ceil(np.log2(size)).astype(int)
fslice = slice(0, int(size))# tuple([slice(0, int(sz)) for sz in size])
y = ifft(fft(psf, fsize) * fft(x,fsize))[fslice]
y = _centered(y, x.shape)

fig, axs = plt.subplots(1,3, figsize=(15,5))
im = axs[0].plot(x)
im = axs[1].plot(psf)
im = axs[2].plot(y)

from scipy import signal
y = signal.fftconvolve(x ,psf, mode="same")
fig, axs = plt.subplots(1,3, figsize=(15,5))
im = axs[0].plot(x)
im = axs[1].plot(psf)
im = axs[2].plot(y)

In [None]:
# Try 2D

import skimage

npix = 512
x = skimage.data.binary_blobs(length=npix, blob_size_fraction=0.1, n_dim=2, volume_fraction=0.01).astype(float)
sigma = 5

psf = np.fromfunction(lambda x, y: (1/(2*np.pi*sigma**2)) * np.exp((-1*((x-(npix-1)/2)**2+(y-(npix-1)/2)**2))/(2*sigma**2)), (npix, npix))
forward = PsfFourier(psf)


y = forward(x.ravel()).reshape(x.shape)
y_noise = np.clip(rng.normal(loc=y, scale=0.05), a_min=0, a_max=1)


fig, axs = plt.subplots(1,3, figsize=(15,5))
im = axs[0].imshow(x)
plt.colorbar(im, ax=axs[0])
im = axs[1].imshow(psf)
plt.colorbar(im, ax=axs[1])
im = axs[2].imshow(y)
plt.colorbar(im, ax=axs[2])

# Least squares (PINV)

In [None]:
from pycsou.opt.stop import MaxIter, RelError

damp = 1e-6
x_pinv = forward.pinv(
    y_noise.ravel(), 
    damp=damp, 
    kwargs_init=dict(show_progress=False),
)
x_pinv = x_pinv.reshape(y_noise.shape) / x_pinv.max()

In [None]:
imshow_compare(y_noise , x_pinv)

In [None]:
from pycsou.operator import Gradient, L1Norm, L21Norm, PositiveL1Norm

sl2 = SquaredL2Norm(dim=y_noise.size).asloss(y_noise.ravel())
sl2.diff_lipschitz()
#grad = Gradient(arg_shape=y_noise.shape, diff_method="gd", sigma=[2.0, 2.0], gpu=True) % BUG
grad = Gradient(arg_shape=y_noise.shape, diff_method="gd", sigma=[2.0, 2.0])
grad.lipschitz(tight=False, tol=0.1)
posl1 = PositiveL1Norm(dim=y_noise.size)
l1_grad = L1Norm(dim=grad.codim)
loss = sl2 * forward
loss.diff_lipschitz()

In [None]:
import pycsou.runtime as pycrt
from pycsou.opt.solver import PD3O

# Stopping criterion
default_stop_crit = (
    RelError(eps=1e-3, var="x", f=None, norm=2, satisfy_all=True)
    & RelError(eps=1e-3, var="z", f=None, norm=2, satisfy_all=True)
    & MaxIter(20)
) | MaxIter(500)

In [None]:
λ1 = 5e-2
λ2 = 5e-2

print(f"{loss(y_noise.ravel())=}")
print(f"{(λ1 * posl1)(y_noise.ravel())=}")
print(f"{(λ2 * l1_grad * grad)(y_noise.ravel())=}")


In [None]:
# Initialize solver (PD3O)
solver = PD3O(
    f=loss, g=λ1 * posl1, h= λ2 * l1_grad, K=grad, show_progress=True, verbosity=50
)


# Fit
with pycrt.Precision(pycrt.Width.SINGLE):
    solver.fit(x0=y_noise.ravel(), tuning_strategy=2, stop_crit=default_stop_crit)
    y_tv = solver.solution().reshape(y_noise.shape)
    print(f"{y_tv.max()} before normalization")
    y_tv /= y_tv.max()

In [None]:
imshow_compare(y_noise , y_tv)