# Inpainting
In this notebook we use left-invariant Regularised Diffusion-Shock (RDS) filtering on $\mathbb{M}_2$ and RDS on $\mathbb{R}^2$ to inpaint, as in Section 6.2 of ["Diffusion-Shock Filtering on the Space of Positions and Orientations"]().

In [None]:
import taichi as ti
ti.init(arch=ti.gpu, debug=False, device_memory_GB=3.5)
import numpy as np
import scipy as sp
from PIL import Image
from dsfilter import (
    DS_inpainting_LI,
    DS_inpainting_R2
)
from dsfilter.M2.utils import (
    coordinate_array_to_real,
    clean_mask_boundaries
)
from dsfilter.orientationscore import (
    cakewavelet_stack,
    wavelet_transform
)
from dsfilter.visualisations import plot_image_array
import matplotlib.pyplot as plt

## Test Cases
In the paper, we inpaint a grey image with a grid of black and white lines (Fig. 6). We can also inpaint other images.

In [None]:
# "black_and_white_on_grey" "black_on_white" "diagonal" "cross"
test_case = "black_and_white_on_grey"

### Setup

In [None]:
dim_I, dim_J, dim_K = 256, 256, 16
u_ground_truth = np.ones((dim_I, dim_J)) * 255.
xs, ys = np.meshgrid(np.linspace(-1, 1, dim_I), np.linspace(-1, 1, dim_J))

In [None]:
match test_case:
    case "black_on_white": # Grid of lines (black)
        N_lines = 4
        offset = dim_I // (N_lines + 1)
        for k in range(N_lines):
            centre = (k + 1) * offset
            u_ground_truth[:, (centre-2):(centre+3)] = 0.
            u_ground_truth[(centre-2):(centre+3), :] = 0.

        l = 0.4
        mask = (xs**2 < l) * (ys**2 < l)

        u = u_ground_truth.copy()
        u[mask] = 255.
    case "black_and_white_on_grey": # Grid of lines (alternating black and white)
        u_ground_truth *= 0.5
        N_lines = 4 # 7
        offset = dim_I // (N_lines + 1)
        colour = 0. # black
        for k in range(N_lines):
            centre = (k + 1) * offset
            u_ground_truth[:, (centre-2):(centre+3)] = 255. - colour
            u_ground_truth[(centre-2):(centre+3), :] = colour
            colour = 255. - colour
        u_ground_truth = sp.ndimage.gaussian_filter(u_ground_truth, 1)

        l = 0.4
        mask = (xs**2 < l) * (ys**2 < l)

        u = u_ground_truth.copy()
        u[mask] = 0.5 * 255.
    case "diagonal": # Lines with diagonal (black)
        N_lines = 4
        offset = dim_I // (N_lines + 1)
        for k in range(N_lines):
            centre = (k + 1) * offset
            u_ground_truth[:, (centre-2):(centre+3)] = 0.

        diagonal = (np.abs(xs - ys) < 0.03)
        u_ground_truth[diagonal] = 0.
        l = 0.4
        mask = (xs**2 < l) * (ys**2 < l)

        u = u_ground_truth.copy()
        u[mask] = 255.

        dim_K = 32 # We need more orientations to have sufficient distance between the horizontal and diagonal lines.
    case "cross": # Chunky cross
        vertical = (-0.2 < xs) * (xs < 0.2)
        horizontal = (-0.2 < ys) * (ys < 0.2)
        u_ground_truth = np.zeros((dim_I, dim_J))
        u_ground_truth[vertical + horizontal] = 1. * 255.

        mask = (-0.3 < xs) * (xs < 0.3) * (-0.3 < ys) * (ys < 0.3)
        u = u_ground_truth.copy()
        u[mask] = 0.

u = sp.ndimage.gaussian_filter(u, 1.) # Smooth for well-posed lifting.
clip = (0., 255.)

mask_orig = 1 - mask.astype(int)
mask = sp.ndimage.binary_erosion(mask_orig, iterations=10, border_value=1).astype(int) # Deal with boundary artefacts.

Is, Js, Ks = np.indices((dim_I, dim_J, dim_K))
x_min, x_max = 0., dim_I - 1.
y_min, y_max = 0., dim_J - 1.
θ_min, θ_max = 0., 2 * np.pi
dxy = (x_max - x_min) / (dim_I - 1)
dθ = (θ_max - θ_min) / dim_K
xs, ys, θs = coordinate_array_to_real(Is, Js, Ks, x_min, y_min, θ_min, dxy, dθ)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
_, _, cbar = plot_image_array(u_ground_truth, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0])
ax[0].set_title("Ground Truth")
fig.colorbar(cbar, ax=ax[0])
_, _, cbar = plot_image_array(u, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1])
ax[1].set_title("Masked")
fig.colorbar(cbar, ax=ax[1]);

### Inpainting Parameters

In [None]:
G_D_inv = 1.8 * np.array((1., 0.1, 0.0))
G_S_inv = np.array((1., 1., 0.0))
# Internal regularisation for switching between dilation and erosion.
σ_s, σ_o = 2., 0.6
# External regularisation for switching between dilation and erosion.
ρ_s, ρ_o = 1., 0.6
# Internal and external regularisation of gradient for switching between diffusion and shock.
ν_s, ν_o = 2.5, 0.6
λ = 0.35 # Contrast parameter for switching between diffusion and shock.
ε = 0.1 # Regularisation parameter for signum.
T_M2 = 2500.

σ_R2 = 2.5
ρ_R2 = 1.6 * σ_R2
ν_R2 = 1.6 * σ_R2
λ_R2 = 1.
ε_R2 = 0.15 * λ_R2

T_R2 = 500.

## Orientation Score

In [None]:
cws = cakewavelet_stack(dim_I, dim_K, Gaussian_σ=dim_I/8)
U = wavelet_transform(u, cws).real
U = np.transpose(U, axes=(1, 2, 0)) # x, y, θ
Mask = np.transpose(np.array([mask for _ in range(dim_K)]), axes=(1, 2, 0)) # x, y, θ

In [None]:
K = 0
print(θs[0, 0, K])
fig, ax = plt.subplots(2, 3, figsize=(18, 10))
_, _, cbar = plot_image_array(U[..., K] * mask, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 0])
ax[0, 0].set_title(f"$U(\cdot, {K*dθ:.2f})$")
fig.colorbar(cbar, ax=ax[0, 0])
_, _, cbar = plot_image_array(U[..., K + 2] * mask, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 1])
ax[0, 1].set_title(f"$U(\cdot, {(K + 2)*dθ:.2f})$")
fig.colorbar(cbar, ax=ax[0, 1])
_, _, cbar = plot_image_array(U[..., K + 4] * mask, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[0, 2])
ax[0, 2].set_title(f"$U(\cdot, {(K + 4)*dθ:.2f})$")
fig.colorbar(cbar, ax=ax[0, 2])
_, _, cbar = plot_image_array(u, x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 0])
ax[1, 0].set_title("$u$")
fig.colorbar(cbar, ax=ax[1, 0])
_, _, cbar = plot_image_array(U.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 1])
ax[1, 1].set_title("Reconstruction")
fig.colorbar(cbar, ax=ax[1, 1])
_, _, cbar = plot_image_array(u - U.sum(-1), x_min, x_max, y_min, y_max, cmap="gray", fig=fig, ax=ax[1, 2])
ax[1, 2].set_title("Reconstruction error")
fig.colorbar(cbar, ax=ax[1, 2]);

### Preprocess
The mask causes a large step edge in the image, which will get picked up in the orientation score. Therefore, we cannot trust the image data near the boundary of the mask. For this reason, we dilate the mask: the image gets inpainted using data that is sufficiently far from the boundary to have reliable orientation information.

In [None]:
U_preprocessed = clean_mask_boundaries(U, Mask)

## $\mathbb{R}^2$ RDS Inpainting

In [None]:
u_R2 = DS_inpainting_R2(u, mask, T_R2, σ_R2, ρ_R2, ν_R2, λ_R2, ε=ε_R2, dxy=dxy)

In [None]:
fig, ax, cbar = plot_image_array(u_R2, x_min, x_max, y_min, y_max, clip=clip, cmap="gray", figsize=(6, 5))
ax.set_title("Inpainted with $\mathbb{R}^2$ RDS")
fig.colorbar(cbar, ax=ax);

## $\mathbb{M}_2$ RDS Inpainting

In [None]:
U_M2 = DS_inpainting_LI(U_preprocessed, Mask, θs, T_M2, G_D_inv, G_S_inv, σ_s, σ_o, ρ_s, ρ_o, ν_s, ν_o, λ, ε=ε, dxy=dxy)
u_M2 = U_M2.sum(-1)

In [None]:
fig, ax, cbar = plot_image_array(u_M2, x_min, x_max, y_min, y_max, clip=clip, cmap="gray", figsize=(6, 5))
ax.set_title("Inpainted with $\mathbb{M}_2$ RDS")
fig.colorbar(cbar, ax=ax);

In [None]:
fig, ax, cbar = plot_image_array(u_M2, x_min, x_max, y_min, y_max, clip=clip, cmap="gray", figsize=(6, 5))
ax.set_title("Inpainted with $\mathbb{M}_2$ RDS")
fig.colorbar(cbar, ax=ax);

## Comparison

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
ax[0, 0].imshow(u_ground_truth, vmin=clip[0], vmax=clip[1], cmap="gray")
ax[0, 0].set_title("Ground Truth")
ax[0, 0].set_xticks([])
ax[0, 0].set_yticks([])
ax[0, 0].set_frame_on(False)
ax[0, 1].imshow(u, vmin=clip[0], vmax=clip[1], cmap="gray")
ax[0, 1].set_title("Masked")
ax[0, 1].set_xticks([])
ax[0, 1].set_yticks([])
ax[0, 1].set_frame_on(False)
ax[1, 0].imshow(u_R2, vmin=clip[0], vmax=clip[1], cmap="gray")
ax[1, 0].set_title("$\mathbb{R}^2$")
ax[1, 0].set_xticks([])
ax[1, 0].set_yticks([])
ax[1, 0].set_frame_on(False)
ax[1, 1].imshow(u_M2, vmin=clip[0], vmax=clip[1], cmap="gray")
ax[1, 1].set_title("$\mathbb{M}_2$")
ax[1, 1].set_xticks([])
ax[1, 1].set_yticks([])
ax[1, 1].set_frame_on(False)

## Save Results

In [None]:
images = []
names = []

images.append(np.clip(u_R2, *clip))
names.append("R2")
images.append(np.clip(u_M2, *clip))
names.append("M2")

In [None]:
for image, name in zip(images, names):
    Image.fromarray(
        image.astype(np.uint8)
    ).save(f"output\\{test_case}_{name}.png")