In [None]:
!pip install pyxu[viz]@git+https://github.com/pyxu-org/pyxu.git@v3-dev
!pip install scipy scikit-image

# Build Your Deconvolution Algorithm

In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import optax
from scipy.ndimage import gaussian_filter
from skimage.draw import ellipse
from skimage.util import random_noise

import pyxu.abc as pxa
import pyxu.math as pxm
import pyxu.operator as pxo
import pyxu.util as pxu

In [None]:
# A helper function to simplify plotting.

def plot_image(ax, img: npt.ArrayLike, title: str = ""):
    ax.imshow(img, cmap="gray")
    ax.set_title(title)
    ax.axis("off")

## Ground truth generation

Let's generate some ground truth microscopy data.

In [None]:
# A helper function to generate cell-like images.

def generate_textured_cells(
    image_size: tuple[int, int],
    num_cell: int,
    cell_radius: tuple[int, int],
):
    """
    Generate 2D grayscale image with randomly oriented textured ellipses.

    Parameters
    ----------
    image_size: tuple[int, int]
        (height, width)
    num_cell: int
    cell_radius: tuple[int, int]
        (min, max) cell radius in pixels

    Returns
    -------
    img: ndarray
        (height, width)
    """
    rng = np.random.default_rng(0)

    height, width = image_size
    r_min, r_max = cell_radius

    img = np.zeros((height, width))
    for _ in range(num_cell):
        # Random center within img bounds
        center = (
            rng.integers(r_max, height - r_max),
            rng.integers(r_max, width - r_max),
        )

        # Random semi-axis lengths for the ellipse
        radius = rng.integers(r_min, r_max, size=2)

        # Random orientation angle [rad]
        angle = rng.uniform(0, np.pi)

        # Generate elliptical mask
        rr, cc = ellipse(*center, *radius, shape=(height, width), rotation=angle)

        # Generate textured pattern
        texture = rng.uniform(0, 1, len(rr)) * 0.6 + 0.4

        img[rr, cc] = texture

    # Smoothen the img slightly to make cells look more natural
    img = gaussian_filter(img, sigma=1)

    # Normalize and scale to 8-bit grayscale
    img = (img - img.min()) / np.ptp(img)
    return img

In [None]:
# Generate an image with textured cells
height, width = 512, 513
img_gt = generate_textured_cells(
    image_size=(height, width),
    num_cell=20,
    cell_radius=(20, 40),
)

# Display the image
fig, ax = plt.subplots()
plot_image(ax, img_gt, "ground truth")

## Modeling the Microscope

The microscope never sees the ground truth due to the physics of the acquisition system.
Let's define the forward model describing the acquisition physics, i.e. what happens between the sample and the sensor plane.
We'll assume a simple convolutional relationship between the sample and the sensor plane.

In [None]:
# Modeling the acquisition system via Pyxu operators.

class MicroscopeModel(pxa.LinearOperator):
    psf: jax.Array  # the model's sole parameter
    # point-spread function of the instrument

    def __init__(self, psf: npt.ArrayLike):
        self.dim_shape = pxu.ShapeStruct(psf.shape)
        self.codim_shape = pxu.ShapeStruct(psf.shape)

        self.psf = jnp.asarray(psf)

    def apply(self, x: jax.Array) -> jax.Array:
        return jax.scipy.signal.fftconvolve(x, self.psf, mode="same")

With a microscope model available, let's simulate what the microscope actually captures.
For this we'll generate the impulse response of the system (point-spread function; PSF), then add shot noise.

In [None]:
psf = np.zeros_like(img_gt)
psf[height//2, width//2] = 1
psf = gaussian_filter(psf, sigma=3)

fig, ax = plt.subplots()
plot_image(ax, psf, "point spread function")

In [None]:
# instantiate microscope model
model = MicroscopeModel(psf)

# apply model + add shot noise
img = random_noise(image=model(img_gt), mode="poisson").clip(0, None)

fig, ax = plt.subplots(ncols=2, figsize=[10, 10])
plot_image(ax[0], img_gt, "ground truth")
plot_image(ax[1], img, "microscope image")

## Deconvolution 1: Richardson-Lucy Method

The Richardson-Lucy algorithm (RL) is an iterative scheme to deconvolve microscope images where the forward model is convolutional and noise is Poisson-distributed. It should work relatively well if the captured data actually follows this model.

In [None]:
def rl_step(
    u: jax.Array,
    data: jax.Array,
    model: MicroscopeModel,
) -> jax.Array:
    z = data / model(u)
    u_next = model.adjoint(z) * u
    return u_next


u = model.adjoint(img)
for i in range(100):
    u_next = rl_step(u, img, model)
    rel_err = pxo.L2Norm().apply(u_next - u) / pxo.L2Norm().apply(u)
    u = u_next

    if rel_err < 1e-3:
        break
img_rl = u

fig, ax = plt.subplots(ncols=3, figsize=[15, 10])
plot_image(ax[0], img, "microscope image")
plot_image(ax[1], img_rl, f"RL deconvolved ({i} iter)")
plot_image(ax[2], img_gt, "ground truth")


The RL algorithm makes strong assumptions on the microscope model.
When these do not hold, it produces poor estimates.

In [None]:
class MicroscopeModel2(pxa.LinearOperator):
    psf: jax.Array  # instrument point-spread function
    weight: jax.Array  # pixel re-weighting function

    def __init__(self, psf: npt.ArrayLike, weight: npt.DTypeLike):
        self.dim_shape = pxu.ShapeStruct(psf.shape)
        self.codim_shape = pxu.ShapeStruct(psf.shape)

        self.psf = jnp.asarray(psf)
        self.weight = jnp.asarray(weight)
        assert self.psf.shape == self.weight.shape

    def apply(self, x: jax.Array) -> jax.Array:
        y = jax.scipy.signal.fftconvolve(x, self.psf, mode="same")
        z = self.weight * y
        return z


rng = np.random.default_rng()
weight = rng.uniform(0, 1, psf.shape)
model2 = MicroscopeModel2(psf, weight)

# apply model + add shot noise
img2 = random_noise(image=model2(img_gt), mode="poisson").clip(0, None)

u = model2.adjoint(img2)
for i in range(100):
    u_next = rl_step(u, img2, model2)
    rel_err = pxo.L2Norm().apply(u_next - u) / pxo.L2Norm().apply(u)
    u = u_next

    if rel_err < 1e-3:
        break
img_rl2 = u


fig, ax = plt.subplots(nrows=2, ncols=2, figsize=[10, 10])
plot_image(ax[0, 0], img2, "microscope image - model 2")
plot_image(ax[0, 1], img_rl, "RL - model 1")
plot_image(ax[1, 0], img_rl2, f"RL - model 2 ({i} iter)")
plot_image(ax[1, 1], img_gt, "ground truth")


## Deconvolution 2: Least-Squares Method

In [None]:
img_lsq, diag = pxm.pinv(model2, img2, tau=0.1, rtol=1e-3, atol=1e-6, max_steps=30)  # Colab users: tau=0.1 may fail due to different linalg backend -> use tau=7
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=[20, 15])
plot_image(ax[0, 0], img_gt, "ground truth")
plot_image(ax[0, 1], img_rl, "RL - model 1")
plot_image(ax[1, 0], img_rl2, f"RL - model 2 ({i} iter)")
plot_image(ax[1, 1], img_lsq, "LSQ - model 2")
ax[0, 2].axis("off")
plot_image(ax[1, 2], img2, "microscope image - model 2")

## Deconvolution 3: Total Variation Method

In [None]:
def loss(x: jax.Array, data: jax.Array, model: MicroscopeModel, tv_scale: float):
    # computes F(x, y) + \lambda_{TV} * \norm{x}_{TV}
    sl2 = pxo.SquaredL2Norm()
    data_loss = sl2(model.apply(x) - data)

    # TV regularizer
    grad_x, grad_y = jnp.gradient(x)
    tv_norm = pxo.L1Norm().apply(grad_x) + pxo.L1Norm().apply(grad_y)
    reg_loss = tv_scale * tv_norm

    return data_loss + reg_loss

tv_scale = 0.01

solver = optax.adam(learning_rate=0.01)
u = jnp.zeros_like(img2)
opt_state = solver.init(u)
for i in range(200):
    grad = jax.grad(loss)(u, img2, model2, tv_scale)
    updates, opt_state = solver.update(grad, opt_state, u)
    u_next = optax.apply_updates(u, updates)

    rel_err = pxo.L2Norm().apply(u_next - u) / pxo.L2Norm().apply(u)
    u = u_next

    # if i % 20 == 0:
    #     print(f"Objective function: {loss(u, img2, model2, tv_scale)}")
    #     print(f"Relative error: {rel_err}")

    if rel_err < 1e-3:
        break
img_tv = u

fig, ax = plt.subplots(nrows=2, ncols=3, figsize=[20, 15])
plot_image(ax[0, 0], img_gt, "ground truth")
plot_image(ax[0, 1], img2, "microscope image - model 2")
plot_image(ax[1, 0], img_rl2, f"RL - model 2")
plot_image(ax[1, 1], img_lsq, "LSQ - model 2")
ax[0, 2].axis("off")
plot_image(ax[1, 2], img_tv, "TV - model 2")