## IPython notebook for various denoising filters

### Import all modules and functions needed for this example.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.io import imread
from skimage import img_as_float
from skimage.filters import median
from skimage.morphology import square
from scipy.signal import correlate2d
from skimage.util import random_noise


### Implementation of some filter kernels.

In [None]:
def mean_filter_kernel(a):
    kernel = np.ones((2 * a + 1, 2 * a + 1))
    kernel = kernel / np.sum(kernel)
    return kernel


def gauss_filter_kernel(a, sigma):

    # Straightforward loops to compute the kernel
    # kernel = np.zeros((2*a+1, 2*a+1))
    # for i in range(-a, a+1):
    #     for j in range(-a, a+1):
    #         kernel[i+a, j+a] = np.exp(- (i**2 + j**2) / (2 * sigma**2))

    # Python way of computing the kernel
    indX, indY = np.meshgrid(range(-a, a + 1), range(-a, a + 1))
    kernel = np.exp(-(np.square(indX) + np.square(indY)) / (2 * sigma**2))
    kernel = kernel / np.sum(kernel)
    return kernel


### Read and plot the input image.

In [None]:
true_input_image = img_as_float(imread("astronaut.png", as_gray=True))

plt.imshow(true_input_image, interpolation="nearest", cmap=plt.cm.get_cmap("gray"), vmin=0, vmax=1)
plt.axis("off")
plt.title("Original image")
plt.show()

### Add noise to the input image.

In [None]:
# Different noise models have very different effects
input_image = random_noise(true_input_image, mode="gaussian")
# input_image = random_noise(true_input_image, mode="poisson")
# input_image = random_noise(true_input_image, mode="s&p")

plt.imshow(input_image, interpolation="nearest", cmap=plt.cm.get_cmap("gray"), vmin=0, vmax=1)
plt.axis("off")
plt.title("Noisy input image")
plt.show()

### Set the filter radius.

In [None]:
a = 2

### Apply different filters to the noisy input image.

In [None]:
mean_filtered_image = correlate2d(input_image, mean_filter_kernel(a), mode="same")
gauss_filtered_image = correlate2d(input_image, gauss_filter_kernel(a, 1), mode="same")
median_filtered_image = median(input_image, footprint=square(2 * a + 1), mode="reflect")

# Plot all filtered images
plt.subplot(1, 3, 1)
plt.imshow(mean_filtered_image, interpolation="nearest", cmap=plt.cm.get_cmap("gray"), vmin=0, vmax=1)
plt.axis("off")
plt.title("Mean filtered image")

plt.subplot(1, 3, 2)
plt.imshow(gauss_filtered_image, interpolation="nearest", cmap=plt.cm.get_cmap("gray"), vmin=0, vmax=1)
plt.axis("off")
plt.title("Gauss filtered image")

plt.subplot(1, 3, 3)
plt.imshow(median_filtered_image, interpolation="nearest", cmap=plt.cm.get_cmap("gray"), vmin=0, vmax=1)
plt.axis("off")
plt.title("Median filtered image")
plt.show()
