# Diffusion-Shock Inpainting in $SE(2)$
Diffusion-shock inpainting (DS) is a technique to fill in missing structures in images, developed in ["Diffusion-Shock Inpainting" (2023) by K. Schaefer and J. Weickert](https://link.springer.com/chapter/10.1007/978-3-031-31975-4_45) and the follow-up paper ["Regularised Diffusion-Shock Inpainting" (2023) by K. Schaefer and J. Weickert](https://arxiv.org/abs/2309.08761). In this notebook, we will look at DS applied to images lifted into $SE(2)$.

In $\mathbb{R}^2$, we can describe DS in a PDE-based formulation as
$$
\partial_t u = g(\lvert \nabla (G_{\nu} * u) \rvert^2) \underbrace{\Delta u}_{\textrm{Diffusion}} - \left(1 - g(\lvert \nabla (G_{\nu} * u) \rvert^2)\right) \underbrace{\mathrm{sgn}(\partial_{\vec{w} \vec{w}} (G_{\sigma} * u)) \lvert \nabla u \rvert}_{\textrm{Shock}},
$$
in which $g: [0, \infty) \to (0, 1]$ is a decreasing function with $g(0) = 1$, $G_{\alpha}$ is a Gaussian with standard deviation $\alpha$, and $\vec{w}$ is the dominant eigenvector of the structure tensor. It is clear then that $g$ switches between applying diffusion and shock: if the gradient of the image is small, we mostly apply diffusion, but if the gradient is large, we mostly apply shock. This makes sense, since a large gradient implies that there is a feature there, which we would like to sharpen up. 

The signum in the shock term switches between erosion and dilation. If the second derivative with respect to the dominant eigenvector of the structure tensor is positive, then we perform erosion (defined by the PDE $\partial_t u = -\lvert \nabla u \rvert$); otherwise we perform dilation (defined by the PDE $\partial_t u = -\lvert \nabla u \rvert$). In regularised DS, the signum is replaced with a soft signum, so that the selection of erosion vs dilation is less sensitive to noise.

The signum of the second derivative of the dominant eigenvector of the structure tensor is not unlike the convexity criterion we know from studying vesselness; perhaps we could replace it?

What is the correct way to extend DS to $SE(2)$? It would make sense to keep the gradients and Laplacian. For the selection of erosion vs dilation we could again look at the vesselness convexity criterion. For switching between diffusion and shock, we could maybe use some sort of line/edge detector.

In [None]:
import taichi as ti
ti.init(arch=ti.gpu, debug=False, device_memory_GB=3.5) #, kernel_profiler=True) # Use less than 4 so that we don't mix RAM and VRAM (?)
import numpy as np
from PIL import Image
# from PIL import Image
import matplotlib.pyplot as plt
# %matplotlib widget
import dsfilter

In [None]:
test_case = "spiral"

In [None]:
match test_case:
    case "spiral":
        ground_truth = dsfilter.SE2.utils.align_to_real_axis_scalar_field(np.array(Image.open("data/spiral.tif").convert("L")).astype(np.float64))
        noisy = dsfilter.SE2.utils.align_to_real_axis_scalar_field(np.array(Image.open("data/noisyspiral.tif")).astype(np.float64) / 256)
        t = 1.
        G_D_inv = 1.8 * np.array((1., 0.1, 0.0))
        G_S_inv = np.array((1., 1., 0.0))
        # Internal regularisation for switching between dilation and erosion.
        σ_1, σ_2, σ_3 = np.array((2.5, 2.5, 0.6))
        # External regularisation for switching between dilation and erosion.
        ρ_1, ρ_2, ρ_3 = np.array((1., 1., 0.6))
        # Internal and external regularisation of gradient for switching between diffusion and shock.
        ν_1, ν_2, ν_3 = np.array((2.5, 2.5, 0.6))
        λ = 0.1 # Contrast parameter for switching between diffusion and shock.
        ε = 0.5 # Regularisation parameter for signum.
    case "monalisa":
        ground_truth = dsfilter.SE2.utils.align_to_real_axis_scalar_field(np.array(Image.open("data/monalisa.tif").convert("L")).astype(np.float64) / 256)
        noisy = dsfilter.SE2.utils.align_to_real_axis_scalar_field(np.array(Image.open("data/noisymonalisa.tif")).astype(np.float64) / 256**2)
        t = 1.

clip = (ground_truth.min(), ground_truth.max())

dim_I, dim_J = ground_truth.shape
dim_K = 16
Is, Js, Ks = np.indices((dim_I, dim_J, dim_K))
x_min, x_max = 0., dim_I - 1.
y_min, y_max = 0., dim_J - 1.
θ_min, θ_max = 0., 2 * np.pi
dxy = (x_max - x_min) / (dim_I - 1)
dθ = (θ_max - θ_min) / dim_K
xs, ys, θs = dsfilter.SE2.utils.coordinate_array_to_real(Is, Js, Ks, x_min, y_min, θ_min, dxy, dθ)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
_, _, cbar = dsfilter.visualisations.plot_image_array(ground_truth, 0., dim_I - 1., 0., dim_J - 1., fig=fig, ax=ax[0])
fig.colorbar(cbar, ax=ax[0])
_, _, cbar = dsfilter.visualisations.plot_image_array(noisy, 0., dim_I - 1., 0., dim_J - 1., fig=fig, ax=ax[1])
fig.colorbar(cbar, ax=ax[1]);

### Orientation Score

In [None]:
cws_check = dsfilter.orientationscore.cakewavelet_stack(dim_I, dim_K, Gaussian_σ=dim_I/16)

In [None]:
K = 1
print(θs[0, 0, K])
fig, ax, cbar = dsfilter.visualisations.plot_image_array(cws_check.real[K], x_min, x_max, y_min, y_max, cmap="gray")
ax.set_title("$\psi$")
fig.colorbar(cbar, ax=ax);

In [None]:
cws = dsfilter.orientationscore.cakewavelet_stack(min(dim_I, dim_J), dim_K, Gaussian_σ=dim_I / 16).real
U = dsfilter.orientationscore.wavelet_transform(noisy, cws).real
U = np.transpose(U, axes=(1, 2, 0)) # x, y, θ
mask = np.zeros_like(U) # Filtering, so there is no region outside of the mask

In [None]:
K = 0
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
_, _, cbar = dsfilter.visualisations.plot_image_array(cws[K], x_min, x_max, y_min, y_max, fig=fig, ax=ax[0])
fig.colorbar(cbar, ax=ax[0])
_, _, cbar = dsfilter.visualisations.plot_image_array(U[..., K], x_min, x_max, y_min, y_max, fig=fig, ax=ax[1])
fig.colorbar(cbar, ax=ax[1]);

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
_, _, cbar = dsfilter.visualisations.plot_image_array(noisy, x_min, x_max, y_min, y_max, fig=fig, ax=ax[0])
fig.colorbar(cbar, ax=ax[0])
_, _, cbar = dsfilter.visualisations.plot_image_array(U.sum(-1), x_min, x_max, y_min, y_max, clip=clip, fig=fig, ax=ax[1])
fig.colorbar(cbar, ax=ax[1]);

### Perform Filtering

#### TV Flow Filtering

In [None]:
λ_TV = 50. / 256.
U_TV = dsfilter.TV_inpainting_LI(U * λ_TV, mask, np.array((1., 1., 0.01)), dxy, dθ, θs, 1., 0.5, t) / λ_TV

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
_, _, cbar = dsfilter.visualisations.plot_image_array(noisy, x_min, x_max, y_min, y_max, fig=fig, ax=ax[0])
fig.colorbar(cbar, ax=ax[0])
_, _, cbar = dsfilter.visualisations.plot_image_array(U_TV.sum(-1), x_min, x_max, y_min, y_max, clip=clip, fig=fig, ax=ax[1])
fig.colorbar(cbar, ax=ax[1]);

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(24, 5))
_, _, cbar = dsfilter.visualisations.plot_image_array(ground_truth, x_min, x_max, y_min, y_max, fig=fig, ax=ax[0])
ax[0].set_title("Ground Truth")
fig.colorbar(cbar, ax=ax[0])
_, _, cbar = dsfilter.visualisations.plot_image_array(noisy, x_min, x_max, y_min, y_max, fig=fig, ax=ax[1])
ax[1].set_title("Noisy")
fig.colorbar(cbar, ax=ax[1])
_, _, cbar = dsfilter.visualisations.plot_image_array(U_TV.sum(-1), x_min, x_max, y_min, y_max, clip=clip, fig=fig, ax=ax[2])
ax[2].set_title("Denoised")
fig.colorbar(cbar, ax=ax[2])
_, _, cbar = dsfilter.visualisations.plot_image_array(ground_truth - np.clip(U_TV.sum(-1), *clip), x_min, x_max, y_min, y_max, fig=fig, ax=ax[3])
fig.colorbar(cbar, ax=ax[3])
ax[3].set_title("Error");

In [None]:
margin = 0.1
fig, ax, cbar = dsfilter.visualisations.plot_image_array((np.abs((ground_truth - np.clip(U_TV.sum(-1), 0., 1.))) > margin).astype(np.float64), x_min, x_max, y_min, y_max)
fig.colorbar(cbar, ax=ax)
ax.set_title(f"Error > {margin}");

#### DS Filtering

In [None]:
T_short = 0.2
T_medium = 2.
T_long = 5.
T_mega_long = 10.

In [None]:
ε = 0.05
λ = 2.

In [None]:
# Internal regularisation for switching between dilation and erosion.
σ_1, σ_2, σ_3 = np.array((1., 1., 0.6))
# External regularisation for switching between dilation and erosion.
ρ_1, ρ_2, ρ_3 = np.array((1., 1., 0.6))
# Internal and external regularisation of gradient for switching between diffusion and shock.
ν_1, ν_2, ν_3 = np.array((1., 1., 0.6))

In [None]:
u_filtered_short, switch_DS_short, switch_morph_short = dsfilter.DS_filter_spatial(U, mask, θs, T_short, G_D_inv, G_S_inv, σ_1, σ_3, ρ_1, ρ_3, ν_1, ν_3, λ, ε=ε, dxy=dxy)
u_filtered_medium, switch_DS_medium, switch_morph_medium = dsfilter.DS_filter_spatial(U, mask, θs, T_medium, G_D_inv, G_S_inv, σ_1, σ_3, ρ_1, ρ_3, ν_1, ν_3, λ, ε=ε, dxy=dxy)
u_filtered_long, switch_DS_long, switch_morph_long = dsfilter.DS_filter_spatial(U, mask, θs, T_long, G_D_inv, G_S_inv, σ_1, σ_3, ρ_1, ρ_3, ν_1, ν_3, λ, ε=ε, dxy=dxy)
u_filtered_mega_long, switch_DS_mega_long, switch_morph_mega_long = dsfilter.DS_filter_spatial(U, mask, θs, T_mega_long, G_D_inv, G_S_inv, σ_1, σ_3, ρ_1, ρ_3, ν_1, ν_3, λ, ε=ε, dxy=dxy)

In [None]:
fig, ax, cbar = dsfilter.visualisations.plot_image_array((u_filtered_long - u_filtered_mega_long)[..., 0], x_min, x_max, y_min, y_max, cmap="gray")
fig.colorbar(cbar, ax=ax);

In [None]:
K = 0
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_short[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$T = {round(T_short, ndigits=2)}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_medium[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$T = {T_medium}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_long[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$T = {T_long}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$T = {T_mega_long}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
K = 1
fig, ax = plt.subplots(1, 5, figsize=(30, 5))
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0])
ax[0].set_title(f"$\\theta = {θs[0, 0, K]:.2f}$")
fig.colorbar(cbar, ax=ax[0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long[..., K + 2], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1])
ax[1].set_title(f"$\\theta = {θs[0, 0, K + 2]:.2f}$")
fig.colorbar(cbar, ax=ax[1])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long[..., K + 4], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[2])
ax[2].set_title(f"$\\theta = {θs[0, 0, K + 4]:.2f}$")
fig.colorbar(cbar, ax=ax[2])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long[..., K + 8], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[3])
ax[3].set_title(f"$\\theta = {θs[0, 0, K + 8]:.2f}$")
fig.colorbar(cbar, ax=ax[3])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long.sum(-1), x_min, x_max, y_min, y_max, clip=clip, cmap="gray", fig=fig, ax=ax[4])
ax[4].set_title("$\int_\\theta U(\\theta) d\\theta$")
fig.colorbar(cbar, ax=ax[4]);

In [None]:
u_change = u_filtered_mega_long - u_filtered_short
K = 0
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(u_change[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$\\theta = {θs[0, 0, K]:.2f}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_change[..., K + 1], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$\\theta = {θs[0, 0, K + 1]:.2f}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_change[..., K + 2], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$\\theta = {θs[0, 0, K + 2]:.2f}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_change[..., K + 4], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$\\theta = {θs[0, 0, K + 4]:.2f}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_short.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", clip=clip, fig=fig, ax=ax[0])
ax[0].set_title(f"$T = {round(T_short, ndigits=2)}$")
fig.colorbar(cbar, ax=ax[0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", clip=clip, fig=fig, ax=ax[1])
ax[1].set_title(f"$T = {T_mega_long}$")
fig.colorbar(cbar, ax=ax[1]);

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_short.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", clip=clip, fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$T = {round(T_short, ndigits=2)}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_medium.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", clip=clip, fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$T = {T_medium}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_long.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", clip=clip, fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$T = {T_long}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", clip=clip, fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$T = {T_mega_long}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_short.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$T = {round(T_short, ndigits=2)}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_medium.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$T = {T_medium}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_long.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$T = {T_long}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$T = {T_mega_long}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(np.clip(u_filtered_short.sum(-1), 0., 255.) - ground_truth, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$T = {round(T_short, ndigits=2)}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(np.clip(u_filtered_medium.sum(-1), 0., 255.) - ground_truth, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$T = {T_medium}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(np.clip(u_filtered_long.sum(-1), 0., 255.) - ground_truth, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$T = {T_long}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(np.clip(u_filtered_mega_long.sum(-1), 0., 255.) - ground_truth, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$T = {T_mega_long}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
K = 2
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_morph_short[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$T = {round(T_short, ndigits=2)}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_morph_medium[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$T = {T_medium}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_morph_long[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$T = {T_long}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_morph_mega_long[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$T = {T_mega_long}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_morph_short.min(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$T = {round(T_short, ndigits=2)}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_morph_medium.min(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$T = {T_medium}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_morph_long.min(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$T = {T_long}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_morph_mega_long.min(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$T = {T_mega_long}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
K = 0
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_short[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$\\theta = {θs[0, 0, K]:.2f}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_short[..., K + 1], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$\\theta = {θs[0, 0, K + 1]:.2f}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_short[..., K + 2], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$\\theta = {θs[0, 0, K + 2]:.2f}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_short[..., K + 3], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$\\theta = {θs[0, 0, K + 4]:.2f}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
K = 0
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_short[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$T = {round(T_short, ndigits=2)}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_medium[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$T = {T_medium}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_long[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$T = {T_long}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_mega_long[..., K], x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$T = {T_mega_long}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_short.min(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$T = {round(T_short, ndigits=2)}$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_medium.min(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$T = {T_medium}$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_long.min(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title(f"$T = {T_long}$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(switch_DS_mega_long.min(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title(f"$T = {T_mega_long}$")
fig.colorbar(cbar, ax=ax[1, 1]);

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(24, 5))
_, _, cbar = dsfilter.visualisations.plot_image_array(ground_truth, x_min, x_max, y_min, y_max, fig=fig, ax=ax[0])
ax[0].set_title("Ground Truth")
fig.colorbar(cbar, ax=ax[0])
_, _, cbar = dsfilter.visualisations.plot_image_array(noisy, x_min, x_max, y_min, y_max, fig=fig, ax=ax[1])
ax[1].set_title("Noisy")
fig.colorbar(cbar, ax=ax[1])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long.sum(-1), x_min, x_max, y_min, y_max, clip=clip, fig=fig, ax=ax[2])
ax[2].set_title("Denoised")
fig.colorbar(cbar, ax=ax[2])
_, _, cbar = dsfilter.visualisations.plot_image_array(ground_truth - np.clip(u_filtered_mega_long.sum(-1), *clip), x_min, x_max, y_min, y_max, fig=fig, ax=ax[3])
fig.colorbar(cbar, ax=ax[3])
ax[3].set_title("Error");

In [None]:
margin = 0.1
fig, ax, cbar = dsfilter.visualisations.plot_image_array((np.abs((ground_truth - np.clip(u_filtered_mega_long.sum(-1), *clip))) > margin).astype(np.float64), x_min, x_max, y_min, y_max)
fig.colorbar(cbar, ax=ax)
ax.set_title(f"Error > {margin}");

#### Comparison

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
_, _, cbar = dsfilter.visualisations.plot_image_array(ground_truth, x_min, x_max, y_min, y_max, fig=fig, ax=ax[0, 0])
ax[0, 0].set_title("Ground Truth")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(noisy, x_min, x_max, y_min, y_max, fig=fig, ax=ax[0, 1])
ax[0, 1].set_title("Noisy")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = dsfilter.visualisations.plot_image_array(U_TV.sum(-1), x_min, x_max, y_min, y_max, clip=clip, fig=fig, ax=ax[1, 0])
ax[1, 0].set_title("TV Flow")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = dsfilter.visualisations.plot_image_array(u_filtered_mega_long.sum(-1), x_min, x_max, y_min, y_max, clip=clip, fig=fig, ax=ax[1, 1])
fig.colorbar(cbar, ax=ax[1, 1])
ax[1, 1].set_title("DS Filtering");

In [None]:
# import h5py

In [None]:
# filename = f".\\data\\u_short.hdf5"
# with h5py.File(filename, "w") as distance_file:
#     distance_file.create_dataset("Dataset1", data=u_filtered_short)
# filename = f".\\data\\u_medium.hdf5"
# with h5py.File(filename, "w") as distance_file:
#     distance_file.create_dataset("Dataset1", data=u_filtered_medium)
# filename = f".\\data\\u_long.hdf5"
# with h5py.File(filename, "w") as distance_file:
#     distance_file.create_dataset("Dataset1", data=u_filtered_long)
# filename = f".\\data\\u_mega_long.hdf5"
# with h5py.File(filename, "w") as distance_file:
#     distance_file.create_dataset("Dataset1", data=u_filtered_mega_long)

In [None]:
# filename = f".\\data\\u_init.hdf5"
# with h5py.File(filename, "w") as distance_file:
#     distance_file.create_dataset("Dataset1", data=U)