# Total Variation Image Inpainting
#### From: Chan, T. F., & Shen, J. (2005). Variational image inpainting. *Communications on Pure and Applied Mathematics: A Journal Issued by the Courant Institute of Mathematical Sciences*, 58(5), 579-619.

Given an image domain $\Omega$ and an inpainting domain $D \subset \Omega$, where no image information is defined, the goal is to minimize the following energy:
$$
E = \int_\Omega | \nabla u | dx + \frac{\lambda}{2} \int_{\Omega - D} (u - u_0)^2 dx.
$$
The Euler-Lagrange equations for the above yields the following Gradient Descent solution:
$$
\frac{\partial u}{\partial t} = \nabla\cdot\Big[ \frac{\nabla u}{|\nabla u|} \Big] + \lambda_D(x)(u^0 - u), \tag{1}
$$
where:
$$
\lambda_D(x) = \begin{cases}
\lambda \:\: , \text{if }\, x\in \Omega-D \\
0 \:\: , \text{if }\, x \in D
\end{cases}.
\tag{2}
$$ 


In [None]:
# First, import all necessary libraries

import matplotlib.pyplot as plt
import numpy as np
import numpy.random as rndm
from PIL import Image
from scipy import signal

In [None]:
# This function defines an array of the size of the image, where the rules for equation (2) are satisfied.
# The array will have value lam everywhere except the inpainting domain D

def makeLambdaArray(corrupted, lam):
    NA = np.where(corrupted >= 210)
    X = NA[0]
    Y = NA[1]
    arr = np.full(corrupted.shape, lam)
    for i in range(len(X)):
        x, y = X[i], Y[i]
        arr[x,y] = 0
    return(arr)

In [None]:
# This function calculates the first term in the right-hand-side of equation (1)

def divergence(u, delta=0.01):   # delta is a small term used to avoid division by 0
    grad_u = np.gradient(u)
    grad_u = grad_u/ (np.sqrt(grad_u[0]**2 + grad_u[1]**2 + delta) )    # divide by the norm of the gradient
    
    ux = grad_u[0]
    uy = grad_u[1]                  # ux and uy are the first partial derivatives
    uxx = np.gradient(ux)[0]
    uyy = np.gradient(uy)[1]        # uxx and uyy are the second partial derivatives
    div = uxx + uyy
    return(div)

In [None]:
# This function performs the gradient descent

def grad_descent(u0, corrupted, N_iter, lam, alpha):
    u = np.copy(u0)
    LamArr = makeLambdaArray(corrupted, lam)
    norm_history = []
    for i in range(N_iter):
        div = divergence(u)             # Calculate first term of RHS in (1)
        grad = -div + LamArr*(u - u0)   # Calculate du/dt
        
        if i % 1000 == 0:
            norm = np.linalg.norm(grad)
            print('Iteration =', i, ' Gradient norm =', norm)
            norm_history.append(norm)
        
        u = u - alpha*grad
    return(u, norm_history)

In [None]:
# Now read the corrupted image
not_corrupted = np.array(Image.open("cameraman.png").convert('L'))
corrupted = np.array(Image.open("cameraman_corrupted.png").convert('L'))
u0 = np.copy(corrupted)

# Make a new image where the corrupted domain has uniform random noise
M, N = corrupted.shape
for i in range(M):
    for j in range(N):
        if u0[i,j] >= 250:
            u0[i,j] = np.random.randint(0,255)

# Change brightness values of image to [0,1]
u0 = u0/255

FileNotFoundError: ignored

In [None]:
# Define parameters and hyper-parameters

N_iter = 10001
lam = 100
alpha = 0.001
print('N_iter = ', N_iter, ' lam = ', lam, ' alpha = ', alpha)

# Perform Gradient Descent
img_new, norm_history = grad_descent(u0, corrupted, N_iter, lam, alpha)

In [None]:
# Show all images

fig, axs = plt.subplots(2,2, figsize=(6,6))

axs[0,0].imshow(not_corrupted, cmap='gray')
axs[0,0].set_title('Original')
axs[0,0].set_xticks([])
axs[0,0].set_yticks([])

axs[0,1].imshow(corrupted, cmap='gray')
axs[0,1].set_title('Corrupted')
axs[0,1].set_xticks([])
axs[0,1].set_yticks([])

axs[1,0].imshow(u0, cmap='gray')
axs[1,0].set_title('Inpainted domain filled w/ noise')
axs[1,0].set_xticks([])
axs[1,0].set_yticks([])

axs[1,1].imshow(img_new, cmap='gray')
axs[1,1].set_title('Inpainted')
axs[1,1].set_xticks([])
axs[1,1].set_yticks([])

fig.tight_layout()

In [None]:
L = makeLambdaArray(corrupted, lam)
plt.imshow(L, cmap='gray')

In [None]:
# Graph the norm history

X = np.linspace(0, N_iter-1, int(np.round(N_iter/200)+1))
fig, ax = plt.subplots(1,1)
ax.set_title('Gradient Norm History')

ax.plot(norm_history)