
<br>
=========================================<br>
Image denoising using dictionary learning<br>
=========================================<br>
An example comparing the effect of reconstructing noisy fragments<br>
of a raccoon face image using firstly online :ref:`DictionaryLearning` and<br>
various transform methods.<br>
The dictionary is fitted on the distorted left half of the image, and<br>
subsequently used to reconstruct the right half. Note that even better<br>
performance could be achieved by fitting to an undistorted (i.e.<br>
noiseless) image, but here we start from the assumption that it is not<br>
available.<br>
A common practice for evaluating the results of image denoising is by looking<br>
at the difference between the reconstruction and the original image. If the<br>
reconstruction is perfect this will look like Gaussian noise.<br>
It can be seen from the plots that the results of :ref:`omp` with two<br>
non-zero coefficients is a bit less biased than when keeping only one<br>
(the edges look less prominent). It is in addition closer from the ground<br>
truth in Frobenius norm.<br>
The result of :ref:`least_angle_regression` is much more strongly biased: the<br>
difference is reminiscent of the local intensity value of the original image.<br>
Thresholding is clearly not useful for denoising, but it is here to show that<br>
it can produce a suggestive output with very high speed, and thus be useful<br>
for other tasks such as object classification, where performance is not<br>
necessarily related to visualisation.<br>


In [None]:
print(__doc__)

In [None]:
from time import time

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp

In [None]:
from sklearn.decomposition import MiniBatchDictionaryLearning
from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.feature_extraction.image import reconstruct_from_patches_2d

In [None]:
try:  # SciPy >= 0.16 have face in misc
    from scipy.misc import face
    face = face(gray=True)
except ImportError:
    face = sp.face(gray=True)

Convert from uint8 representation with values between 0 and 255 to<br>
a floating point representation with values between 0 and 1.

In [None]:
face = face / 255.

downsample for higher speed

In [None]:
face = face[::4, ::4] + face[1::4, ::4] + face[::4, 1::4] + face[1::4, 1::4]
face /= 4.0
height, width = face.shape

Distort the right half of the image

In [None]:
print('Distorting image...')
distorted = face.copy()
distorted[:, width // 2:] += 0.075 * np.random.randn(height, width // 2)

Extract all reference patches from the left half of the image

In [None]:
print('Extracting reference patches...')
t0 = time()
patch_size = (7, 7)
data = extract_patches_2d(distorted[:, :width // 2], patch_size)
data = data.reshape(data.shape[0], -1)
data -= np.mean(data, axis=0)
data /= np.std(data, axis=0)
print('done in %.2fs.' % (time() - t0))

#############################################################################<br>
Learn the dictionary from reference patches

In [None]:
print('Learning the dictionary...')
t0 = time()
dico = MiniBatchDictionaryLearning(n_components=100, alpha=1, n_iter=500)
V = dico.fit(data).components_
dt = time() - t0
print('done in %.2fs.' % dt)

In [None]:
plt.figure(figsize=(4.2, 4))
for i, comp in enumerate(V[:100]):
    plt.subplot(10, 10, i + 1)
    plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r,
               interpolation='nearest')
    plt.xticks(())
    plt.yticks(())
plt.suptitle('Dictionary learned from face patches\n' +
             'Train time %.1fs on %d patches' % (dt, len(data)),
             fontsize=16)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)

#############################################################################<br>
Display the distorted image

In [None]:
def show_with_diff(image, reference, title):
    """Helper function to display denoising"""
    plt.figure(figsize=(5, 3.3))
    plt.subplot(1, 2, 1)
    plt.title('Image')
    plt.imshow(image, vmin=0, vmax=1, cmap=plt.cm.gray,
               interpolation='nearest')
    plt.xticks(())
    plt.yticks(())
    plt.subplot(1, 2, 2)
    difference = image - reference
    plt.title('Difference (norm: %.2f)' % np.sqrt(np.sum(difference ** 2)))
    plt.imshow(difference, vmin=-0.5, vmax=0.5, cmap=plt.cm.PuOr,
               interpolation='nearest')
    plt.xticks(())
    plt.yticks(())
    plt.suptitle(title, size=16)
    plt.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2)

In [None]:
show_with_diff(distorted, face, 'Distorted image')

#############################################################################<br>
Extract noisy patches and reconstruct them using the dictionary

In [None]:
print('Extracting noisy patches... ')
t0 = time()
data = extract_patches_2d(distorted[:, width // 2:], patch_size)
data = data.reshape(data.shape[0], -1)
intercept = np.mean(data, axis=0)
data -= intercept
print('done in %.2fs.' % (time() - t0))

In [None]:
transform_algorithms = [
    ('Orthogonal Matching Pursuit\n1 atom', 'omp',
     {'transform_n_nonzero_coefs': 1}),
    ('Orthogonal Matching Pursuit\n2 atoms', 'omp',
     {'transform_n_nonzero_coefs': 2}),
    ('Least-angle regression\n5 atoms', 'lars',
     {'transform_n_nonzero_coefs': 5}),
    ('Thresholding\n alpha=0.1', 'threshold', {'transform_alpha': .1})]

In [None]:
reconstructions = {}
for title, transform_algorithm, kwargs in transform_algorithms:
    print(title + '...')
    reconstructions[title] = face.copy()
    t0 = time()
    dico.set_params(transform_algorithm=transform_algorithm, **kwargs)
    code = dico.transform(data)
    patches = np.dot(code, V)
    patches += intercept
    patches = patches.reshape(len(data), *patch_size)
    if transform_algorithm == 'threshold':
        patches -= patches.min()
        patches /= patches.max()
    reconstructions[title][:, width // 2:] = reconstruct_from_patches_2d(
        patches, (height, width // 2))
    dt = time() - t0
    print('done in %.2fs.' % dt)
    show_with_diff(reconstructions[title], face,
                   title + ' (time: %.1fs)' % dt)

In [None]:
plt.show()