# Image denoising using SVD

In [None]:
# start by importing some necessary packages
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from skimage.io import imread

We read in a series of stamp images:

In [None]:
def load_data():
    print("loading data ...")
    f = "https://raw.githubusercontent.com/jandrejevic12/svd_files/master/stamps/"
    N = 22 # the number of images
    ims = [0]*N
    for i in range(N):
        ims[i] = imread(f+"im_{:02d}.jpg".format(i))
    
    ims = np.array(ims).astype(np.float)/255. # convert to [0,1] scale
    print("load complete!")
    return ims

ims = load_data()
N, m, n, p = ims.shape

In [None]:
# Visualize the data in a grid:
def plot_data(data):
    m = 2
    n = len(data)//m
    fig, axes = plt.subplots(m, n, figsize=(2*n,2*m))
    for i,ax in enumerate(axes.flat):
        ax.imshow(data[i])
        ax.axis('off')
    plt.show()

plot_data(ims)

Next, we reshape the data so that each image is represented as a long, $m\times{n}\times{p}$ column. We center the data about the mean image, and compute the SVD.

In [None]:
# Reshape the data to m*n*p by N:
S = ims.reshape(N, m*n*p).T

# Compute the mean over all columns.
Sm = np.mean(S, axis=1, keepdims=True) # even though it is a vector, keep it two-dimensional.

# Perform a reduced SVD on the centered matrix.
A = S - Sm
U, s, Vt = np.linalg.svd(A, full_matrices=False)

Finally, we can pick a sample image to reconstruct up to a desired rank. Provided the images are well-aligned, due to their high similarity we expect only a few singular vectors are needed for a faithful reconstruction.

Notice since we are reconstructing just a single column instead of the whole dataset, we can simply pick out the correct entry of each $v_i^T$ corresponding to our chosen image, instead of constructing the full rank-one matrices $\sigma_iu_iv_i^T$.

In [None]:
# Pick an image index and reconstruct it up to rank r.
index = 5
r = 1
Ar = np.sum([s[i]*U[:,i]*Vt[i,index] for i in range(r)], axis=0)

# Add back the mean and reshape into m by n by p.
Ar += Sm.ravel()
imr = Ar.reshape(m,n,p)

We conclude with a visualization of our reconstructed image:

In [None]:
# Truncate to valid range.
imr[imr<0] = 0; imr[imr>1]=1

# Plot the original and compressed images.
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(7,4))
ax1.imshow(ims[index])
ax1.set_title("original", size=16)
ax1.axis("off")

ax2.imshow(imr)
ax2.set_title("$r={:d}$".format(r), size=16)
ax2.axis("off")
plt.show()